Skip to content

Commit 8d0c2de

Browse files
authored
implement deploys in rust (#311)
1 parent 7fab677 commit 8d0c2de

File tree

3 files changed

+126
-20
lines changed

3 files changed

+126
-20
lines changed

pgml-extension/pgml_rust/src/api.rs

Lines changed: 121 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::str::FromStr;
2+
13
use pgx::*;
24

35
use crate::orm::Algorithm;
@@ -22,7 +24,8 @@ fn train(
2224
search_args: default!(JsonB, "'{}'"),
2325
test_size: default!(f32, 0.25),
2426
test_sampling: default!(Sampling, "'last'"),
25-
) {
27+
) -> impl std::iter::Iterator<Item = (name!(project, String), name!(task, String), name!(algorithm, String), name!(deployed, bool))>
28+
{
2629
let project = match Project::find_by_name(project_name) {
2730
Some(project) => project,
2831
None => Project::create(project_name, task.unwrap()),
@@ -50,15 +53,122 @@ fn train(
5053
search_args,
5154
);
5255

53-
// TODO move deployment into a struct and only deploy if new model is better than old model
56+
let new_metrics: &serde_json::Value = &model.metrics.unwrap().0;
57+
let new_metrics = new_metrics.as_object().unwrap();
58+
59+
let deployed_metrics = Spi::get_one_with_args::<JsonB>(
60+
"
61+
SELECT models.metrics
62+
FROM pgml_rust.models
63+
JOIN pgml_rust.deployments
64+
ON deployments.model_id = models.id
65+
JOIN pgml_rust.projects
66+
ON projects.id = deployments.project_id
67+
WHERE projects.name = $1
68+
ORDER by deployments.created_at DESC
69+
LIMIT 1;",
70+
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
71+
);
72+
73+
let mut deploy = false;
74+
if deployed_metrics.is_none() {
75+
deploy = true;
76+
} else {
77+
let deployed_metrics = deployed_metrics.unwrap().0;
78+
let deployed_metrics = deployed_metrics.as_object().unwrap();
79+
if project.task == Task::classification && deployed_metrics.get("f1").unwrap().as_f64() < new_metrics.get("f1").unwrap().as_f64() {
80+
deploy = true;
81+
}
82+
if project.task == Task::regression && deployed_metrics.get("r2").unwrap().as_f64() < new_metrics.get("r2").unwrap().as_f64() {
83+
deploy = true;
84+
}
85+
}
86+
87+
if deploy {
88+
Spi::get_one_with_args::<i64>(
89+
"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id",
90+
vec![
91+
(PgBuiltInOids::INT8OID.oid(), project.id.into_datum()),
92+
(PgBuiltInOids::INT8OID.oid(), model.id.into_datum()),
93+
(PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()),
94+
]
95+
);
96+
}
97+
98+
vec![(project.name, project.task.to_string(), model.algorithm.to_string(), deploy)].into_iter()
99+
}
100+
101+
#[pg_extern]
102+
fn deploy(
103+
project_name: &str,
104+
strategy: Strategy,
105+
algorithm: Option<default!(Algorithm, "NULL")>,
106+
) -> impl std::iter::Iterator<Item = (name!(project, String), name!(strategy, String), name!(algorithm, String))> {
107+
let (project_id, task) = Spi::get_two_with_args::<i64, String>(
108+
"SELECT id, task::TEXT from pgml_rust.projects WHERE name = $1",
109+
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
110+
);
111+
let project_id = project_id.expect(format!("Project named `{}` does not exist.", project_name).as_str());
112+
let task = Task::from_str(&task.unwrap()).unwrap();
113+
114+
let mut sql = "SELECT models.id, models.algorithm::TEXT FROM pgml_rust.models JOIN pgml_rust.projects ON projects.id = models.project_id".to_string();
115+
let mut predicate = "\nWHERE projects.name = $1".to_string();
116+
match algorithm {
117+
Some(algorithm) => predicate += &format!("\nAND algorithm::TEXT = '{}'", algorithm.to_string().as_str()),
118+
_ => (),
119+
}
120+
match strategy {
121+
Strategy::best_score => {
122+
match task {
123+
Task::regression => {
124+
sql += &format!("{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST");
125+
},
126+
Task::classification => {
127+
sql += &format!("{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST");
128+
}
129+
}
130+
},
131+
Strategy::most_recent => {
132+
sql += &format!("{predicate}\nORDER by models.created_at DESC");
133+
},
134+
Strategy::rollback => {
135+
sql += &format!("
136+
JOIN pgml_rust.deployments ON deployments.project_id = projects.id
137+
AND deployments.model_id = models.id
138+
AND models.id != (
139+
SELECT models.id
140+
FROM pgml_rust.models
141+
JOIN pgml_rust.deployments
142+
ON deployments.model_id = models.id
143+
JOIN pgml_rust.projects
144+
ON projects.id = deployments.project_id
145+
WHERE projects.name = $1
146+
ORDER by deployments.created_at DESC
147+
LIMIT 1
148+
)
149+
{predicate}
150+
ORDER by deployments.created_at DESC
151+
");
152+
},
153+
_ => error!("invalid stategy")
154+
}
155+
sql += "\nLIMIT 1";
156+
let (model_id, algorithm_name) = Spi::get_two_with_args::<i64, String>(&sql,
157+
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
158+
);
159+
let model_id = model_id.expect("No qualified models exist for this deployment.");
160+
let algorithm_name = algorithm_name.expect("No qualified models exist for this deployment.");
161+
54162
Spi::get_one_with_args::<i64>(
55163
"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id",
56164
vec![
57-
(PgBuiltInOids::INT8OID.oid(), project.id.into_datum()),
58-
(PgBuiltInOids::INT8OID.oid(), model.id.into_datum()),
59-
(PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()),
165+
(PgBuiltInOids::INT8OID.oid(), project_id.into_datum()),
166+
(PgBuiltInOids::INT8OID.oid(), model_id.into_datum()),
167+
(PgBuiltInOids::TEXTOID.oid(), strategy.to_string().into_datum()),
60168
]
61169
);
170+
171+
vec![(project_name.to_string(), strategy.to_string(), algorithm_name)].into_iter()
62172
}
63173

64174
#[pg_extern]
@@ -67,22 +177,15 @@ fn predict(project_name: &str, features: Vec<f32>) -> f32 {
67177
estimator.predict(features)
68178
}
69179

70-
// #[pg_extern]
71-
// fn return_table_example() -> impl std::Iterator<Item = (name!(id, Option<i64>), name!(title, Option<String>))> {
72-
// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None)
73-
// vec![tuple].into_iter()
74-
// }
75-
76180
#[pg_extern]
77-
fn create_snapshot(
181+
fn snapshot(
78182
relation_name: &str,
79183
y_column_name: &str,
80-
test_size: f32,
81-
test_sampling: Sampling,
82-
) -> i64 {
83-
let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling);
84-
info!("{:?}", snapshot);
85-
snapshot.id
184+
test_size: default!(f32, 0.25),
185+
test_sampling: default!(Sampling, "'last'"),
186+
) -> impl std::iter::Iterator<Item = (name!(relation, String), name!(y_column_name, String))> {
187+
Snapshot::create(relation_name, y_column_name, test_size, test_sampling);
188+
vec![(relation_name.to_string(), y_column_name.to_string())].into_iter()
86189
}
87190

88191
#[cfg(any(test, feature = "pg_test"))]

pgml-extension/pgml_rust/src/orm/project.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ impl Project {
1919
let mut project: Option<Project> = None;
2020

2121
Spi::connect(|client| {
22-
let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;",
22+
let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;",
2323
Some(1),
2424
Some(vec![
2525
(PgBuiltInOids::INT8OID.oid(), id.into_datum()),
@@ -44,7 +44,7 @@ impl Project {
4444
let mut project = None;
4545

4646
Spi::connect(|client| {
47-
let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;",
47+
let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;",
4848
Some(1),
4949
Some(vec![
5050
(PgBuiltInOids::TEXTOID.oid(), name.into_datum()),

pgml-extension/pgml_rust/src/orm/strategy.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use serde::Deserialize;
44
#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)]
55
#[allow(non_camel_case_types)]
66
pub enum Strategy {
7+
new_score,
78
best_score,
89
most_recent,
910
rollback,
@@ -14,6 +15,7 @@ impl std::str::FromStr for Strategy {
1415

1516
fn from_str(input: &str) -> Result<Strategy, Self::Err> {
1617
match input {
18+
"new_score" => Ok(Strategy::new_score),
1719
"best_score" => Ok(Strategy::best_score),
1820
"most_recent" => Ok(Strategy::most_recent),
1921
"rollback" => Ok(Strategy::rollback),
@@ -25,6 +27,7 @@ impl std::str::FromStr for Strategy {
2527
impl std::string::ToString for Strategy {
2628
fn to_string(&self) -> String {
2729
match *self {
30+
Strategy::new_score => "new_score".to_string(),
2831
Strategy::best_score => "best_score".to_string(),
2932
Strategy::most_recent => "most_recent".to_string(),
3033
Strategy::rollback => "rollback".to_string(),

0 commit comments

Comments
 (0)