Skip to content

Commit 4c301ca

Browse files
authored
SDK - Fix in operator to match expected mongo behavior (#1354)
1 parent 8f8d8bf commit 4c301ca

File tree

1 file changed

+39
-28
lines changed

1 file changed

+39
-28
lines changed

pgml-sdks/pgml/src/filter_builder.rs

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,27 @@ fn build_expression(expression: Expr, filter: serde_json::Value) -> SimpleExpr {
2727
"$gte" => expression.gte(Expr::val(serde_value_to_sea_query_value(value))),
2828
"$lt" => expression.lt(Expr::val(serde_value_to_sea_query_value(value))),
2929
"$lte" => expression.lte(Expr::val(serde_value_to_sea_query_value(value))),
30-
"$in" => {
30+
e @ "$in" | e @ "$nin" => {
3131
let value = value
3232
.as_array()
3333
.expect("Invalid metadata filter configuration")
3434
.iter()
35-
// .map(|value| handle_value(value))
36-
.map(|value| Expr::val(serde_value_to_sea_query_value(value)))
37-
.collect::<Vec<_>>();
38-
expression.is_in(value)
39-
}
40-
"$nin" => {
41-
let value = value
42-
.as_array()
43-
.expect("Invalid metadata filter configuration")
44-
.iter()
45-
.map(|value| Expr::val(serde_value_to_sea_query_value(value)))
35+
.map(|value| {
36+
if value.is_string() {
37+
value.as_str().unwrap().to_owned()
38+
} else {
39+
value.to_string()
40+
}
41+
})
4642
.collect::<Vec<_>>();
47-
expression.is_not_in(value)
43+
let value_expr = Expr::cust_with_values("$1", [value]);
44+
let expr =
45+
Expr::cust_with_exprs("$1 && $2", [SimpleExpr::from(expression), value_expr]);
46+
if e == "$in" {
47+
expr
48+
} else {
49+
expr.not()
50+
}
4851
}
4952
_ => panic!("Invalid metadata filter configuration"),
5053
};
@@ -115,6 +118,15 @@ fn build_recursive<'a>(
115118
.contains(Expr::val(serde_value_to_sea_query_value(&json)));
116119
expression.not()
117120
}
121+
} else if operator == "$in" || operator == "$nin" {
122+
let expression = Expr::cust(
123+
format!(
124+
r#"ARRAY(SELECT JSONB_ARRAY_ELEMENTS_TEXT(JSONB_PATH_QUERY_ARRAY("{table_name}"."{column_name}", '$.{}[*]')))"#,
125+
local_path.join(".")
126+
).as_str()
127+
);
128+
let expression = Expr::expr(expression);
129+
build_expression(expression, value.clone())
118130
} else {
119131
let expression = Expr::cust(
120132
format!(
@@ -256,7 +268,6 @@ mod tests {
256268
}))
257269
.build()?
258270
.to_valid_sql_query();
259-
println!("{sql}");
260271
assert_eq!(
261272
sql,
262273
format!(
@@ -270,25 +281,25 @@ mod tests {
270281

271282
#[test]
272283
fn array_comparison_operators() -> anyhow::Result<()> {
273-
let array_comparison_operators = vec!["IN", "NOT IN"];
274284
let array_comparison_operators_names = vec!["$in", "$nin"];
275-
for (operator, name) in array_comparison_operators
276-
.into_iter()
277-
.zip(array_comparison_operators_names.into_iter())
278-
{
285+
for name in array_comparison_operators_names {
279286
let sql = construct_filter_builder_with_json(json!({
280-
"id": {name: [1]},
281-
"id2": {"id3": {name: [1]}}
287+
"id": {name: ["key_1", "key_2", 10]},
288+
"id2": {"id3": {name: ["key_1", false]}}
282289
}))
283290
.build()?
284291
.to_valid_sql_query();
285-
assert_eq!(
286-
sql,
287-
format!(
288-
r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>'{{id}}') {} ('1') AND ("test_table"."metadata"#>'{{id2,id3}}') {} ('1')"##,
289-
operator, operator
290-
)
291-
);
292+
if name == "$in" {
293+
assert_eq!(
294+
sql,
295+
r#"SELECT "id" FROM "test_table" WHERE (ARRAY(SELECT JSONB_ARRAY_ELEMENTS_TEXT(JSONB_PATH_QUERY_ARRAY("test_table"."metadata", '$.id[*]'))) && ARRAY ['key_1','key_2','10']) AND (ARRAY(SELECT JSONB_ARRAY_ELEMENTS_TEXT(JSONB_PATH_QUERY_ARRAY("test_table"."metadata", '$.id2.id3[*]'))) && ARRAY ['key_1','false'])"#
296+
);
297+
} else {
298+
assert_eq!(
299+
sql,
300+
r#"SELECT "id" FROM "test_table" WHERE (NOT (ARRAY(SELECT JSONB_ARRAY_ELEMENTS_TEXT(JSONB_PATH_QUERY_ARRAY("test_table"."metadata", '$.id[*]'))) && ARRAY ['key_1','key_2','10'])) AND (NOT (ARRAY(SELECT JSONB_ARRAY_ELEMENTS_TEXT(JSONB_PATH_QUERY_ARRAY("test_table"."metadata", '$.id2.id3[*]'))) && ARRAY ['key_1','false']))"#
301+
);
302+
}
292303
}
293304
Ok(())
294305
}

0 commit comments

Comments
 (0)