Skip to content

Commit f10607a

Browse files
authored
SDK - Specify which keys to retrieve for documents when using collection.get_documents (#1460)
1 parent 5ce41e7 commit f10607a

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

pgml-sdks/pgml/src/collection.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use indicatif::MultiProgress;
33
use itertools::Itertools;
44
use regex::Regex;
55
use rust_bridge::{alias, alias_methods};
6+
use sea_query::Alias;
67
use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query};
78
use sea_query_binder::SqlxBinder;
89
use serde_json::json;
@@ -656,8 +657,9 @@ impl Collection {
656657
/// Each object must have a `field` key with the name of the field to order by, and a `direction`
657658
/// key with the value `asc` or `desc`.
658659
/// * `last_row_id` - The id of the last document returned
659-
/// * `offset` - The number of documents to skip before returning results.
660-
/// * `filter` - A JSON object specifying the filter to apply to the documents.
660+
/// * `offset` - The number of documents to skip before returning results
661+
/// * `filter` - A JSON object specifying the filter to apply to the documents
662+
/// * `keys` - a JSON array specifying the document keys to return
661663
///
662664
/// # Example
663665
///
@@ -691,9 +693,33 @@ impl Collection {
691693
self.documents_table_name.to_table_tuple(),
692694
SIden::Str("documents"),
693695
)
694-
.expr(Expr::cust("*")) // Adds the * in SELECT * FROM
696+
.columns([
697+
SIden::Str("id"),
698+
SIden::Str("created_at"),
699+
SIden::Str("source_uuid"),
700+
SIden::Str("version"),
701+
])
695702
.limit(limit);
696703

704+
if let Some(keys) = args.remove("keys") {
705+
let document_queries = keys
706+
.as_array()
707+
.context("`keys` must be an array")?
708+
.iter()
709+
.map(|d| {
710+
let key = d.as_str().context("`key` value must be a string")?;
711+
anyhow::Ok(format!("'{key}', document #> '{{{key}}}'"))
712+
})
713+
.collect::<anyhow::Result<Vec<String>>>()?
714+
.join(",");
715+
query.expr_as(
716+
Expr::cust(format!("jsonb_build_object({document_queries})")),
717+
Alias::new("document"),
718+
);
719+
} else {
720+
query.column(SIden::Str("document"));
721+
}
722+
697723
if let Some(order_by) = args.remove("order_by") {
698724
let order_by_builder =
699725
order_by_builder::OrderByBuilder::new(order_by, "documents", "document").build()?;

pgml-sdks/pgml/src/lib.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,50 @@ mod tests {
13431343
Ok(())
13441344
}
13451345

1346+
#[tokio::test]
1347+
async fn can_get_document_keys_get_documents() -> anyhow::Result<()> {
1348+
internal_init_logger(None, None).ok();
1349+
let mut collection = Collection::new("test r_c_cuafgd_1", None)?;
1350+
1351+
let documents = vec![
1352+
serde_json::json!({"id": 1, "random_key": 10, "nested": {"nested2": "test" } , "text": "hello world 1"}).into(),
1353+
serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(),
1354+
serde_json::json!({"id": 3, "random_key": 12, "text": "hello world 3"}).into(),
1355+
];
1356+
collection.upsert_documents(documents.clone(), None).await?;
1357+
1358+
let documents = collection
1359+
.get_documents(Some(
1360+
serde_json::json!({
1361+
"keys": [
1362+
"id",
1363+
"random_key",
1364+
"nested,nested2"
1365+
]
1366+
})
1367+
.into(),
1368+
))
1369+
.await?;
1370+
assert!(!documents[0]["document"]
1371+
.as_object()
1372+
.unwrap()
1373+
.contains_key("text"));
1374+
assert!(documents[0]["document"]
1375+
.as_object()
1376+
.unwrap()
1377+
.contains_key("id"));
1378+
assert!(documents[0]["document"]
1379+
.as_object()
1380+
.unwrap()
1381+
.contains_key("random_key"));
1382+
assert!(documents[0]["document"]
1383+
.as_object()
1384+
.unwrap()
1385+
.contains_key("nested,nested2"));
1386+
collection.archive().await?;
1387+
Ok(())
1388+
}
1389+
13461390
#[tokio::test]
13471391
async fn can_paginate_get_documents() -> anyhow::Result<()> {
13481392
internal_init_logger(None, None).ok();

pgml-sdks/pgml/src/queries.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ CREATE TABLE IF NOT EXISTS %s (
3030
id serial8 PRIMARY KEY,
3131
created_at timestamp NOT NULL DEFAULT now(),
3232
source_uuid uuid NOT NULL,
33-
document jsonb NOT NULL,
3433
version jsonb NOT NULL DEFAULT '{}'::jsonb,
34+
document jsonb NOT NULL,
3535
UNIQUE (source_uuid)
3636
);
3737
"#;

0 commit comments

Comments
 (0)