Skip to content

Commit 7fab677

Browse files
authored
start integrating smartcore for common algos in rust (#301)
1 parent a5a1ff8 commit 7fab677

File tree

17 files changed

+2000
-764
lines changed

17 files changed

+2000
-764
lines changed

pgml-extension/pgml_rust/Cargo.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@ pg14 = ["pgx/pg14", "pgx-tests/pg14" ]
1616
pg_test = []
1717

1818
[dependencies]
19-
pgx = "=0.4.5"
20-
xgboost = { path = "rust-xgboost" }
21-
rustlearn = "0.5"
19+
pgx = "0.4.5"
2220
once_cell = "1"
2321
rand = "0.8"
22+
xgboost = { path = "rust-xgboost" }
23+
smartcore = { version = "0.2.0", features = ["serde", "ndarray-bindings"] }
24+
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
2425
blas = { version = "0.22.0" }
2526
blas-src = { version = "0.8", features = ["openblas"] }
2627
openblas-src = { version = "0.10", features = ["cblas", "system"] }
28+
serde = { version = "1.0.2" }
29+
serde_json = { version = "1.0.85" }
30+
rmp-serde = { version = "1.1.0" }
31+
typetag = "0.2"
2732

2833
[dev-dependencies]
2934
pgx-tests = "=0.4.5"
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
comment = 'pgml_rust: Created by pgx'
1+
comment = 'pgml_rust: Created by the PostgresML team'
22
default_version = '@CARGO_VERSION@'
33
module_pathname = '$libdir/pgml_rust'
44
relocatable = false
55
superuser = false
6+
schema = 'pgml_rust'

pgml-extension/pgml_rust/sql/schema.sql

Lines changed: 130 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
CREATE SCHEMA IF NOT EXISTS pgml_rust;
2-
31
---
42
--- Track of updates to data
53
---
@@ -33,43 +31,164 @@ BEGIN
3331
) THEN
3432
NEW.updated_at := clock_timestamp();
3533
END IF;
36-
RETURN NEW;
34+
RETURN new;
3735
END;
3836
$$
3937
LANGUAGE plpgsql;
4038

39+
4140
---
4241
--- Projects organize work
4342
---
4443
CREATE TABLE IF NOT EXISTS pgml_rust.projects(
4544
id BIGSERIAL PRIMARY KEY,
46-
name TEXT NOT NULL UNIQUE,
47-
task TEXT NOT NULL,
45+
name TEXT NOT NULL,
46+
task pgml_rust.task NOT NULL,
4847
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
4948
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp()
5049
);
5150
SELECT pgml_rust.auto_updated_at('pgml_rust.projects');
51+
CREATE UNIQUE INDEX IF NOT EXISTS projects_name_idx ON pgml_rust.projects(name);
5252

5353

54-
CREATE TABLE IF NOT EXISTS pgml_rust.models (
54+
---
55+
--- Snapshots freeze data for training
56+
---
57+
CREATE TABLE IF NOT EXISTS pgml_rust.snapshots(
5558
id BIGSERIAL PRIMARY KEY,
56-
project_id BIGINT NOT NULL REFERENCES pgml_rust.projects(id),
57-
algorithm VARCHAR,
58-
data BYTEA
59+
relation_name TEXT NOT NULL,
60+
y_column_name TEXT[] NOT NULL,
61+
test_size FLOAT4 NOT NULL,
62+
test_sampling pgml_rust.sampling NOT NULL,
63+
status TEXT NOT NULL,
64+
columns JSONB,
65+
analysis JSONB,
66+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
67+
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp()
5968
);
69+
SELECT pgml_rust.auto_updated_at('pgml_rust.snapshots');
70+
6071

6172
---
62-
--- Deployments determine which model is live
73+
--- Models save the learned parameters
74+
---
75+
CREATE TABLE IF NOT EXISTS pgml_rust.models(
76+
id BIGSERIAL PRIMARY KEY,
77+
project_id BIGINT NOT NULL,
78+
snapshot_id BIGINT NOT NULL,
79+
algorithm TEXT NOT NULL,
80+
hyperparams JSONB NOT NULL,
81+
status TEXT NOT NULL,
82+
metrics JSONB,
83+
search pgml_rust.search,
84+
search_params JSONB NOT NULL,
85+
search_args JSONB NOT NULL,
86+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
87+
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
88+
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id),
89+
CONSTRAINT snapshot_id_fk FOREIGN KEY(snapshot_id) REFERENCES pgml_rust.snapshots(id)
90+
);
91+
CREATE INDEX IF NOT EXISTS models_project_id_idx ON pgml_rust.models(project_id);
92+
CREATE INDEX IF NOT EXISTS models_snapshot_id_idx ON pgml_rust.models(snapshot_id);
93+
SELECT pgml_rust.auto_updated_at('pgml_rust.models');
94+
95+
96+
---
97+
--- Deployements determine which model is live
6398
---
6499
CREATE TABLE IF NOT EXISTS pgml_rust.deployments(
65100
id BIGSERIAL PRIMARY KEY,
66101
project_id BIGINT NOT NULL,
67102
model_id BIGINT NOT NULL,
68-
strategy TEXT NOT NULL,
103+
strategy pgml_rust.strategy NOT NULL,
69104
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
70105
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id),
71106
CONSTRAINT model_id_fk FOREIGN KEY(model_id) REFERENCES pgml_rust.models(id)
72107
);
73108
CREATE INDEX IF NOT EXISTS deployments_project_id_created_at_idx ON pgml_rust.deployments(project_id);
74109
CREATE INDEX IF NOT EXISTS deployments_model_id_created_at_idx ON pgml_rust.deployments(model_id);
75110
SELECT pgml_rust.auto_updated_at('pgml_rust.deployments');
111+
112+
---
113+
--- Distribute serialized models consistently for HA
114+
---
115+
CREATE TABLE IF NOT EXISTS pgml_rust.files(
116+
id BIGSERIAL PRIMARY KEY,
117+
model_id BIGINT NOT NULL,
118+
path TEXT NOT NULL,
119+
part INTEGER NOT NULL,
120+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
121+
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
122+
data BYTEA NOT NULL
123+
);
124+
CREATE UNIQUE INDEX IF NOT EXISTS files_model_id_path_part_idx ON pgml_rust.files(model_id, path, part);
125+
SELECT pgml_rust.auto_updated_at('pgml_rust.files');
126+
127+
---
128+
--- Quick status check on the system.
129+
---
130+
DROP VIEW IF EXISTS pgml_rust.overview;
131+
CREATE VIEW pgml_rust.overview AS
132+
SELECT
133+
p.name,
134+
d.created_at AS deployed_at,
135+
p.task,
136+
m.algorithm,
137+
s.relation_name,
138+
s.y_column_name,
139+
s.test_sampling,
140+
s.test_size
141+
FROM pgml_rust.projects p
142+
INNER JOIN pgml_rust.models m ON p.id = m.project_id
143+
INNER JOIN pgml_rust.deployments d ON d.project_id = p.id
144+
AND d.model_id = m.id
145+
INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id
146+
ORDER BY d.created_at DESC;
147+
148+
149+
---
150+
--- List details of trained models.
151+
---
152+
DROP VIEW IF EXISTS pgml_rust.trained_models;
153+
CREATE VIEW pgml_rust.trained_models AS
154+
SELECT
155+
m.id,
156+
p.name,
157+
p.task,
158+
m.algorithm,
159+
m.created_at,
160+
s.test_sampling,
161+
s.test_size,
162+
d.model_id IS NOT NULL AS deployed
163+
FROM pgml_rust.projects p
164+
INNER JOIN pgml_rust.models m ON p.id = m.project_id
165+
INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id
166+
LEFT JOIN (
167+
SELECT DISTINCT ON(project_id)
168+
project_id, model_id, created_at
169+
FROM pgml_rust.deployments
170+
ORDER BY project_id, created_at desc
171+
) d ON d.model_id = m.id
172+
ORDER BY m.created_at DESC;
173+
174+
175+
---
176+
--- List details of deployed models.
177+
---
178+
DROP VIEW IF EXISTS pgml_rust.deployed_models;
179+
CREATE VIEW pgml_rust.deployed_models AS
180+
SELECT
181+
m.id,
182+
p.name,
183+
p.task,
184+
m.algorithm,
185+
d.created_at as deployed_at
186+
FROM pgml_rust.projects p
187+
INNER JOIN (
188+
SELECT DISTINCT ON(project_id)
189+
project_id, model_id, created_at
190+
FROM pgml_rust.deployments
191+
ORDER BY project_id, created_at desc
192+
) d ON d.project_id = p.id
193+
INNER JOIN pgml_rust.models m ON m.id = d.model_id
194+
ORDER BY p.name ASC;

pgml-extension/pgml_rust/src/api.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use pgx::*;
2+
3+
use crate::orm::Algorithm;
4+
use crate::orm::Model;
5+
use crate::orm::Project;
6+
use crate::orm::Sampling;
7+
use crate::orm::Search;
8+
use crate::orm::Snapshot;
9+
use crate::orm::Strategy;
10+
use crate::orm::Task;
11+
12+
#[pg_extern]
13+
fn train(
14+
project_name: &str,
15+
task: Option<default!(Task, "NULL")>,
16+
relation_name: Option<default!(&str, "NULL")>,
17+
y_column_name: Option<default!(&str, "NULL")>,
18+
algorithm: default!(Algorithm, "'linear'"),
19+
hyperparams: default!(JsonB, "'{}'"),
20+
search: Option<default!(Search, "NULL")>,
21+
search_params: default!(JsonB, "'{}'"),
22+
search_args: default!(JsonB, "'{}'"),
23+
test_size: default!(f32, 0.25),
24+
test_sampling: default!(Sampling, "'last'"),
25+
) {
26+
let project = match Project::find_by_name(project_name) {
27+
Some(project) => project,
28+
None => Project::create(project_name, task.unwrap()),
29+
};
30+
if task.is_some() && task.unwrap() != project.task {
31+
error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task);
32+
}
33+
let snapshot = match relation_name {
34+
None => project.last_snapshot().expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."),
35+
Some(relation_name) => Snapshot::create(relation_name, y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`"), test_size, test_sampling)
36+
};
37+
38+
// # Default repeatable random state when possible
39+
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
40+
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
41+
// hyperparams["random_state"] = 0
42+
43+
let model = Model::create(
44+
&project,
45+
&snapshot,
46+
algorithm,
47+
hyperparams,
48+
search,
49+
search_params,
50+
search_args,
51+
);
52+
53+
// TODO move deployment into a struct and only deploy if new model is better than old model
54+
Spi::get_one_with_args::<i64>(
55+
"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id",
56+
vec![
57+
(PgBuiltInOids::INT8OID.oid(), project.id.into_datum()),
58+
(PgBuiltInOids::INT8OID.oid(), model.id.into_datum()),
59+
(PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()),
60+
]
61+
);
62+
}
63+
64+
#[pg_extern]
65+
fn predict(project_name: &str, features: Vec<f32>) -> f32 {
66+
let estimator = crate::orm::estimator::find_deployed_estimator_by_project_name(project_name);
67+
estimator.predict(features)
68+
}
69+
70+
// #[pg_extern]
71+
// fn return_table_example() -> impl std::Iterator<Item = (name!(id, Option<i64>), name!(title, Option<String>))> {
72+
// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None)
73+
// vec![tuple].into_iter()
74+
// }
75+
76+
#[pg_extern]
77+
fn create_snapshot(
78+
relation_name: &str,
79+
y_column_name: &str,
80+
test_size: f32,
81+
test_sampling: Sampling,
82+
) -> i64 {
83+
let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling);
84+
info!("{:?}", snapshot);
85+
snapshot.id
86+
}
87+
88+
#[cfg(any(test, feature = "pg_test"))]
89+
#[pg_schema]
90+
mod tests {
91+
use super::*;
92+
93+
#[pg_test]
94+
fn test_project_lifecycle() {
95+
assert_eq!(Project::create("test", Task::regression).id, 1);
96+
assert_eq!(Project::find(1).id, 1);
97+
}
98+
99+
#[pg_test]
100+
fn test_snapshot_lifecycle() {
101+
let snapshot = Snapshot::create("test", "column", 0.5, Sampling::last);
102+
assert_eq!(snapshot.id, 1);
103+
}
104+
}

0 commit comments

Comments
 (0)