Skip to content

Commit f934892

Browse files
authored
Lasso (#320)
1 parent c5f0ea1 commit f934892

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub enum Algorithm {
77
linear,
88
xgboost,
99
svm,
10+
lasso,
1011
}
1112

1213
impl std::str::FromStr for Algorithm {
@@ -17,6 +18,7 @@ impl std::str::FromStr for Algorithm {
1718
"linear" => Ok(Algorithm::linear),
1819
"xgboost" => Ok(Algorithm::xgboost),
1920
"svm" => Ok(Algorithm::svm),
21+
"lasso" => Ok(Algorithm::lasso),
2022
_ => Err(()),
2123
}
2224
}
@@ -28,6 +30,7 @@ impl std::string::ToString for Algorithm {
2830
Algorithm::linear => "linear".to_string(),
2931
Algorithm::xgboost => "xgboost".to_string(),
3032
Algorithm::svm => "svm".to_string(),
33+
Algorithm::lasso => "lasso".to_string(),
3134
}
3235
}
3336
}

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
8282
> = rmp_serde::from_read(&*data).unwrap();
8383
Box::new(estimator)
8484
}
85+
Algorithm::lasso => {
86+
let estimator: smartcore::linear::lasso::Lasso<f32, Array2<f32>> =
87+
rmp_serde::from_read(&*data).unwrap();
88+
Box::new(estimator)
89+
}
8590
Algorithm::xgboost => {
8691
let bst = Booster::load_buffer(&*data).unwrap();
8792
Box::new(BoosterBox::new(bst))
@@ -143,6 +148,7 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
143148
> = rmp_serde::from_read(&*data).unwrap();
144149
Box::new(estimator)
145150
}
151+
Algorithm::lasso => panic!("Lasso does not support classification"),
146152
Algorithm::xgboost => {
147153
let bst = Booster::load_buffer(&*data).unwrap();
148154
Box::new(BoosterBox::new(bst))
@@ -395,6 +401,17 @@ impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RB
395401
}
396402
}
397403

404+
#[typetag::serialize]
405+
impl Estimator for smartcore::linear::lasso::Lasso<f32, Array2<f32>> {
406+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
407+
test_smartcore(self, task, data)
408+
}
409+
410+
fn predict(&self, features: Vec<f32>) -> f32 {
411+
predict_smartcore(self, features)
412+
}
413+
}
414+
398415
pub struct BoosterBox {
399416
contents: Box<xgboost::Booster>,
400417
}

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,69 @@ impl Model {
555555
(PgBuiltInOids::INT8OID.oid(), self.id.into_datum()),
556556
(PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()),
557557
]
558-
).unwrap();
558+
).unwrap();
559559
Some(Box::new(BoosterBox::new(bst)))
560560
}
561+
562+
Algorithm::lasso => {
563+
let x_train = Array2::from_shape_vec(
564+
(dataset.num_train_rows, dataset.num_features),
565+
dataset.x_train().to_vec(),
566+
)
567+
.unwrap();
568+
569+
let y_train =
570+
Array1::from_shape_vec(dataset.num_train_rows, dataset.y_train().to_vec())
571+
.unwrap();
572+
573+
let alpha = match hyperparams.get("alpha") {
574+
Some(alpha) => alpha.as_f64().unwrap_or(1.0) as f32,
575+
_ => 1.0,
576+
};
577+
578+
let normalize = match hyperparams.get("normalize") {
579+
Some(normalize) => normalize.as_bool().unwrap_or(false),
580+
_ => false,
581+
};
582+
583+
let tol = match hyperparams.get("tol") {
584+
Some(tol) => tol.as_f64().unwrap_or(1e-4) as f32,
585+
_ => 1e-4,
586+
};
587+
588+
let max_iter = match hyperparams.get("max_iter") {
589+
Some(max_iter) => max_iter.as_u64().unwrap_or(1000) as usize,
590+
_ => 1000,
591+
};
592+
593+
let estimator: Option<Box<dyn Estimator>> = match project.task {
594+
Task::regression => Some(Box::new(
595+
smartcore::linear::lasso::Lasso::fit(
596+
&x_train,
597+
&y_train,
598+
smartcore::linear::lasso::LassoParameters::default()
599+
.with_alpha(alpha)
600+
.with_normalize(normalize)
601+
.with_tol(tol)
602+
.with_max_iter(max_iter),
603+
)
604+
.unwrap(),
605+
)),
606+
607+
Task::classification => panic!("Lasso only supports regression"),
608+
};
609+
610+
let bytes: Vec<u8> = rmp_serde::to_vec(estimator.as_ref().unwrap()).unwrap();
611+
Spi::get_one_with_args::<i64>(
612+
"INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id",
613+
vec![
614+
(PgBuiltInOids::INT8OID.oid(), self.id.into_datum()),
615+
(PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()),
616+
]
617+
).unwrap();
618+
619+
estimator
620+
}
561621
};
562622
}
563623

0 commit comments

Comments
 (0)