Skip to content

Commit 63963d4

Browse files
authored
Stratified sampling (#1336)
I verified tests locally, because I wasn't able to figure out how to get them running via github actions...
1 parent 347168a commit 63963d4

File tree

5 files changed

+241
-34
lines changed

5 files changed

+241
-34
lines changed

.github/workflows/ci.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ jobs:
4747
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
4848
run: |
4949
git submodule update --init --recursive
50+
- name: Get current version
51+
id: current-version
52+
run: echo "CI_BRANCH=$(git name-rev --name-only HEAD)" >> $GITHUB_OUTPUT
5053
- name: Run tests
54+
env:
55+
CI_BRANCH: ${{ steps.current-version.outputs.CI_BRANCH }}
5156
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
5257
run: |
5358
curl https://sh.rustup.rs -sSf | sh -s -- -y
@@ -58,8 +63,13 @@ jobs:
5863
cargo pgrx init
5964
fi
6065
66+
git checkout master
67+
echo "\q" | cargo pgrx run
68+
psql -p 28816 -h localhost -d pgml -P pager -c "CREATE EXTENSION pgml;"
69+
git checkout $CI_BRANCH
70+
echo "\q" | cargo pgrx run
71+
psql -p 28816 -h localhost -d pgml -P pager -c "ALTER EXTENSION pgml UPDATE;"
6172
cargo pgrx test
62-
6373
# cargo pgrx start
6474
# psql -p 28815 -h 127.0.0.1 -d pgml -P pager -f tests/test.sql
6575
# cargo pgrx stop

pgml-extension/sql/pgml--2.8.1--2.8.2.sql

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,102 @@ CREATE FUNCTION pgml."deploy"(
2525
AS 'MODULE_PATHNAME', 'deploy_strategy_wrapper';
2626

2727
ALTER TYPE pgml.strategy ADD VALUE 'specific';
28+
29+
ALTER TYPE pgml.Sampling ADD VALUE 'stratified';
30+
31+
-- src/api.rs:534
32+
-- pgml::api::snapshot
33+
DROP FUNCTION IF EXISTS pgml."snapshot"(text, text, real, pgml.Sampling, jsonb);
34+
CREATE FUNCTION pgml."snapshot"(
35+
"relation_name" TEXT, /* &str */
36+
"y_column_name" TEXT, /* &str */
37+
"test_size" real DEFAULT 0.25, /* f32 */
38+
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
39+
"preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
40+
) RETURNS TABLE (
41+
"relation" TEXT, /* alloc::string::String */
42+
"y_column_name" TEXT /* alloc::string::String */
43+
)
44+
STRICT
45+
LANGUAGE c /* Rust */
46+
AS 'MODULE_PATHNAME', 'snapshot_wrapper';
47+
48+
-- src/api.rs:802
49+
-- pgml::api::tune
50+
DROP FUNCTION IF EXISTS pgml."tune"(text, text, text, text, text, jsonb, real, pgml.Sampling, bool, bool);
51+
CREATE FUNCTION pgml."tune"(
52+
"project_name" TEXT, /* &str */
53+
"task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
54+
"relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
55+
"y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
56+
"model_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
57+
"hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
58+
"test_size" real DEFAULT 0.25, /* f32 */
59+
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
60+
"automatic_deploy" bool DEFAULT true, /* core::option::Option<bool> */
61+
"materialize_snapshot" bool DEFAULT false /* bool */
62+
) RETURNS TABLE (
63+
"status" TEXT, /* alloc::string::String */
64+
"task" TEXT, /* alloc::string::String */
65+
"algorithm" TEXT, /* alloc::string::String */
66+
"deployed" bool /* bool */
67+
)
68+
PARALLEL SAFE
69+
LANGUAGE c /* Rust */
70+
AS 'MODULE_PATHNAME', 'tune_wrapper';
71+
72+
-- src/api.rs:92
73+
-- pgml::api::train
74+
DROP FUNCTION IF EXISTS pgml."train"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
75+
CREATE FUNCTION pgml."train"(
76+
"project_name" TEXT, /* &str */
77+
"task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
78+
"relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
79+
"y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
80+
"algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
81+
"hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
82+
"search" pgml.Search DEFAULT NULL, /* core::option::Option<pgml::orm::search::Search> */
83+
"search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
84+
"search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
85+
"test_size" real DEFAULT 0.25, /* f32 */
86+
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
87+
"runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option<pgml::orm::runtime::Runtime> */
88+
"automatic_deploy" bool DEFAULT true, /* core::option::Option<bool> */
89+
"materialize_snapshot" bool DEFAULT false, /* bool */
90+
"preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
91+
) RETURNS TABLE (
92+
"project" TEXT, /* alloc::string::String */
93+
"task" TEXT, /* alloc::string::String */
94+
"algorithm" TEXT, /* alloc::string::String */
95+
"deployed" bool /* bool */
96+
)
97+
LANGUAGE c /* Rust */
98+
AS 'MODULE_PATHNAME', 'train_wrapper';
99+
100+
-- src/api.rs:138
101+
-- pgml::api::train_joint
102+
DROP FUNCTION IF EXISTS pgml."train_joint"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
103+
CREATE FUNCTION pgml."train_joint"(
104+
"project_name" TEXT, /* &str */
105+
"task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
106+
"relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
107+
"y_column_name" TEXT[] DEFAULT NULL, /* core::option::Option<alloc::vec::Vec<alloc::string::String>> */
108+
"algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
109+
"hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
110+
"search" pgml.Search DEFAULT NULL, /* core::option::Option<pgml::orm::search::Search> */
111+
"search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
112+
"search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
113+
"test_size" real DEFAULT 0.25, /* f32 */
114+
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
115+
"runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option<pgml::orm::runtime::Runtime> */
116+
"automatic_deploy" bool DEFAULT true, /* core::option::Option<bool> */
117+
"materialize_snapshot" bool DEFAULT false, /* bool */
118+
"preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
119+
) RETURNS TABLE (
120+
"project" TEXT, /* alloc::string::String */
121+
"task" TEXT, /* alloc::string::String */
122+
"algorithm" TEXT, /* alloc::string::String */
123+
"deployed" bool /* bool */
124+
)
125+
LANGUAGE c /* Rust */
126+
AS 'MODULE_PATHNAME', 'train_joint_wrapper';

pgml-extension/src/api.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ fn train(
100100
search_params: default!(JsonB, "'{}'"),
101101
search_args: default!(JsonB, "'{}'"),
102102
test_size: default!(f32, 0.25),
103-
test_sampling: default!(Sampling, "'last'"),
103+
test_sampling: default!(Sampling, "'stratified'"),
104104
runtime: default!(Option<Runtime>, "NULL"),
105105
automatic_deploy: default!(Option<bool>, true),
106106
materialize_snapshot: default!(bool, false),
@@ -146,7 +146,7 @@ fn train_joint(
146146
search_params: default!(JsonB, "'{}'"),
147147
search_args: default!(JsonB, "'{}'"),
148148
test_size: default!(f32, 0.25),
149-
test_sampling: default!(Sampling, "'last'"),
149+
test_sampling: default!(Sampling, "'stratified'"),
150150
runtime: default!(Option<Runtime>, "NULL"),
151151
automatic_deploy: default!(Option<bool>, true),
152152
materialize_snapshot: default!(bool, false),
@@ -535,7 +535,7 @@ fn snapshot(
535535
relation_name: &str,
536536
y_column_name: &str,
537537
test_size: default!(f32, 0.25),
538-
test_sampling: default!(Sampling, "'last'"),
538+
test_sampling: default!(Sampling, "'stratified'"),
539539
preprocess: default!(JsonB, "'{}'"),
540540
) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> {
541541
Snapshot::create(
@@ -807,7 +807,7 @@ fn tune(
807807
model_name: default!(Option<&str>, "NULL"),
808808
hyperparams: default!(JsonB, "'{}'"),
809809
test_size: default!(f32, 0.25),
810-
test_sampling: default!(Sampling, "'last'"),
810+
test_sampling: default!(Sampling, "'stratified'"),
811811
automatic_deploy: default!(Option<bool>, true),
812812
materialize_snapshot: default!(bool, false),
813813
) -> TableIterator<

pgml-extension/src/orm/sampling.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
use pgrx::*;
22
use serde::Deserialize;
33

4+
use super::snapshot::Column;
5+
46
#[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)]
57
#[allow(non_camel_case_types)]
68
pub enum Sampling {
79
random,
810
last,
11+
stratified,
912
}
1013

1114
impl std::str::FromStr for Sampling {
@@ -15,6 +18,7 @@ impl std::str::FromStr for Sampling {
1518
match input {
1619
"random" => Ok(Sampling::random),
1720
"last" => Ok(Sampling::last),
21+
"stratified" => Ok(Sampling::stratified),
1822
_ => Err(()),
1923
}
2024
}
@@ -25,6 +29,111 @@ impl std::string::ToString for Sampling {
2529
match *self {
2630
Sampling::random => "random".to_string(),
2731
Sampling::last => "last".to_string(),
32+
Sampling::stratified => "stratified".to_string(),
2833
}
2934
}
3035
}
36+
37+
impl Sampling {
38+
// Implementing the sampling strategy in SQL
39+
// Effectively orders the table according to the train/test split
40+
// e.g. first N rows are train, last M rows are test
41+
// where M is configured by the user
42+
pub fn get_sql(&self, relation_name: &str, y_column_names: Vec<Column>) -> String {
43+
let col_string = y_column_names
44+
.iter()
45+
.map(|c| c.quoted_name())
46+
.collect::<Vec<String>>()
47+
.join(", ");
48+
match *self {
49+
Sampling::random => {
50+
format!("SELECT * FROM {relation_name} ORDER BY RANDOM()")
51+
}
52+
Sampling::last => {
53+
format!("SELECT * FROM {relation_name}")
54+
}
55+
Sampling::stratified => {
56+
format!(
57+
"
58+
SELECT *
59+
FROM (
60+
SELECT
61+
*,
62+
ROW_NUMBER() OVER(PARTITION BY {col_string} ORDER BY RANDOM()) AS rn
63+
FROM {relation_name}
64+
) AS subquery
65+
ORDER BY rn, RANDOM();
66+
"
67+
)
68+
}
69+
}
70+
}
71+
}
72+
73+
#[cfg(test)]
74+
mod tests {
75+
use crate::orm::snapshot::{Preprocessor, Statistics};
76+
77+
use super::*;
78+
79+
fn get_column_fixtures() -> Vec<Column> {
80+
vec![
81+
Column {
82+
name: "col1".to_string(),
83+
pg_type: "text".to_string(),
84+
nullable: false,
85+
label: true,
86+
position: 0,
87+
size: 0,
88+
array: false,
89+
preprocessor: Preprocessor::default(),
90+
statistics: Statistics::default(),
91+
},
92+
Column {
93+
name: "col2".to_string(),
94+
pg_type: "text".to_string(),
95+
nullable: false,
96+
label: true,
97+
position: 0,
98+
size: 0,
99+
array: false,
100+
preprocessor: Preprocessor::default(),
101+
statistics: Statistics::default(),
102+
},
103+
]
104+
}
105+
106+
#[test]
107+
fn test_get_sql_random_sampling() {
108+
let sampling = Sampling::random;
109+
let columns = get_column_fixtures();
110+
let sql = sampling.get_sql("my_table", columns);
111+
assert_eq!(sql, "SELECT * FROM my_table ORDER BY RANDOM()");
112+
}
113+
114+
#[test]
115+
fn test_get_sql_last_sampling() {
116+
let sampling = Sampling::last;
117+
let columns = get_column_fixtures();
118+
let sql = sampling.get_sql("my_table", columns);
119+
assert_eq!(sql, "SELECT * FROM my_table");
120+
}
121+
122+
#[test]
123+
fn test_get_sql_stratified_sampling() {
124+
let sampling = Sampling::stratified;
125+
let columns = get_column_fixtures();
126+
let sql = sampling.get_sql("my_table", columns);
127+
let expected_sql = "
128+
SELECT *
129+
FROM (
130+
SELECT
131+
*,
132+
ROW_NUMBER() OVER(PARTITION BY \"col1\", \"col2\" ORDER BY RANDOM()) AS rn
133+
FROM my_table
134+
) AS subquery
135+
ORDER BY rn, RANDOM();
136+
";
137+
assert_eq!(sql, expected_sql);
138+
}
139+
}

pgml-extension/src/orm/snapshot.rs

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ pub(crate) struct Preprocessor {
119119
}
120120

121121
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
122-
pub(crate) struct Column {
122+
pub struct Column {
123123
pub(crate) name: String,
124124
pub(crate) pg_type: String,
125125
pub(crate) nullable: bool,
@@ -147,7 +147,7 @@ impl Column {
147147
)
148148
}
149149

150-
fn quoted_name(&self) -> String {
150+
pub(crate) fn quoted_name(&self) -> String {
151151
format!(r#""{}""#, self.name)
152152
}
153153

@@ -608,13 +608,8 @@ impl Snapshot {
608608
};
609609

610610
if materialized {
611-
let mut sql = format!(
612-
r#"CREATE TABLE "pgml"."snapshot_{}" AS SELECT * FROM {}"#,
613-
s.id, s.relation_name
614-
);
615-
if s.test_sampling == Sampling::random {
616-
sql += " ORDER BY random()";
617-
}
611+
let sampled_query = s.test_sampling.get_sql(&s.relation_name, s.columns.clone());
612+
let sql = format!(r#"CREATE TABLE "pgml"."snapshot_{}" AS {}"#, s.id, sampled_query);
618613
client.update(&sql, None, None).unwrap();
619614
}
620615
snapshot = Some(s);
@@ -742,26 +737,20 @@ impl Snapshot {
742737
}
743738

744739
fn select_sql(&self) -> String {
745-
format!(
746-
"SELECT {} FROM {} {}",
747-
self.columns
748-
.iter()
749-
.map(|c| c.quoted_name())
750-
.collect::<Vec<String>>()
751-
.join(", "),
752-
self.relation_name_quoted(),
753-
match self.materialized {
754-
// If the snapshot is materialized, we already randomized it.
755-
true => "",
756-
false => {
757-
if self.test_sampling == Sampling::random {
758-
"ORDER BY random()"
759-
} else {
760-
""
761-
}
762-
}
763-
},
764-
)
740+
match self.materialized {
741+
true => {
742+
format!(
743+
"SELECT {} FROM {}",
744+
self.columns
745+
.iter()
746+
.map(|c| c.quoted_name())
747+
.collect::<Vec<String>>()
748+
.join(", "),
749+
self.relation_name_quoted()
750+
)
751+
}
752+
false => self.test_sampling.get_sql(&self.relation_name_quoted(), self.columns.clone()),
753+
}
765754
}
766755

767756
fn train_test_split(&self, num_rows: usize) -> (usize, usize) {

0 commit comments

Comments
 (0)