Skip to content

Commit 46f4936

Browse files
authored
Fix default XGBoost n_estimators (#338)
1 parent 17d1199 commit 46f4936

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

pgml-extension/pgml_rust/src/engines/xgboost.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ pub fn xgboost_train(
186186
let params = parameters::TrainingParametersBuilder::default()
187187
.dtrain(&dtrain) // dataset to train with
188188
.boost_rounds(match hyperparams.get("n_estimators") {
189-
Some(value) => value.as_u64().unwrap_or(2) as u32,
190-
None => 2,
189+
Some(value) => value.as_u64().unwrap_or(10) as u32,
190+
None => 10,
191191
}) // number of training iterations
192192
.booster_params(booster_params) // model parameters
193193
.evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration

pgml-extension/pgml_rust/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ extern crate openblas_src;
33
extern crate serde;
44

55
use once_cell::sync::Lazy; // 1.3.1
6+
use parking_lot::Mutex;
67
use pgx::*;
78
use std::collections::HashMap;
89
use std::fs;
9-
use parking_lot::Mutex;
1010
use xgboost::{Booster, DMatrix};
1111

1212
pub mod api;

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
use parking_lot::Mutex;
12
use std::collections::HashMap;
23
use std::fmt::Debug;
34
use std::str::FromStr;
45
use std::sync::Arc;
5-
use parking_lot::Mutex;
66

77
use ndarray::{Array1, Array2};
88
use once_cell::sync::Lazy;

0 commit comments

Comments
 (0)