From d36496224fb35dc52d2e4d7747a4aba8c45499ed Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 14 Jun 2024 11:39:51 -0700 Subject: [PATCH 1/2] Added re-ranking into document search --- pgml-sdks/pgml/Cargo.lock | 2 +- pgml-sdks/pgml/src/lib.rs | 16 +- pgml-sdks/pgml/src/search_query_builder.rs | 177 +++++++++++++++++++-- 3 files changed, 175 insertions(+), 20 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 784b528a7..8de1d3967 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1590,7 +1590,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "1.0.4" +version = "1.1.0" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 16ec25ece..c95180fc6 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -1038,7 +1038,12 @@ mod tests { "full_text_search": { "title": { "query": "test 9", - "boost": 4.0 + "boost": 4.0, + "rerank": { + "query": "Test document 2", + "model": "mixedbread-ai/mxbai-rerank-base-v1", + "num_documents_to_rerank": 100 + } }, "body": { "query": "Test", @@ -1051,7 +1056,12 @@ mod tests { "parameters": { "prompt": "query: ", }, - "boost": 2.0 + "boost": 2.0, + "rerank": { + "query": "Test document 2", + "model": "mixedbread-ai/mxbai-rerank-base-v1", + "num_documents_to_rerank": 100 + } }, "body": { "query": "This is the body test", @@ -1086,7 +1096,7 @@ mod tests { .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![9, 3, 4, 7, 5]); + assert_eq!(ids, vec![9, 3, 4, 5, 6]); let pool = get_or_initialize_pool(&None).await?; diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index e76371541..f519add6f 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -25,6 +25,7 @@ struct ValidSemanticSearchAction { query: String, parameters: Option, boost: Option, + rerank: Option, } #[derive(Debug, Deserialize)] @@ -32,6 +33,7 @@ struct ValidSemanticSearchAction { struct ValidFullTextSearchAction { query: String, boost: Option, + rerank: Option, } #[derive(Debug, Deserialize)] @@ -42,6 +44,20 @@ struct ValidQueryActions { filter: Option, } +const fn default_num_documents_to_rerank() -> u64 { + 10 +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidRerank { + query: String, + model: String, + #[serde(default = "default_num_documents_to_rerank")] + num_documents_to_rerank: u64, + parameters: Option, +} + const fn default_limit() -> u64 { 10 } @@ -106,7 +122,11 @@ pub async fn build_search_query( // Build the CTE we actually use later let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); - let cte_name = format!("{key}_embedding_score"); + let cte_name = if vsa.rerank.is_some() { + format!("pre_rerank_{key}_embedding_score") + } else { + format!("{key}_embedding_score") + }; let boost = vsa.boost.unwrap_or(1.); let mut score_cte_non_recursive = Query::select(); let mut score_cte_recurisive = Query::select(); @@ -295,18 +315,78 @@ pub async fn build_search_query( .from_subquery(score_cte_non_recursive, Alias::new("non_recursive")) .union(sea_query::UnionType::All, score_cte_recurisive) .to_owned(); - let mut score_cte = CommonTableExpression::from_select(score_cte); score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); + if let Some(rerank) = vsa.rerank { + // Add our row_number_pre_rerank CTE + let mut row_number_pre_rerank = Query::select(); + row_number_pre_rerank + .column(SIden::Str("id")) + .from(SIden::String(cte_name.clone())) + .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) + .limit(rerank.num_documents_to_rerank); + let mut row_number_pre_rerank_cte = + CommonTableExpression::from_select(row_number_pre_rerank); + row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}"))); + with_clause.cte(row_number_pre_rerank_cte); + + // Our actual CTE + let mut query = Query::select(); + query.column(SIden::Str("id")); + query.expr_as(Expr::cust("(rank).score"), Alias::new("score")); + + // Build the actual CTE + let mut sub_query_rank_call = Query::select(); + let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]); + let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]); + let parameters_expr = + Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]); + sub_query_rank_call.expr_as(Expr::cust_with_exprs( + format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), + [model_expr, query_expr, parameters_expr], + ), Alias::new("rank")) + .from(SIden::String(format!("row_number_{cte_name}"))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::String(format!("row_number_{cte_name}")), SIden::Str("id"))), + ); + + let mut sub_query = Query::select(); + sub_query + .columns([SIden::Str("id"), SIden::Str("rank")]) + .from_as( + SIden::String(format!("row_number_{cte_name}")), + Alias::new("rnsv1"), + ) + .join_subquery( + JoinType::InnerJoin, + sub_query_rank_call, + Alias::new("rnsv2"), + Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"), + ); + + query.from_subquery(sub_query, Alias::new("sub_query")); + let mut query_cte = CommonTableExpression::from_select(query); + query_cte.table_name(Alias::new(format!("{key}_embedding_score"))); + with_clause.cte(query_cte); + } + // Add to the sum expression sum_expression = if let Some(expr) = sum_expression { - Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) + Some(expr.add(Expr::cust(format!( + r#"COALESCE("{key}_embedding_score".score, 0.0)"# + )))) } else { - Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) + Some(Expr::cust(format!( + r#"COALESCE("{key}_embedding_score".score, 0.0)"# + ))) }; - score_table_names.push(cte_name); + score_table_names.push(format!("{key}_embedding_score")); } for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() { @@ -315,7 +395,11 @@ pub async fn build_search_query( let boost = vma.boost.unwrap_or(1.0); // Build the score CTE - let cte_name = format!("{key}_tsvectors_score"); + let cte_name = if vma.rerank.is_some() { + format!("pre_rerank_{key}_tsvectors_score") + } else { + format!("{key}_tsvectors_score") + }; let mut score_cte_non_recursive = Query::select() .column((SIden::Str("documents"), SIden::Str("id"))) @@ -425,13 +509,74 @@ pub async fn build_search_query( score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); + if let Some(rerank) = vma.rerank { + // Add our row_number_pre_rerank CTE + let mut row_number_pre_rerank = Query::select(); + row_number_pre_rerank + .column(SIden::Str("id")) + .from(SIden::String(cte_name.clone())) + .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) + .limit(rerank.num_documents_to_rerank); + let mut row_number_pre_rerank_cte = + CommonTableExpression::from_select(row_number_pre_rerank); + row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}"))); + with_clause.cte(row_number_pre_rerank_cte); + + // Our actual CTE + let mut query = Query::select(); + query.column(SIden::Str("id")); + query.expr_as(Expr::cust("(rank).score"), Alias::new("score")); + + // Build the actual CTE + let mut sub_query_rank_call = Query::select(); + let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]); + let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]); + let parameters_expr = + Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]); + sub_query_rank_call.expr_as(Expr::cust_with_exprs( + format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), + [model_expr, query_expr, parameters_expr], + ), Alias::new("rank")) + .from(SIden::String(format!("row_number_{cte_name}"))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::String(format!("row_number_{cte_name}")), SIden::Str("id"))), + ); + + let mut sub_query = Query::select(); + sub_query + .columns([SIden::Str("id"), SIden::Str("rank")]) + .from_as( + SIden::String(format!("row_number_{cte_name}")), + Alias::new("rnsv1"), + ) + .join_subquery( + JoinType::InnerJoin, + sub_query_rank_call, + Alias::new("rnsv2"), + Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"), + ); + + query.from_subquery(sub_query, Alias::new("sub_query")); + let mut query_cte = CommonTableExpression::from_select(query); + query_cte.table_name(Alias::new(format!("{key}_tsvectors_score"))); + with_clause.cte(query_cte); + } + // Add to the sum expression sum_expression = if let Some(expr) = sum_expression { - Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) + Some(expr.add(Expr::cust(format!( + r#"COALESCE("{key}_tsvectors_score".score, 0.0)"# + )))) } else { - Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) + Some(Expr::cust(format!( + r#"COALESCE("{key}_tsvectors_score".score, 0.0)"# + ))) }; - score_table_names.push(cte_name); + score_table_names.push(format!("{key}_tsvectors_score")); } let query = if let Some(select_from) = score_table_names.first() { @@ -440,9 +585,9 @@ pub async fn build_search_query( .into_iter() .map(|t| Expr::col((SIden::String(t), SIden::Str("id"))).into()) .collect(); - let mut main_query = Query::select(); + let mut joined_query = Query::select(); for i in 1..score_table_names_e.len() { - main_query.full_outer_join( + joined_query.full_outer_join( SIden::String(score_table_names[i].to_string()), Expr::col(( SIden::String(score_table_names[i].to_string()), @@ -455,7 +600,8 @@ pub async fn build_search_query( let sum_expression = sum_expression .context("query requires some scoring through full_text_search or semantic_search")?; - main_query + + joined_query .expr_as(Expr::expr(id_select_expression.clone()), Alias::new("id")) .expr_as(sum_expression, Alias::new("score")) .column(SIden::Str("document")) @@ -468,10 +614,9 @@ pub async fn build_search_query( ) .order_by(SIden::Str("score"), Order::Desc) .limit(limit); - - let mut main_query = CommonTableExpression::from_select(main_query); - main_query.table_name(Alias::new("main")); - with_clause.cte(main_query); + let mut joined_query = CommonTableExpression::from_select(joined_query); + joined_query.table_name(Alias::new("main")); + with_clause.cte(joined_query); // Insert into searches table let searches_table = format!("{}_{}.searches", collection.name, pipeline.name); From 16a7799cc82bc54476269ac9c01de3b80483f14c Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 14 Jun 2024 14:01:23 -0700 Subject: [PATCH 2/2] Finalized re-ranking in document search --- pgml-sdks/pgml/src/lib.rs | 6 ++-- pgml-sdks/pgml/src/search_query_builder.rs | 36 +++++++++++----------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index c95180fc6..30ce09fea 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -980,7 +980,7 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_123"; + let collection_name = "test_r_c_cswle_126"; let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -1096,7 +1096,7 @@ mod tests { .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![9, 3, 4, 5, 6]); + assert_eq!(ids, vec![2, 9, 3, 8, 4]); let pool = get_or_initialize_pool(&None).await?; @@ -1121,7 +1121,7 @@ mod tests { // Document ids are 1 based in the db not 0 based like they are here assert_eq!( search_results.iter().map(|sr| sr.2).collect::>(), - vec![10, 4, 5, 8, 6] + vec![3, 10, 4, 9, 5] ); let event = json!({"clicked": true}); diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index f519add6f..7ca23ff25 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -151,6 +151,7 @@ pub async fn build_search_query( score_cte_non_recursive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .join_as( JoinType::InnerJoin, chunks_table.to_table_tuple(), @@ -177,6 +178,7 @@ pub async fn build_search_query( score_cte_recurisive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || documents.id"#))) .expr(Expr::cust(format!( r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# @@ -233,6 +235,7 @@ pub async fn build_search_query( score_cte_non_recursive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr(Expr::cust("ARRAY[documents.id] as previous_document_ids")) .expr(Expr::cust_with_values( format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), @@ -269,6 +272,7 @@ pub async fn build_search_query( Expr::cust("1 = 1"), ) .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr(Expr::cust(format!( r#""{cte_name}".previous_document_ids || documents.id"# ))) @@ -324,6 +328,7 @@ pub async fn build_search_query( let mut row_number_pre_rerank = Query::select(); row_number_pre_rerank .column(SIden::Str("id")) + .column(SIden::Str("chunk")) .from(SIden::String(cte_name.clone())) .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) .limit(rerank.num_documents_to_rerank); @@ -335,7 +340,10 @@ pub async fn build_search_query( // Our actual CTE let mut query = Query::select(); query.column(SIden::Str("id")); - query.expr_as(Expr::cust("(rank).score"), Alias::new("score")); + query.expr_as( + Expr::cust(format!("(rank).score * {boost}")), + Alias::new("score"), + ); // Build the actual CTE let mut sub_query_rank_call = Query::select(); @@ -347,14 +355,7 @@ pub async fn build_search_query( format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), [model_expr, query_expr, parameters_expr], ), Alias::new("rank")) - .from(SIden::String(format!("row_number_{cte_name}"))) - .join_as( - JoinType::InnerJoin, - chunks_table.to_table_tuple(), - Alias::new("chunks"), - Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::String(format!("row_number_{cte_name}")), SIden::Str("id"))), - ); + .from(SIden::String(format!("row_number_{cte_name}"))); let mut sub_query = Query::select(); sub_query @@ -403,6 +404,7 @@ pub async fn build_search_query( let mut score_cte_non_recursive = Query::select() .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr_as( Expr::cust_with_values( format!( @@ -445,6 +447,7 @@ pub async fn build_search_query( let mut score_cte_recursive = Query::select() .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr_as( Expr::cust_with_values( format!( @@ -514,6 +517,7 @@ pub async fn build_search_query( let mut row_number_pre_rerank = Query::select(); row_number_pre_rerank .column(SIden::Str("id")) + .column(SIden::Str("chunk")) .from(SIden::String(cte_name.clone())) .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) .limit(rerank.num_documents_to_rerank); @@ -525,7 +529,10 @@ pub async fn build_search_query( // Our actual CTE let mut query = Query::select(); query.column(SIden::Str("id")); - query.expr_as(Expr::cust("(rank).score"), Alias::new("score")); + query.expr_as( + Expr::cust(format!("(rank).score * {boost}")), + Alias::new("score"), + ); // Build the actual CTE let mut sub_query_rank_call = Query::select(); @@ -537,14 +544,7 @@ pub async fn build_search_query( format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), [model_expr, query_expr, parameters_expr], ), Alias::new("rank")) - .from(SIden::String(format!("row_number_{cte_name}"))) - .join_as( - JoinType::InnerJoin, - chunks_table.to_table_tuple(), - Alias::new("chunks"), - Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::String(format!("row_number_{cte_name}")), SIden::Str("id"))), - ); + .from(SIden::String(format!("row_number_{cte_name}"))); let mut sub_query = Query::select(); sub_query