Skip to content

Commit c5f0ea1

Browse files
authored
Support Vector Machines (#319)
1 parent 877a40b commit c5f0ea1

File tree

4 files changed

+455
-19
lines changed

4 files changed

+455
-19
lines changed

pgml-docs/docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ article.md-content__inner.md-typeset a.md-content__button.md-icon {
3939
}
4040
</style>
4141

42-
<h1 align="center">End-to-end<br/>machine learning solution <br/>for everyone</h1>
42+
<h1 align="center">End-to-end<br/>machine learning platform <br/>for everyone</h1>
4343

4444
<p align="center" class="subtitle">
4545
Train and deploy models to make online predictions using only SQL, with an open source extension for Postgres. Manage your projects and visualize datasets using the built-in dashboard.

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use serde::Deserialize;
66
pub enum Algorithm {
77
linear,
88
xgboost,
9+
svm,
910
}
1011

1112
impl std::str::FromStr for Algorithm {
@@ -15,6 +16,7 @@ impl std::str::FromStr for Algorithm {
1516
match input {
1617
"linear" => Ok(Algorithm::linear),
1718
"xgboost" => Ok(Algorithm::xgboost),
19+
"svm" => Ok(Algorithm::svm),
1820
_ => Err(()),
1921
}
2022
}
@@ -25,6 +27,7 @@ impl std::string::ToString for Algorithm {
2527
match *self {
2628
Algorithm::linear => "linear".to_string(),
2729
Algorithm::xgboost => "xgboost".to_string(),
30+
Algorithm::svm => "svm".to_string(),
2831
}
2932
}
3033
}

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 203 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
2626
}
2727
}
2828

29-
let (task, algorithm, data) = Spi::get_three_with_args::<String, String, Vec<u8>>(
29+
let (task, algorithm, model_id) = Spi::get_three_with_args::<String, String, i64>(
3030
"
31-
SELECT projects.task::TEXT, models.algorithm::TEXT, files.data
31+
SELECT projects.task::TEXT, models.algorithm::TEXT, models.id AS model_id
3232
FROM pgml_rust.files
3333
JOIN pgml_rust.models
3434
ON models.id = files.model_id
@@ -55,6 +55,17 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
5555
)
5656
}))
5757
.unwrap();
58+
59+
let (data, hyperparams) = Spi::get_two_with_args::<Vec<u8>, JsonB>(
60+
"SELECT data, hyperparams FROM pgml_rust.models
61+
INNER JOIN pgml_rust.files
62+
ON models.id = files.model_id WHERE models.id = $1
63+
LIMIT 1",
64+
vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())],
65+
);
66+
67+
let hyperparams = hyperparams.unwrap();
68+
5869
let data = data.unwrap_or_else(|| {
5970
panic!(
6071
"Project {} does not have a trained and deployed model.",
@@ -75,6 +86,54 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
7586
let bst = Booster::load_buffer(&*data).unwrap();
7687
Box::new(BoosterBox::new(bst))
7788
}
89+
Algorithm::svm => match &hyperparams.0.as_object().unwrap().get("kernel") {
90+
Some(kernel) => match kernel.as_str().unwrap_or("linear") {
91+
"poly" => {
92+
let estimator: smartcore::svm::svr::SVR<
93+
f32,
94+
Array2<f32>,
95+
smartcore::svm::PolynomialKernel<f32>,
96+
> = rmp_serde::from_read(&*data).unwrap();
97+
Box::new(estimator)
98+
}
99+
100+
"sigmoid" => {
101+
let estimator: smartcore::svm::svr::SVR<
102+
f32,
103+
Array2<f32>,
104+
smartcore::svm::SigmoidKernel<f32>,
105+
> = rmp_serde::from_read(&*data).unwrap();
106+
Box::new(estimator)
107+
}
108+
109+
"rbf" => {
110+
let estimator: smartcore::svm::svr::SVR<
111+
f32,
112+
Array2<f32>,
113+
smartcore::svm::RBFKernel<f32>,
114+
> = rmp_serde::from_read(&*data).unwrap();
115+
Box::new(estimator)
116+
}
117+
118+
_ => {
119+
let estimator: smartcore::svm::svr::SVR<
120+
f32,
121+
Array2<f32>,
122+
smartcore::svm::LinearKernel,
123+
> = rmp_serde::from_read(&*data).unwrap();
124+
Box::new(estimator)
125+
}
126+
},
127+
128+
None => {
129+
let estimator: smartcore::svm::svr::SVR<
130+
f32,
131+
Array2<f32>,
132+
smartcore::svm::LinearKernel,
133+
> = rmp_serde::from_read(&*data).unwrap();
134+
Box::new(estimator)
135+
}
136+
},
78137
},
79138
Task::classification => match algorithm {
80139
Algorithm::linear => {
@@ -88,6 +147,54 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
88147
let bst = Booster::load_buffer(&*data).unwrap();
89148
Box::new(BoosterBox::new(bst))
90149
}
150+
Algorithm::svm => match &hyperparams.0.as_object().unwrap().get("kernel") {
151+
Some(kernel) => match kernel.as_str().unwrap_or("linear") {
152+
"poly" => {
153+
let estimator: smartcore::svm::svc::SVC<
154+
f32,
155+
Array2<f32>,
156+
smartcore::svm::PolynomialKernel<f32>,
157+
> = rmp_serde::from_read(&*data).unwrap();
158+
Box::new(estimator)
159+
}
160+
161+
"sigmoid" => {
162+
let estimator: smartcore::svm::svc::SVC<
163+
f32,
164+
Array2<f32>,
165+
smartcore::svm::SigmoidKernel<f32>,
166+
> = rmp_serde::from_read(&*data).unwrap();
167+
Box::new(estimator)
168+
}
169+
170+
"rbf" => {
171+
let estimator: smartcore::svm::svc::SVC<
172+
f32,
173+
Array2<f32>,
174+
smartcore::svm::RBFKernel<f32>,
175+
> = rmp_serde::from_read(&*data).unwrap();
176+
Box::new(estimator)
177+
}
178+
179+
_ => {
180+
let estimator: smartcore::svm::svc::SVC<
181+
f32,
182+
Array2<f32>,
183+
smartcore::svm::LinearKernel,
184+
> = rmp_serde::from_read(&*data).unwrap();
185+
Box::new(estimator)
186+
}
187+
},
188+
189+
None => {
190+
let estimator: smartcore::svm::svc::SVC<
191+
f32,
192+
Array2<f32>,
193+
smartcore::svm::LinearKernel,
194+
> = rmp_serde::from_read(&*data).unwrap();
195+
Box::new(estimator)
196+
}
197+
},
91198
},
92199
};
93200

@@ -194,6 +301,100 @@ impl Estimator for smartcore::linear::logistic_regression::LogisticRegression<f3
194301
}
195302
}
196303

304+
// All the SVM kernels :popcorn:
305+
306+
#[typetag::serialize]
307+
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::LinearKernel> {
308+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
309+
test_smartcore(self, task, data)
310+
}
311+
312+
fn predict(&self, features: Vec<f32>) -> f32 {
313+
predict_smartcore(self, features)
314+
}
315+
}
316+
317+
#[typetag::serialize]
318+
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::LinearKernel> {
319+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
320+
test_smartcore(self, task, data)
321+
}
322+
323+
fn predict(&self, features: Vec<f32>) -> f32 {
324+
predict_smartcore(self, features)
325+
}
326+
}
327+
328+
#[typetag::serialize]
329+
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>> {
330+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
331+
test_smartcore(self, task, data)
332+
}
333+
334+
fn predict(&self, features: Vec<f32>) -> f32 {
335+
predict_smartcore(self, features)
336+
}
337+
}
338+
339+
#[typetag::serialize]
340+
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::SigmoidKernel<f32>> {
341+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
342+
test_smartcore(self, task, data)
343+
}
344+
345+
fn predict(&self, features: Vec<f32>) -> f32 {
346+
predict_smartcore(self, features)
347+
}
348+
}
349+
350+
#[typetag::serialize]
351+
impl Estimator
352+
for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
353+
{
354+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
355+
test_smartcore(self, task, data)
356+
}
357+
358+
fn predict(&self, features: Vec<f32>) -> f32 {
359+
predict_smartcore(self, features)
360+
}
361+
}
362+
363+
#[typetag::serialize]
364+
impl Estimator
365+
for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::PolynomialKernel<f32>>
366+
{
367+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
368+
test_smartcore(self, task, data)
369+
}
370+
371+
fn predict(&self, features: Vec<f32>) -> f32 {
372+
predict_smartcore(self, features)
373+
}
374+
}
375+
376+
#[typetag::serialize]
377+
impl Estimator for smartcore::svm::svc::SVC<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>> {
378+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
379+
test_smartcore(self, task, data)
380+
}
381+
382+
fn predict(&self, features: Vec<f32>) -> f32 {
383+
predict_smartcore(self, features)
384+
}
385+
}
386+
387+
#[typetag::serialize]
388+
impl Estimator for smartcore::svm::svr::SVR<f32, Array2<f32>, smartcore::svm::RBFKernel<f32>> {
389+
fn test(&self, task: Task, data: &Dataset) -> HashMap<String, f32> {
390+
test_smartcore(self, task, data)
391+
}
392+
393+
fn predict(&self, features: Vec<f32>) -> f32 {
394+
predict_smartcore(self, features)
395+
}
396+
}
397+
197398
pub struct BoosterBox {
198399
contents: Box<xgboost::Booster>,
199400
}

0 commit comments

Comments
 (0)