Skip to content

Commit c3a8514

Browse files
authored
SDK - Added re-ranking into vector search (#1516)
1 parent 34e64d8 commit c3a8514

File tree

5 files changed

+197
-17
lines changed

5 files changed

+197
-17
lines changed

pgml-sdks/pgml/Cargo.lock

Lines changed: 0 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-sdks/pgml/src/collection.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,7 @@ impl Collection {
10511051
/// }).into(), &mut pipeline).await?;
10521052
/// Ok(())
10531053
/// }
1054+
#[allow(clippy::type_complexity)]
10541055
#[instrument(skip(self))]
10551056
pub async fn vector_search(
10561057
&mut self,
@@ -1061,7 +1062,7 @@ impl Collection {
10611062

10621063
let (built_query, values) =
10631064
build_vector_search_query(query.clone(), self, pipeline).await?;
1064-
let results: Result<Vec<(Json, String, f64)>, _> =
1065+
let results: Result<Vec<(Json, String, f64, Option<f64>)>, _> =
10651066
sqlx::query_as_with(&built_query, values)
10661067
.fetch_all(&pool)
10671068
.await;
@@ -1072,7 +1073,8 @@ impl Collection {
10721073
serde_json::json!({
10731074
"document": v.0,
10741075
"chunk": v.1,
1075-
"score": v.2
1076+
"score": v.2,
1077+
"rerank_score": v.3
10761078
})
10771079
.into()
10781080
})
@@ -1087,7 +1089,7 @@ impl Collection {
10871089
.await?;
10881090
let (built_query, values) =
10891091
build_vector_search_query(query, self, pipeline).await?;
1090-
let results: Vec<(Json, String, f64)> =
1092+
let results: Vec<(Json, String, f64, Option<f64>)> =
10911093
sqlx::query_as_with(&built_query, values)
10921094
.fetch_all(&pool)
10931095
.await?;
@@ -1097,7 +1099,8 @@ impl Collection {
10971099
serde_json::json!({
10981100
"document": v.0,
10991101
"chunk": v.1,
1100-
"score": v.2
1102+
"score": v.2,
1103+
"rerank_score": v.3
11011104
})
11021105
.into()
11031106
})
@@ -1121,16 +1124,18 @@ impl Collection {
11211124
let pool = get_or_initialize_pool(&self.database_url).await?;
11221125
let (built_query, values) =
11231126
build_vector_search_query(query.clone(), self, pipeline).await?;
1124-
let results: Vec<(Json, String, f64)> = sqlx::query_as_with(&built_query, values)
1125-
.fetch_all(&pool)
1126-
.await?;
1127+
let results: Vec<(Json, String, f64, Option<f64>)> =
1128+
sqlx::query_as_with(&built_query, values)
1129+
.fetch_all(&pool)
1130+
.await?;
11271131
Ok(results
11281132
.into_iter()
11291133
.map(|v| {
11301134
serde_json::json!({
11311135
"document": v.0,
11321136
"chunk": v.1,
1133-
"score": v.2
1137+
"score": v.2,
1138+
"rerank_score": v.3
11341139
})
11351140
.into()
11361141
})

pgml-sdks/pgml/src/lib.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,88 @@ mod tests {
15531553
Ok(())
15541554
}
15551555

1556+
#[tokio::test]
1557+
async fn can_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()> {
1558+
internal_init_logger(None, None).ok();
1559+
let collection_name = "test r_c_cvswlear_1";
1560+
let mut collection = Collection::new(collection_name, None)?;
1561+
let documents = generate_dummy_documents(10);
1562+
collection.upsert_documents(documents.clone(), None).await?;
1563+
let pipeline_name = "0";
1564+
let mut pipeline = Pipeline::new(
1565+
pipeline_name,
1566+
Some(
1567+
json!({
1568+
"title": {
1569+
"semantic_search": {
1570+
"model": "intfloat/e5-small-v2",
1571+
"parameters": {
1572+
"prompt": "passage: "
1573+
}
1574+
},
1575+
"full_text_search": {
1576+
"configuration": "english"
1577+
}
1578+
},
1579+
"body": {
1580+
"splitter": {
1581+
"model": "recursive_character"
1582+
},
1583+
"semantic_search": {
1584+
"model": "intfloat/e5-small-v2",
1585+
"parameters": {
1586+
"prompt": "passage: "
1587+
}
1588+
},
1589+
},
1590+
})
1591+
.into(),
1592+
),
1593+
)?;
1594+
collection.add_pipeline(&mut pipeline).await?;
1595+
let results = collection
1596+
.vector_search(
1597+
json!({
1598+
"query": {
1599+
"fields": {
1600+
"title": {
1601+
"query": "Test document: 2",
1602+
"parameters": {
1603+
"prompt": "passage: "
1604+
},
1605+
"full_text_filter": "test",
1606+
"boost": 1.2
1607+
},
1608+
"body": {
1609+
"query": "Test document: 2",
1610+
"parameters": {
1611+
"prompt": "passage: "
1612+
},
1613+
"boost": 1.0
1614+
},
1615+
}
1616+
},
1617+
"rerank": {
1618+
"query": "Test document 2",
1619+
"model": "mixedbread-ai/mxbai-rerank-base-v1",
1620+
"num_documents_to_rerank": 100
1621+
},
1622+
"limit": 5
1623+
})
1624+
.into(),
1625+
&mut pipeline,
1626+
)
1627+
.await?;
1628+
assert!(results[0]["rerank_score"].as_f64().is_some());
1629+
let ids: Vec<u64> = results
1630+
.into_iter()
1631+
.map(|r| r["document"]["id"].as_u64().unwrap())
1632+
.collect();
1633+
assert_eq!(ids, vec![2, 1, 3, 8, 6]);
1634+
collection.archive().await?;
1635+
Ok(())
1636+
}
1637+
15561638
///////////////////////////////
15571639
// Working With Documents /////
15581640
///////////////////////////////
@@ -2207,6 +2289,11 @@ mod tests {
22072289
"id"
22082290
]
22092291
},
2292+
"rerank": {
2293+
"query": "Test document 2",
2294+
"model": "mixedbread-ai/mxbai-rerank-base-v1",
2295+
"num_documents_to_rerank": 100
2296+
},
22102297
"limit": 5
22112298
},
22122299
"aggregate": {

pgml-sdks/pgml/src/rag_query_builder.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,7 @@ pub async fn build_rag_query(
212212
r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#,
213213
vector_search.aggregate.join
214214
),
215-
format!(
216-
r#"(SELECT json_agg(jsonb_build_object('chunk', chunk, 'document', document, 'score', score)) FROM "{var_name}")"#
217-
),
215+
format!(r#"(SELECT json_agg(j) FROM "{var_name}" j)"#),
218216
)
219217
}
220218
ValidVariable::RawSQL(sql) => (format!("({})", sql.sql), format!("({})", sql.sql)),

pgml-sdks/pgml/src/vector_search_query_builder.rs

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ struct ValidDocument {
4141
keys: Option<Vec<String>>,
4242
}
4343

44+
const fn default_num_documents_to_rerank() -> u64 {
45+
10
46+
}
47+
48+
#[derive(Debug, Deserialize, Serialize, Clone)]
49+
#[serde(deny_unknown_fields)]
50+
struct ValidRerank {
51+
query: String,
52+
model: String,
53+
#[serde(default = "default_num_documents_to_rerank")]
54+
num_documents_to_rerank: u64,
55+
parameters: Option<Json>,
56+
}
57+
4458
const fn default_limit() -> u64 {
4559
10
4660
}
@@ -56,6 +70,8 @@ pub struct ValidQuery {
5670
limit: u64,
5771
// Document related items
5872
document: Option<ValidDocument>,
73+
// Rerank related items
74+
rerank: Option<ValidRerank>,
5975
}
6076

6177
pub async fn build_sqlx_query(
@@ -66,9 +82,14 @@ pub async fn build_sqlx_query(
6682
prefix: Option<&str>,
6783
) -> anyhow::Result<(SelectStatement, Vec<CommonTableExpression>)> {
6884
let valid_query: ValidQuery = serde_json::from_value(query.0)?;
69-
let limit = valid_query.limit;
7085
let fields = valid_query.query.fields.unwrap_or_default();
7186

87+
let search_limit = if let Some(rerank) = valid_query.rerank.as_ref() {
88+
rerank.num_documents_to_rerank
89+
} else {
90+
valid_query.limit
91+
};
92+
7293
let prefix = prefix.unwrap_or("");
7394

7495
if fields.is_empty() {
@@ -209,7 +230,7 @@ pub async fn build_sqlx_query(
209230
Expr::col((SIden::Str("documents"), SIden::Str("id")))
210231
.equals((SIden::Str("chunks"), SIden::Str("document_id"))),
211232
)
212-
.limit(limit);
233+
.limit(search_limit);
213234

214235
if let Some(filter) = &valid_query.query.filter {
215236
let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?;
@@ -272,7 +293,79 @@ pub async fn build_sqlx_query(
272293
// Resort and limit
273294
query
274295
.order_by(SIden::Str("score"), Order::Desc)
275-
.limit(limit);
296+
.limit(search_limit);
297+
298+
// Rerank
299+
let query = if let Some(rerank) = &valid_query.rerank {
300+
// Add our vector_search CTE
301+
let mut vector_search_cte = CommonTableExpression::from_select(query);
302+
vector_search_cte.table_name(Alias::new(format!("{prefix}_vector_search")));
303+
ctes.push(vector_search_cte);
304+
305+
// Add our row_number_vector_search CTE
306+
let mut row_number_vector_search = Query::select();
307+
row_number_vector_search
308+
.columns([
309+
SIden::Str("document"),
310+
SIden::Str("chunk"),
311+
SIden::Str("score"),
312+
])
313+
.from(SIden::String(format!("{prefix}_vector_search")));
314+
row_number_vector_search
315+
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number"));
316+
let mut row_number_vector_search_cte =
317+
CommonTableExpression::from_select(row_number_vector_search);
318+
row_number_vector_search_cte
319+
.table_name(Alias::new(format!("{prefix}_row_number_vector_search")));
320+
ctes.push(row_number_vector_search_cte);
321+
322+
// Our actual select statement
323+
let mut query = Query::select();
324+
query.columns([
325+
SIden::Str("document"),
326+
SIden::Str("chunk"),
327+
SIden::Str("score"),
328+
]);
329+
query.expr_as(Expr::cust("(rank).score"), Alias::new("rank_score"));
330+
331+
// Build the actual select statement sub query
332+
let mut sub_query_rank_call = Query::select();
333+
let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]);
334+
let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]);
335+
let parameters_expr =
336+
Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]);
337+
sub_query_rank_call.expr_as(Expr::cust_with_exprs(
338+
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
339+
[model_expr, query_expr, parameters_expr],
340+
), Alias::new("rank"))
341+
.from(SIden::String(format!("{prefix}_row_number_vector_search")));
342+
343+
let mut sub_query = Query::select();
344+
sub_query
345+
.columns([
346+
SIden::Str("document"),
347+
SIden::Str("chunk"),
348+
SIden::Str("score"),
349+
SIden::Str("rank"),
350+
])
351+
.from_as(
352+
SIden::String(format!("{prefix}_row_number_vector_search")),
353+
Alias::new("rnsv1"),
354+
)
355+
.join_subquery(
356+
JoinType::InnerJoin,
357+
sub_query_rank_call,
358+
Alias::new("rnsv2"),
359+
Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"),
360+
);
361+
362+
// Query from the sub query
363+
query.from_subquery(sub_query, Alias::new("sub_query"));
364+
365+
query
366+
} else {
367+
query
368+
};
276369

277370
Ok((query, ctes))
278371
}

0 commit comments

Comments
 (0)