Skip to content

Commit 343d2e2

Browse files
authored
Add more smartcore (#322)
1 parent bbaf2f4 commit 343d2e2

File tree

3 files changed

+269
-5
lines changed

3 files changed

+269
-5
lines changed

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ pub enum Algorithm {
99
svm,
1010
lasso,
1111
elastic_net,
12-
// ridge,
13-
// kmeans,
14-
// dbscan,
15-
// knn,
16-
// random_forest,
12+
ridge,
13+
kmeans,
14+
dbscan,
15+
knn,
16+
random_forest,
1717
}
1818

1919
impl std::str::FromStr for Algorithm {
@@ -26,6 +26,11 @@ impl std::str::FromStr for Algorithm {
2626
"svm" => Ok(Algorithm::svm),
2727
"lasso" => Ok(Algorithm::lasso),
2828
"elastic_net" => Ok(Algorithm::elastic_net),
29+
"ridge" => Ok(Algorithm::ridge),
30+
"kmeans" => Ok(Algorithm::kmeans),
31+
"dbscan" => Ok(Algorithm::dbscan),
32+
"knn" => Ok(Algorithm::knn),
33+
"random_forest" => Ok(Algorithm::random_forest),
2934
_ => Err(()),
3035
}
3136
}
@@ -39,6 +44,11 @@ impl std::string::ToString for Algorithm {
3944
Algorithm::svm => "svm".to_string(),
4045
Algorithm::lasso => "lasso".to_string(),
4146
Algorithm::elastic_net => "elastic_net".to_string(),
47+
Algorithm::ridge => "ridge".to_string(),
48+
Algorithm::kmeans => "kmeans".to_string(),
49+
Algorithm::dbscan => "dbscan".to_string(),
50+
Algorithm::knn => "knn".to_string(),
51+
Algorithm::random_forest => "random_forest".to_string(),
4252
}
4353
}
4454
}

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,32 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
9292
rmp_serde::from_read(&*data).unwrap();
9393
Box::new(estimator)
9494
}
95+
Algorithm::ridge => {
96+
let estimator: smartcore::linear::ridge_regression::RidgeRegression<
97+
f32,
98+
Array2<f32>,
99+
> = rmp_serde::from_read(&*data).unwrap();
100+
Box::new(estimator)
101+
}
102+
Algorithm::kmeans => todo!(),
103+
104+
Algorithm::dbscan => todo!(),
105+
106+
Algorithm::knn => {
107+
let estimator: smartcore::neighbors::knn_regressor::KNNRegressor<
108+
f32,
109+
smartcore::math::distance::euclidian::Euclidian,
110+
> = rmp_serde::from_read(&*data).unwrap();
111+
Box::new(estimator)
112+
}
113+
114+
Algorithm::random_forest => {
115+
let estimator: smartcore::ensemble::random_forest_regressor::RandomForestRegressor<
116+
f32,
117+
> = rmp_serde::from_read(&*data).unwrap();
118+
Box::new(estimator)
119+
}
120+
95121
Algorithm::xgboost => {
96122
let bst = Booster::load_buffer(&*data).unwrap();
97123
Box::new(BoosterBox::new(bst))
@@ -155,6 +181,26 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
155181
}
156182
Algorithm::lasso => panic!("Lasso does not support classification"),
157183
Algorithm::elastic_net => panic!("Elastic Net does not support classification"),
184+
Algorithm::ridge => panic!("Ridge does not support classification"),
185+
186+
Algorithm::kmeans => todo!(),
187+
188+
Algorithm::dbscan => todo!(),
189+
190+
Algorithm::knn => {
191+
let estimator: smartcore::neighbors::knn_classifier::KNNClassifier<
192+
f32,
193+
smartcore::math::distance::euclidian::Euclidian,
194+
> = rmp_serde::from_read(&*data).unwrap();
195+
Box::new(estimator)
196+
}
197+
198+
Algorithm::random_forest => {
199+
let estimator: smartcore::ensemble::random_forest_classifier::RandomForestClassifier<f32> =
200+
rmp_serde::from_read(&*data).unwrap();
201+
Box::new(estimator)
202+
}
203+
158204
Algorithm::xgboost => {
159205
let bst = Booster::load_buffer(&*data).unwrap();
160206
Box::new(BoosterBox::new(bst))
@@ -320,6 +366,13 @@ smartcore_estimator_impl!(smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::
320366
smartcore_estimator_impl!(smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>>);
321367
smartcore_estimator_impl!(smartcore::linear::lasso::Lasso<f32, Array2<f32>>);
322368
smartcore_estimator_impl!(smartcore::linear::elastic_net::ElasticNet<f32, Array2<f32>>);
369+
smartcore_estimator_impl!(smartcore::linear::ridge_regression::RidgeRegression<f32, Array2<f32>>);
370+
smartcore_estimator_impl!(smartcore::neighbors::knn_regressor::KNNRegressor<f32, smartcore::math::distance::euclidian::Euclidian>);
371+
smartcore_estimator_impl!(smartcore::neighbors::knn_classifier::KNNClassifier<f32, smartcore::math::distance::euclidian::Euclidian>);
372+
smartcore_estimator_impl!(smartcore::ensemble::random_forest_regressor::RandomForestRegressor<f32>);
373+
smartcore_estimator_impl!(
374+
smartcore::ensemble::random_forest_classifier::RandomForestClassifier<f32>
375+
);
323376

324377
pub struct BoosterBox {
325378
contents: Box<xgboost::Booster>,

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,207 @@ impl Model {
628628

629629
estimator
630630
}
631+
632+
Algorithm::ridge => {
633+
train_test_split!(dataset, x_train, y_train);
634+
hyperparam_f32!(alpha, hyperparams, 1.0);
635+
hyperparam_bool!(normalize, hyperparams, false);
636+
637+
let solver = match hyperparams.get("solver") {
638+
Some(solver) => match solver.as_str().unwrap_or("cholesky") {
639+
"svd" => {
640+
smartcore::linear::ridge_regression::RidgeRegressionSolverName::SVD
641+
}
642+
_ => {
643+
smartcore::linear::ridge_regression::RidgeRegressionSolverName::Cholesky
644+
}
645+
},
646+
None => smartcore::linear::ridge_regression::RidgeRegressionSolverName::SVD,
647+
};
648+
649+
let estimator: Option<Box<dyn Estimator>> = match project.task {
650+
Task::regression => Some(
651+
Box::new(
652+
smartcore::linear::ridge_regression::RidgeRegression::fit(
653+
&x_train,
654+
&y_train,
655+
smartcore::linear::ridge_regression::RidgeRegressionParameters::default()
656+
.with_alpha(alpha)
657+
.with_normalize(normalize)
658+
.with_solver(solver)
659+
).unwrap()
660+
)
661+
),
662+
663+
Task::classification => panic!("Ridge does not support classification"),
664+
};
665+
666+
save_estimator!(estimator, self);
667+
668+
estimator
669+
}
670+
671+
Algorithm::kmeans => {
672+
todo!();
673+
}
674+
675+
Algorithm::dbscan => {
676+
todo!();
677+
}
678+
679+
Algorithm::knn => {
680+
train_test_split!(dataset, x_train, y_train);
681+
let algorithm = match hyperparams
682+
.get("algorithm")
683+
.unwrap_or(&serde_json::Value::from("linear_search"))
684+
.as_str()
685+
.unwrap_or("linear_search")
686+
{
687+
"cover_tree" => smartcore::algorithm::neighbour::KNNAlgorithmName::CoverTree,
688+
_ => smartcore::algorithm::neighbour::KNNAlgorithmName::LinearSearch,
689+
};
690+
let weight = match hyperparams
691+
.get("weight")
692+
.unwrap_or(&serde_json::Value::from("uniform"))
693+
.as_str()
694+
.unwrap_or("uniform")
695+
{
696+
"distance" => smartcore::neighbors::KNNWeightFunction::Distance,
697+
_ => smartcore::neighbors::KNNWeightFunction::Uniform,
698+
};
699+
hyperparam_usize!(k, hyperparams, 3);
700+
701+
let estimator: Option<Box<dyn Estimator>> = match project.task {
702+
Task::regression => Some(Box::new(
703+
smartcore::neighbors::knn_regressor::KNNRegressor::fit(
704+
&x_train,
705+
&y_train,
706+
smartcore::neighbors::knn_regressor::KNNRegressorParameters::default()
707+
.with_algorithm(algorithm)
708+
.with_weight(weight)
709+
.with_k(k),
710+
)
711+
.unwrap(),
712+
)),
713+
714+
Task::classification => Some(Box::new(
715+
smartcore::neighbors::knn_classifier::KNNClassifier::fit(
716+
&x_train,
717+
&y_train,
718+
smartcore::neighbors::knn_classifier::KNNClassifierParameters::default(
719+
)
720+
.with_algorithm(algorithm)
721+
.with_weight(weight)
722+
.with_k(k),
723+
)
724+
.unwrap(),
725+
)),
726+
};
727+
728+
save_estimator!(estimator, self);
729+
730+
estimator
731+
}
732+
733+
Algorithm::random_forest => {
734+
train_test_split!(dataset, x_train, y_train);
735+
736+
let max_depth = match hyperparams.get("max_depth") {
737+
Some(max_depth) => match max_depth.as_u64() {
738+
Some(max_depth) => Some(max_depth as u16),
739+
None => None,
740+
},
741+
None => None,
742+
};
743+
744+
let m = match hyperparams.get("m") {
745+
Some(m) => match m.as_u64() {
746+
Some(m) => Some(m as usize),
747+
None => None,
748+
},
749+
None => None,
750+
};
751+
752+
let split_criterion = match hyperparams
753+
.get("split_criterion")
754+
.unwrap_or(&serde_json::Value::from("gini"))
755+
.as_str()
756+
.unwrap_or("gini") {
757+
"entropy" => smartcore::tree::decision_tree_classifier::SplitCriterion::Entropy,
758+
"classification_error" => smartcore::tree::decision_tree_classifier::SplitCriterion::ClassificationError,
759+
_ => smartcore::tree::decision_tree_classifier::SplitCriterion::Gini,
760+
};
761+
762+
hyperparam_usize!(min_samples_leaf, hyperparams, 1);
763+
hyperparam_usize!(min_samples_split, hyperparams, 2);
764+
hyperparam_usize!(n_trees, hyperparams, 10);
765+
hyperparam_usize!(seed, hyperparams, 0);
766+
hyperparam_bool!(keep_samples, hyperparams, false);
767+
768+
let estimator: Option<Box<dyn Estimator>> = match project.task {
769+
Task::regression => {
770+
let mut params = smartcore::ensemble::random_forest_regressor::RandomForestRegressorParameters::default()
771+
.with_min_samples_leaf(min_samples_leaf)
772+
.with_min_samples_split(min_samples_split)
773+
.with_seed(seed as u64)
774+
.with_n_trees(n_trees as usize)
775+
.with_keep_samples(keep_samples);
776+
match max_depth {
777+
Some(max_depth) => params = params.with_max_depth(max_depth),
778+
None => (),
779+
};
780+
781+
match m {
782+
Some(m) => params = params.with_m(m),
783+
None => (),
784+
};
785+
786+
Some(
787+
Box::new(
788+
smartcore::ensemble::random_forest_regressor::RandomForestRegressor::fit(
789+
&x_train,
790+
&y_train,
791+
params,
792+
).unwrap()
793+
)
794+
)
795+
}
796+
797+
Task::classification => {
798+
let mut params = smartcore::ensemble::random_forest_classifier::RandomForestClassifierParameters::default()
799+
.with_min_samples_leaf(min_samples_leaf)
800+
.with_min_samples_split(min_samples_leaf)
801+
.with_seed(seed as u64)
802+
.with_n_trees(n_trees as u16)
803+
.with_keep_samples(keep_samples)
804+
.with_criterion(split_criterion);
805+
806+
match max_depth {
807+
Some(max_depth) => params = params.with_max_depth(max_depth),
808+
None => (),
809+
};
810+
811+
match m {
812+
Some(m) => params = params.with_m(m),
813+
None => (),
814+
};
815+
816+
Some(
817+
Box::new(
818+
smartcore::ensemble::random_forest_classifier::RandomForestClassifier::fit(
819+
&x_train,
820+
&y_train,
821+
params,
822+
).unwrap()
823+
)
824+
)
825+
}
826+
};
827+
828+
save_estimator!(estimator, self);
829+
830+
estimator
831+
}
631832
};
632833
}
633834

0 commit comments

Comments
 (0)