Skip to content

Commit 9972584

Browse files
authored
XGBoost hyperparams (#318)
1 parent 5aaac95 commit 9972584

File tree

2 files changed

+124
-3
lines changed

2 files changed

+124
-3
lines changed

pgml-extension/pgml_rust/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pub mod vectors;
1515

1616
pg_module_magic!();
1717

18-
extension_sql_file!("../sql/schema.sql", name = "bootstrap_raw");
18+
extension_sql_file!("../sql/schema.sql", name = "schema");
1919

2020
// The mutex is there just to guarantee to Rust that
2121
// there is no concurrent access.

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,134 @@ impl Model {
160160
Some(value) => value.as_u64().unwrap_or(2) as u32,
161161
None => 2,
162162
})
163-
.eta(0.3)
163+
.eta(match hyperparams.get("eta") {
164+
Some(value) => value.as_f64().unwrap_or(0.3) as f32,
165+
None => match hyperparams.get("learning_rate") {
166+
Some(value) => value.as_f64().unwrap_or(0.3) as f32,
167+
None => 0.3,
168+
},
169+
})
170+
.gamma(match hyperparams.get("gamma") {
171+
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
172+
None => match hyperparams.get("min_split_loss") {
173+
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
174+
None => 0.0,
175+
},
176+
})
177+
.min_child_weight(match hyperparams.get("min_child_weight") {
178+
Some(value) => value.as_f64().unwrap_or(1.0) as f32,
179+
None => 1.0,
180+
})
181+
.max_delta_step(match hyperparams.get("max_delta_step") {
182+
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
183+
None => 0.0,
184+
})
185+
.subsample(match hyperparams.get("subsample") {
186+
Some(value) => value.as_f64().unwrap_or(1.0) as f32,
187+
None => 1.0,
188+
})
189+
.lambda(match hyperparams.get("lambda") {
190+
Some(value) => value.as_f64().unwrap_or(1.0) as f32,
191+
None => 1.0,
192+
})
193+
.alpha(match hyperparams.get("alpha") {
194+
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
195+
None => 0.0,
196+
})
197+
.tree_method(match hyperparams.get("tree_method") {
198+
Some(value) => match value.as_str().unwrap_or("auto") {
199+
"auto" => parameters::tree::TreeMethod::Auto,
200+
"exact" => parameters::tree::TreeMethod::Exact,
201+
"approx" => parameters::tree::TreeMethod::Approx,
202+
"hist" => parameters::tree::TreeMethod::Hist,
203+
_ => parameters::tree::TreeMethod::Auto,
204+
},
205+
206+
None => parameters::tree::TreeMethod::Auto,
207+
})
208+
.sketch_eps(match hyperparams.get("sketch_eps") {
209+
Some(value) => value.as_f64().unwrap_or(0.03) as f32,
210+
None => 0.03,
211+
})
212+
.max_leaves(match hyperparams.get("max_leaves") {
213+
Some(value) => value.as_u64().unwrap_or(0) as u32,
214+
None => 0,
215+
})
216+
.max_bin(match hyperparams.get("max_bin") {
217+
Some(value) => value.as_u64().unwrap_or(256) as u32,
218+
None => 256,
219+
})
220+
.num_parallel_tree(match hyperparams.get("num_parallel_tree") {
221+
Some(value) => value.as_u64().unwrap_or(1) as u32,
222+
None => 1,
223+
})
224+
.grow_policy(match hyperparams.get("grow_policy") {
225+
Some(value) => match value.as_str().unwrap_or("depthwise") {
226+
"depthwise" => parameters::tree::GrowPolicy::Depthwise,
227+
"lossguide" => parameters::tree::GrowPolicy::LossGuide,
228+
_ => parameters::tree::GrowPolicy::Depthwise,
229+
},
230+
231+
None => parameters::tree::GrowPolicy::Depthwise,
232+
})
233+
.build()
234+
.unwrap();
235+
236+
let linear_params = parameters::linear::LinearBoosterParametersBuilder::default()
237+
.alpha(match hyperparams.get("alpha") {
238+
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
239+
None => 0.0,
240+
})
241+
.lambda(match hyperparams.get("lambda") {
242+
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
243+
None => 0.0,
244+
})
245+
.build()
246+
.unwrap();
247+
248+
let dart_params = parameters::dart::DartBoosterParametersBuilder::default()
249+
.rate_drop(match hyperparams.get("rate_drop") {
250+
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
251+
None => 0.0,
252+
})
253+
.one_drop(match hyperparams.get("one_drop") {
254+
Some(value) => value.as_u64().unwrap_or(0) != 0,
255+
None => false,
256+
})
257+
.skip_drop(match hyperparams.get("skip_drop") {
258+
Some(value) => value.as_f64().unwrap_or(0.0) as f32,
259+
None => 0.0,
260+
})
261+
.sample_type(match hyperparams.get("sample_type") {
262+
Some(value) => match value.as_str().unwrap_or("uniform") {
263+
"uniform" => parameters::dart::SampleType::Uniform,
264+
"weighted" => parameters::dart::SampleType::Weighted,
265+
_ => parameters::dart::SampleType::Uniform,
266+
},
267+
None => parameters::dart::SampleType::Uniform,
268+
})
269+
.normalize_type(match hyperparams.get("normalize_type") {
270+
Some(value) => match value.as_str().unwrap_or("tree") {
271+
"tree" => parameters::dart::NormalizeType::Tree,
272+
"forest" => parameters::dart::NormalizeType::Forest,
273+
_ => parameters::dart::NormalizeType::Tree,
274+
},
275+
None => parameters::dart::NormalizeType::Tree,
276+
})
164277
.build()
165278
.unwrap();
166279

167280
// overall configuration for Booster
168281
let booster_params = parameters::BoosterParametersBuilder::default()
169-
.booster_type(parameters::BoosterType::Tree(tree_params))
282+
.booster_type(match hyperparams.get("booster") {
283+
Some(value) => match value.as_str().unwrap_or("gbtree") {
284+
"gbtree" => parameters::BoosterType::Tree(tree_params),
285+
"linear" => parameters::BoosterType::Linear(linear_params),
286+
"dart" => parameters::BoosterType::Dart(dart_params),
287+
_ => parameters::BoosterType::Tree(tree_params),
288+
},
289+
None => parameters::BoosterType::Tree(tree_params),
290+
})
170291
.learning_params(learning_params)
171292
.verbose(true)
172293
.build()

0 commit comments

Comments
 (0)