Skip to content

Commit 29abf63

Browse files
authored
Add more Scikit algorithms and tests (#334)
1 parent aebd36d commit 29abf63

File tree

7 files changed

+417
-33
lines changed

7 files changed

+417
-33
lines changed

pgml-extension/pgml_rust/src/engines/sklearn.rs

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,77 @@ pub fn sklearn_version() -> String {
3131
version
3232
}
3333

34+
fn sklearn_algorithm_name(task: Task, algorithm: Algorithm) -> &'static str {
35+
match task {
36+
Task::regression => match algorithm {
37+
Algorithm::linear => "linear_regression",
38+
Algorithm::lasso => "lasso_regression",
39+
Algorithm::svm => "svm_regression",
40+
Algorithm::elastic_net => "elastic_net_regression",
41+
Algorithm::ridge => "ridge_regression",
42+
Algorithm::random_forest => "random_forest_regression",
43+
Algorithm::xgboost => {
44+
panic!("Sklearn doesn't support XGBoost, use 'xgboost' engine instead")
45+
}
46+
Algorithm::orthogonal_matching_pursuit => "orthogonal_matching_persuit_regression",
47+
Algorithm::bayesian_ridge => "bayesian_ridge_regression",
48+
Algorithm::automatic_relevance_determination => {
49+
"automatic_relevance_determination_regression"
50+
}
51+
Algorithm::stochastic_gradient_descent => "stochastic_gradient_descent_regression",
52+
Algorithm::passive_aggressive => "passive_aggressive_regression",
53+
Algorithm::ransac => "ransac_regression",
54+
Algorithm::theil_sen => "theil_sen_regression",
55+
Algorithm::huber => "huber_regression",
56+
Algorithm::quantile => "quantile_regression",
57+
Algorithm::kernel_ridge => "kernel_ridge_regression",
58+
Algorithm::gaussian_process => "gaussian_process_regression",
59+
Algorithm::nu_svm => "nu_svm_regression",
60+
Algorithm::ada_boost => "ada_boost_regression",
61+
Algorithm::bagging => "bagging_regression",
62+
Algorithm::extra_trees => "extra_trees_regression",
63+
Algorithm::gradient_boosting_trees => "gradient_boosting_trees_regression",
64+
Algorithm::hist_gradient_boosting => "hist_gradient_boosting_regression",
65+
Algorithm::least_angle => "least_angle_regression",
66+
Algorithm::lasso_least_angle => "lasso_least_angle_regression",
67+
Algorithm::linear_svm => "linear_svm_regression",
68+
_ => panic!("{:?} does not support regression", algorithm),
69+
},
70+
71+
Task::classification => match algorithm {
72+
Algorithm::linear => "linear_classification",
73+
Algorithm::lasso => panic!("Sklearn Lasso does not support classification"),
74+
Algorithm::svm => "svm_classification",
75+
Algorithm::elastic_net => panic!("Sklearn Elastic Net does not support classification"),
76+
Algorithm::ridge => "ridge_classification",
77+
Algorithm::random_forest => "random_forest_classification",
78+
Algorithm::xgboost => {
79+
panic!("Sklearn doesn't support XGBoost, use 'xgboost' engine instead")
80+
}
81+
Algorithm::stochastic_gradient_descent => "stochastic_gradient_descent_classification",
82+
Algorithm::perceptron => "perceptron_classification",
83+
Algorithm::passive_aggressive => "passive_aggressive_classification",
84+
Algorithm::gaussian_process => "gaussian_process",
85+
Algorithm::nu_svm => "nu_svm_classification",
86+
Algorithm::ada_boost => "ada_boost_classification",
87+
Algorithm::bagging => "bagging_classification",
88+
Algorithm::extra_trees => "extra_trees_classification",
89+
Algorithm::gradient_boosting_trees => "gradient_boosting_trees_classification",
90+
Algorithm::hist_gradient_boosting => "hist_gradient_boosting_classification",
91+
Algorithm::linear_svm => "linear_svm_classification",
92+
Algorithm::least_angle => panic!("least_angle does not support classification"),
93+
Algorithm::orthogonal_matching_pursuit => {
94+
panic!("orthogonal_matching_pursuit does not support classification")
95+
}
96+
Algorithm::bayesian_ridge => panic!("bayesian_ridge does not support classification"),
97+
Algorithm::lasso_least_angle => {
98+
panic!("lasso_least_angle does not support classification")
99+
}
100+
_ => panic!("{:?} does not support classification", algorithm),
101+
},
102+
}
103+
}
104+
34105
pub fn sklearn_train(
35106
task: Task,
36107
algorithm: Algorithm,
@@ -42,18 +113,7 @@ pub fn sklearn_train(
42113
"/src/engines/wrappers.py"
43114
));
44115

45-
let algorithm_name = match task {
46-
Task::regression => match algorithm {
47-
Algorithm::linear => "linear_regression",
48-
_ => todo!(),
49-
},
50-
51-
Task::classification => match algorithm {
52-
Algorithm::linear => "linear_classification",
53-
_ => todo!(),
54-
},
55-
};
56-
116+
let algorithm_name = sklearn_algorithm_name(task, algorithm);
57117
let hyperparams = serde_json::to_string(hyperparams).unwrap();
58118

59119
let estimator = Python::with_gil(|py| -> Py<PyAny> {
@@ -189,17 +249,7 @@ pub fn sklearn_search(
189249
"/src/engines/wrappers.py"
190250
));
191251

192-
let algorithm_name = match task {
193-
Task::regression => match algorithm {
194-
Algorithm::linear => "linear_regression",
195-
_ => todo!(),
196-
},
197-
198-
Task::classification => match algorithm {
199-
Algorithm::linear => "linear_classification",
200-
_ => todo!(),
201-
},
202-
};
252+
let algorithm_name = sklearn_algorithm_name(task, algorithm);
203253

204254
Python::with_gil(|py| -> (SklearnBox, Hyperparams) {
205255
let module = PyModule::from_code(py, module, "", "").unwrap();

pgml-extension/pgml_rust/src/engines/smartcore.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ pub fn smartcore_train(
484484
}
485485
}
486486
}
487+
488+
_ => todo!(),
487489
}
488490
}
489491

@@ -595,6 +597,8 @@ pub fn smartcore_load(
595597
Box::new(estimator)
596598
}
597599
},
600+
601+
_ => todo!(),
598602
},
599603

600604
Task::classification => match algorithm {
@@ -674,6 +678,8 @@ pub fn smartcore_load(
674678
Box::new(estimator)
675679
}
676680
},
681+
682+
_ => todo!(),
677683
},
678684
}
679685
}

pgml-extension/pgml_rust/src/engines/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"elastic_net_regression": sklearn.linear_model.ElasticNet,
2424
"least_angle_regression": sklearn.linear_model.Lars,
2525
"lasso_least_angle_regression": sklearn.linear_model.LassoLars,
26-
"orthoganl_matching_pursuit_regression": sklearn.linear_model.OrthogonalMatchingPursuit,
26+
"orthogonal_matching_persuit_regression": sklearn.linear_model.OrthogonalMatchingPursuit,
2727
"bayesian_ridge_regression": sklearn.linear_model.BayesianRidge,
2828
"automatic_relevance_determination_regression": sklearn.linear_model.ARDRegression,
2929
"stochastic_gradient_descent_regression": sklearn.linear_model.SGDRegressor,

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@ pub enum Algorithm {
1414
dbscan,
1515
knn,
1616
random_forest,
17+
least_angle,
18+
lasso_least_angle,
19+
orthogonal_matching_pursuit,
20+
bayesian_ridge,
21+
automatic_relevance_determination,
22+
stochastic_gradient_descent,
23+
perceptron,
24+
passive_aggressive,
25+
ransac,
26+
theil_sen,
27+
huber,
28+
quantile,
29+
kernel_ridge,
30+
gaussian_process,
31+
nu_svm,
32+
ada_boost,
33+
bagging,
34+
extra_trees,
35+
gradient_boosting_trees,
36+
hist_gradient_boosting,
37+
linear_svm,
1738
}
1839

1940
impl std::str::FromStr for Algorithm {
@@ -31,6 +52,27 @@ impl std::str::FromStr for Algorithm {
3152
"dbscan" => Ok(Algorithm::dbscan),
3253
"knn" => Ok(Algorithm::knn),
3354
"random_forest" => Ok(Algorithm::random_forest),
55+
"least_angle" => Ok(Algorithm::least_angle),
56+
"lasso_least_angle" => Ok(Algorithm::lasso_least_angle),
57+
"orthogonal_matching_pursuit" => Ok(Algorithm::orthogonal_matching_pursuit),
58+
"bayesian_ridge" => Ok(Algorithm::bayesian_ridge),
59+
"automatic_relevance_determination" => Ok(Algorithm::automatic_relevance_determination),
60+
"stochastic_gradient_descent" => Ok(Algorithm::stochastic_gradient_descent),
61+
"perceptron" => Ok(Algorithm::perceptron),
62+
"passive_aggressive" => Ok(Algorithm::passive_aggressive),
63+
"ransac" => Ok(Algorithm::ransac),
64+
"theil_sen" => Ok(Algorithm::theil_sen),
65+
"huber" => Ok(Algorithm::huber),
66+
"quantile" => Ok(Algorithm::quantile),
67+
"kernel_ridge" => Ok(Algorithm::kernel_ridge),
68+
"gaussian_process" => Ok(Algorithm::gaussian_process),
69+
"nu_svm" => Ok(Algorithm::nu_svm),
70+
"ada_boost" => Ok(Algorithm::ada_boost),
71+
"bagging" => Ok(Algorithm::bagging),
72+
"extra_trees" => Ok(Algorithm::extra_trees),
73+
"gradient_boosting_trees" => Ok(Algorithm::gradient_boosting_trees),
74+
"hist_gradient_boosting" => Ok(Algorithm::hist_gradient_boosting),
75+
"linear_svm" => Ok(Algorithm::linear_svm),
3476
_ => Err(()),
3577
}
3678
}
@@ -49,6 +91,29 @@ impl std::string::ToString for Algorithm {
4991
Algorithm::dbscan => "dbscan".to_string(),
5092
Algorithm::knn => "knn".to_string(),
5193
Algorithm::random_forest => "random_forest".to_string(),
94+
Algorithm::least_angle => "least_angle".to_string(),
95+
Algorithm::lasso_least_angle => "lasso_least_angle".to_string(),
96+
Algorithm::orthogonal_matching_pursuit => "orthogonal_matching_pursuit".to_string(),
97+
Algorithm::bayesian_ridge => "bayesian_ridge".to_string(),
98+
Algorithm::automatic_relevance_determination => {
99+
"automatic_relevance_determination".to_string()
100+
}
101+
Algorithm::stochastic_gradient_descent => "stochastic_gradient_descent".to_string(),
102+
Algorithm::perceptron => "perceptron".to_string(),
103+
Algorithm::passive_aggressive => "passive_aggressive".to_string(),
104+
Algorithm::ransac => "ransac".to_string(),
105+
Algorithm::theil_sen => "theil_sen".to_string(),
106+
Algorithm::huber => "huber".to_string(),
107+
Algorithm::quantile => "quantile".to_string(),
108+
Algorithm::kernel_ridge => "kernel_ridge".to_string(),
109+
Algorithm::gaussian_process => "gaussian_process".to_string(),
110+
Algorithm::nu_svm => "nu_svm".to_string(),
111+
Algorithm::ada_boost => "ada_boost".to_string(),
112+
Algorithm::bagging => "bagging".to_string(),
113+
Algorithm::extra_trees => "extra_trees".to_string(),
114+
Algorithm::gradient_boosting_trees => "gradient_boosting_trees".to_string(),
115+
Algorithm::hist_gradient_boosting => "hist_gradient_boosting".to_string(),
116+
Algorithm::linear_svm => "linear_svm".to_string(),
52117
}
53118
}
54119
}

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,7 @@ impl Model {
5151
Some(engine) => engine,
5252
None => match algorithm {
5353
Algorithm::xgboost => Engine::xgboost,
54-
Algorithm::linear => Engine::sklearn,
55-
Algorithm::svm => Engine::sklearn,
56-
Algorithm::lasso => Engine::sklearn,
57-
Algorithm::elastic_net => Engine::sklearn,
58-
Algorithm::ridge => Engine::sklearn,
59-
Algorithm::kmeans => Engine::sklearn,
60-
Algorithm::dbscan => Engine::sklearn,
61-
Algorithm::knn => Engine::sklearn,
62-
Algorithm::random_forest => Engine::sklearn,
54+
_ => Engine::sklearn,
6355
},
6456
};
6557

0 commit comments

Comments
 (0)