@@ -31,6 +31,77 @@ pub fn sklearn_version() -> String {
31
31
version
32
32
}
33
33
34
+ fn sklearn_algorithm_name ( task : Task , algorithm : Algorithm ) -> & ' static str {
35
+ match task {
36
+ Task :: regression => match algorithm {
37
+ Algorithm :: linear => "linear_regression" ,
38
+ Algorithm :: lasso => "lasso_regression" ,
39
+ Algorithm :: svm => "svm_regression" ,
40
+ Algorithm :: elastic_net => "elastic_net_regression" ,
41
+ Algorithm :: ridge => "ridge_regression" ,
42
+ Algorithm :: random_forest => "random_forest_regression" ,
43
+ Algorithm :: xgboost => {
44
+ panic ! ( "Sklearn doesn't support XGBoost, use 'xgboost' engine instead" )
45
+ }
46
+ Algorithm :: orthogonal_matching_pursuit => "orthogonal_matching_persuit_regression" ,
47
+ Algorithm :: bayesian_ridge => "bayesian_ridge_regression" ,
48
+ Algorithm :: automatic_relevance_determination => {
49
+ "automatic_relevance_determination_regression"
50
+ }
51
+ Algorithm :: stochastic_gradient_descent => "stochastic_gradient_descent_regression" ,
52
+ Algorithm :: passive_aggressive => "passive_aggressive_regression" ,
53
+ Algorithm :: ransac => "ransac_regression" ,
54
+ Algorithm :: theil_sen => "theil_sen_regression" ,
55
+ Algorithm :: huber => "huber_regression" ,
56
+ Algorithm :: quantile => "quantile_regression" ,
57
+ Algorithm :: kernel_ridge => "kernel_ridge_regression" ,
58
+ Algorithm :: gaussian_process => "gaussian_process_regression" ,
59
+ Algorithm :: nu_svm => "nu_svm_regression" ,
60
+ Algorithm :: ada_boost => "ada_boost_regression" ,
61
+ Algorithm :: bagging => "bagging_regression" ,
62
+ Algorithm :: extra_trees => "extra_trees_regression" ,
63
+ Algorithm :: gradient_boosting_trees => "gradient_boosting_trees_regression" ,
64
+ Algorithm :: hist_gradient_boosting => "hist_gradient_boosting_regression" ,
65
+ Algorithm :: least_angle => "least_angle_regression" ,
66
+ Algorithm :: lasso_least_angle => "lasso_least_angle_regression" ,
67
+ Algorithm :: linear_svm => "linear_svm_regression" ,
68
+ _ => panic ! ( "{:?} does not support regression" , algorithm) ,
69
+ } ,
70
+
71
+ Task :: classification => match algorithm {
72
+ Algorithm :: linear => "linear_classification" ,
73
+ Algorithm :: lasso => panic ! ( "Sklearn Lasso does not support classification" ) ,
74
+ Algorithm :: svm => "svm_classification" ,
75
+ Algorithm :: elastic_net => panic ! ( "Sklearn Elastic Net does not support classification" ) ,
76
+ Algorithm :: ridge => "ridge_classification" ,
77
+ Algorithm :: random_forest => "random_forest_classification" ,
78
+ Algorithm :: xgboost => {
79
+ panic ! ( "Sklearn doesn't support XGBoost, use 'xgboost' engine instead" )
80
+ }
81
+ Algorithm :: stochastic_gradient_descent => "stochastic_gradient_descent_classification" ,
82
+ Algorithm :: perceptron => "perceptron_classification" ,
83
+ Algorithm :: passive_aggressive => "passive_aggressive_classification" ,
84
+ Algorithm :: gaussian_process => "gaussian_process" ,
85
+ Algorithm :: nu_svm => "nu_svm_classification" ,
86
+ Algorithm :: ada_boost => "ada_boost_classification" ,
87
+ Algorithm :: bagging => "bagging_classification" ,
88
+ Algorithm :: extra_trees => "extra_trees_classification" ,
89
+ Algorithm :: gradient_boosting_trees => "gradient_boosting_trees_classification" ,
90
+ Algorithm :: hist_gradient_boosting => "hist_gradient_boosting_classification" ,
91
+ Algorithm :: linear_svm => "linear_svm_classification" ,
92
+ Algorithm :: least_angle => panic ! ( "least_angle does not support classification" ) ,
93
+ Algorithm :: orthogonal_matching_pursuit => {
94
+ panic ! ( "orthogonal_matching_pursuit does not support classification" )
95
+ }
96
+ Algorithm :: bayesian_ridge => panic ! ( "bayesian_ridge does not support classification" ) ,
97
+ Algorithm :: lasso_least_angle => {
98
+ panic ! ( "lasso_least_angle does not support classification" )
99
+ }
100
+ _ => panic ! ( "{:?} does not support classification" , algorithm) ,
101
+ } ,
102
+ }
103
+ }
104
+
34
105
pub fn sklearn_train (
35
106
task : Task ,
36
107
algorithm : Algorithm ,
@@ -42,18 +113,7 @@ pub fn sklearn_train(
42
113
"/src/engines/wrappers.py"
43
114
) ) ;
44
115
45
- let algorithm_name = match task {
46
- Task :: regression => match algorithm {
47
- Algorithm :: linear => "linear_regression" ,
48
- _ => todo ! ( ) ,
49
- } ,
50
-
51
- Task :: classification => match algorithm {
52
- Algorithm :: linear => "linear_classification" ,
53
- _ => todo ! ( ) ,
54
- } ,
55
- } ;
56
-
116
+ let algorithm_name = sklearn_algorithm_name ( task, algorithm) ;
57
117
let hyperparams = serde_json:: to_string ( hyperparams) . unwrap ( ) ;
58
118
59
119
let estimator = Python :: with_gil ( |py| -> Py < PyAny > {
@@ -189,17 +249,7 @@ pub fn sklearn_search(
189
249
"/src/engines/wrappers.py"
190
250
) ) ;
191
251
192
- let algorithm_name = match task {
193
- Task :: regression => match algorithm {
194
- Algorithm :: linear => "linear_regression" ,
195
- _ => todo ! ( ) ,
196
- } ,
197
-
198
- Task :: classification => match algorithm {
199
- Algorithm :: linear => "linear_classification" ,
200
- _ => todo ! ( ) ,
201
- } ,
202
- } ;
252
+ let algorithm_name = sklearn_algorithm_name ( task, algorithm) ;
203
253
204
254
Python :: with_gil ( |py| -> ( SklearnBox , Hyperparams ) {
205
255
let module = PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;
0 commit comments