Skip to content

SDK - Added re-ranking into vector search #1516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pgml-sdks/pgml/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 13 additions & 8 deletions pgml-sdks/pgml/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,7 @@ impl Collection {
/// }).into(), &mut pipeline).await?;
/// Ok(())
/// }
#[allow(clippy::type_complexity)]
#[instrument(skip(self))]
pub async fn vector_search(
&mut self,
Expand All @@ -1061,7 +1062,7 @@ impl Collection {

let (built_query, values) =
build_vector_search_query(query.clone(), self, pipeline).await?;
let results: Result<Vec<(Json, String, f64)>, _> =
let results: Result<Vec<(Json, String, f64, Option<f64>)>, _> =
sqlx::query_as_with(&built_query, values)
.fetch_all(&pool)
.await;
Expand All @@ -1072,7 +1073,8 @@ impl Collection {
serde_json::json!({
"document": v.0,
"chunk": v.1,
"score": v.2
"score": v.2,
"rerank_score": v.3
})
.into()
})
Expand All @@ -1087,7 +1089,7 @@ impl Collection {
.await?;
let (built_query, values) =
build_vector_search_query(query, self, pipeline).await?;
let results: Vec<(Json, String, f64)> =
let results: Vec<(Json, String, f64, Option<f64>)> =
sqlx::query_as_with(&built_query, values)
.fetch_all(&pool)
.await?;
Expand All @@ -1097,7 +1099,8 @@ impl Collection {
serde_json::json!({
"document": v.0,
"chunk": v.1,
"score": v.2
"score": v.2,
"rerank_score": v.3
})
.into()
})
Expand All @@ -1121,16 +1124,18 @@ impl Collection {
let pool = get_or_initialize_pool(&self.database_url).await?;
let (built_query, values) =
build_vector_search_query(query.clone(), self, pipeline).await?;
let results: Vec<(Json, String, f64)> = sqlx::query_as_with(&built_query, values)
.fetch_all(&pool)
.await?;
let results: Vec<(Json, String, f64, Option<f64>)> =
sqlx::query_as_with(&built_query, values)
.fetch_all(&pool)
.await?;
Ok(results
.into_iter()
.map(|v| {
serde_json::json!({
"document": v.0,
"chunk": v.1,
"score": v.2
"score": v.2,
"rerank_score": v.3
})
.into()
})
Expand Down
87 changes: 87 additions & 0 deletions pgml-sdks/pgml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,88 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn can_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cvswlear_1";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
let pipeline_name = "0";
let mut pipeline = Pipeline::new(
pipeline_name,
Some(
json!({
"title": {
"semantic_search": {
"model": "intfloat/e5-small-v2",
"parameters": {
"prompt": "passage: "
}
},
"full_text_search": {
"configuration": "english"
}
},
"body": {
"splitter": {
"model": "recursive_character"
},
"semantic_search": {
"model": "intfloat/e5-small-v2",
"parameters": {
"prompt": "passage: "
}
},
},
})
.into(),
),
)?;
collection.add_pipeline(&mut pipeline).await?;
let results = collection
.vector_search(
json!({
"query": {
"fields": {
"title": {
"query": "Test document: 2",
"parameters": {
"prompt": "passage: "
},
"full_text_filter": "test",
"boost": 1.2
},
"body": {
"query": "Test document: 2",
"parameters": {
"prompt": "passage: "
},
"boost": 1.0
},
}
},
Copy link
Contributor Author

@SilasMarvin SilasMarvin Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@montanalow How does this "rerank" key look?

query is the text to compare against.
model is the model to use
num_documents_to_rerank are the number of results to return from vector search and rerank against before limiting it to the limit parameter defined in the next section

"rerank": {
"query": "Test document 2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like query is being repeated a few places in this example, which may be pretty typical. One enhancement would be to move the query string out and reuse it everywhere, and make passing specific sub clause query strings optional. Not a launch blocker though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I will think more on making that optional and reusing it, but will merge this and get it out in the meantime.

"model": "mixedbread-ai/mxbai-rerank-base-v1",
"num_documents_to_rerank": 100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about calling this just limit. Does llamaindex or transformers have a similarly named parameter name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry missed this before merging. I think it might be a little confusing if we make it limit as we already have a limit key, and this isn't actually the limit. We already defined limit with llama index to mean the final number of items returned, but I'm not sure if they or langchain use it elsewhere.

},
"limit": 5
})
.into(),
&mut pipeline,
)
.await?;
assert!(results[0]["rerank_score"].as_f64().is_some());
let ids: Vec<u64> = results
.into_iter()
.map(|r| r["document"]["id"].as_u64().unwrap())
.collect();
assert_eq!(ids, vec![2, 1, 3, 8, 6]);
collection.archive().await?;
Ok(())
}

///////////////////////////////
// Working With Documents /////
///////////////////////////////
Expand Down Expand Up @@ -2207,6 +2289,11 @@ mod tests {
"id"
]
},
"rerank": {
"query": "Test document 2",
"model": "mixedbread-ai/mxbai-rerank-base-v1",
"num_documents_to_rerank": 100
},
"limit": 5
},
"aggregate": {
Expand Down
4 changes: 1 addition & 3 deletions pgml-sdks/pgml/src/rag_query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,7 @@ pub async fn build_rag_query(
r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#,
vector_search.aggregate.join
),
format!(
r#"(SELECT json_agg(jsonb_build_object('chunk', chunk, 'document', document, 'score', score)) FROM "{var_name}")"#
),
format!(r#"(SELECT json_agg(j) FROM "{var_name}" j)"#),
)
}
ValidVariable::RawSQL(sql) => (format!("({})", sql.sql), format!("({})", sql.sql)),
Expand Down
99 changes: 96 additions & 3 deletions pgml-sdks/pgml/src/vector_search_query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@ struct ValidDocument {
keys: Option<Vec<String>>,
}

const fn default_num_documents_to_rerank() -> u64 {
10
}

#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(deny_unknown_fields)]
struct ValidRerank {
query: String,
model: String,
#[serde(default = "default_num_documents_to_rerank")]
num_documents_to_rerank: u64,
parameters: Option<Json>,
}

const fn default_limit() -> u64 {
10
}
Expand All @@ -56,6 +70,8 @@ pub struct ValidQuery {
limit: u64,
// Document related items
document: Option<ValidDocument>,
// Rerank related items
rerank: Option<ValidRerank>,
}

pub async fn build_sqlx_query(
Expand All @@ -66,9 +82,14 @@ pub async fn build_sqlx_query(
prefix: Option<&str>,
) -> anyhow::Result<(SelectStatement, Vec<CommonTableExpression>)> {
let valid_query: ValidQuery = serde_json::from_value(query.0)?;
let limit = valid_query.limit;
let fields = valid_query.query.fields.unwrap_or_default();

let search_limit = if let Some(rerank) = valid_query.rerank.as_ref() {
rerank.num_documents_to_rerank
} else {
valid_query.limit
};

let prefix = prefix.unwrap_or("");

if fields.is_empty() {
Expand Down Expand Up @@ -209,7 +230,7 @@ pub async fn build_sqlx_query(
Expr::col((SIden::Str("documents"), SIden::Str("id")))
.equals((SIden::Str("chunks"), SIden::Str("document_id"))),
)
.limit(limit);
.limit(search_limit);

if let Some(filter) = &valid_query.query.filter {
let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?;
Expand Down Expand Up @@ -272,7 +293,79 @@ pub async fn build_sqlx_query(
// Resort and limit
query
.order_by(SIden::Str("score"), Order::Desc)
.limit(limit);
.limit(search_limit);

// Rerank
let query = if let Some(rerank) = &valid_query.rerank {
// Add our vector_search CTE
let mut vector_search_cte = CommonTableExpression::from_select(query);
vector_search_cte.table_name(Alias::new(format!("{prefix}_vector_search")));
ctes.push(vector_search_cte);

// Add our row_number_vector_search CTE
let mut row_number_vector_search = Query::select();
row_number_vector_search
.columns([
SIden::Str("document"),
SIden::Str("chunk"),
SIden::Str("score"),
])
.from(SIden::String(format!("{prefix}_vector_search")));
row_number_vector_search
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number"));
let mut row_number_vector_search_cte =
CommonTableExpression::from_select(row_number_vector_search);
row_number_vector_search_cte
.table_name(Alias::new(format!("{prefix}_row_number_vector_search")));
ctes.push(row_number_vector_search_cte);

// Our actual select statement
let mut query = Query::select();
query.columns([
SIden::Str("document"),
SIden::Str("chunk"),
SIden::Str("score"),
]);
query.expr_as(Expr::cust("(rank).score"), Alias::new("rank_score"));

// Build the actual select statement sub query
let mut sub_query_rank_call = Query::select();
let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]);
let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]);
let parameters_expr =
Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]);
sub_query_rank_call.expr_as(Expr::cust_with_exprs(
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
[model_expr, query_expr, parameters_expr],
), Alias::new("rank"))
.from(SIden::String(format!("{prefix}_row_number_vector_search")));

let mut sub_query = Query::select();
sub_query
.columns([
SIden::Str("document"),
SIden::Str("chunk"),
SIden::Str("score"),
SIden::Str("rank"),
])
.from_as(
SIden::String(format!("{prefix}_row_number_vector_search")),
Alias::new("rnsv1"),
)
.join_subquery(
JoinType::InnerJoin,
sub_query_rank_call,
Alias::new("rnsv2"),
Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"),
);

// Query from the sub query
query.from_subquery(sub_query, Alias::new("sub_query"));

query
} else {
query
};

Ok((query, ctes))
}
Expand Down