From d36496224fb35dc52d2e4d7747a4aba8c45499ed Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 14 Jun 2024 11:39:51 -0700 Subject: [PATCH 1/2] Added re-ranking into document search --- pgml-sdks/pgml/Cargo.lock | 2 +- pgml-sdks/pgml/src/lib.rs | 16 +- pgml-sdks/pgml/src/search_query_builder.rs | 177 +++++++++++++++++++-- 3 files changed, 175 insertions(+), 20 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 784b528a7..8de1d3967 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1590,7 +1590,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "1.0.4" +version = "1.1.0" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 16ec25ece..c95180fc6 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -1038,7 +1038,12 @@ mod tests { "full_text_search": { "title": { "query": "test 9", - "boost": 4.0 + "boost": 4.0, + "rerank": { + "query": "Test document 2", + "model": "mixedbread-ai/mxbai-rerank-base-v1", + "num_documents_to_rerank": 100 + } }, "body": { "query": "Test", @@ -1051,7 +1056,12 @@ mod tests { "parameters": { "prompt": "query: ", }, - "boost": 2.0 + "boost": 2.0, + "rerank": { + "query": "Test document 2", + "model": "mixedbread-ai/mxbai-rerank-base-v1", + "num_documents_to_rerank": 100 + } }, "body": { "query": "This is the body test", @@ -1086,7 +1096,7 @@ mod tests { .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![9, 3, 4, 7, 5]); + assert_eq!(ids, vec![9, 3, 4, 5, 6]); let pool = get_or_initialize_pool(&None).await?; diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index e76371541..f519add6f 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -25,6 +25,7 @@ struct ValidSemanticSearchAction { query: String, parameters: Option, boost: Option, + rerank: Option, } #[derive(Debug, Deserialize)] @@ -32,6 +33,7 @@ struct ValidSemanticSearchAction { struct ValidFullTextSearchAction { query: String, boost: Option, + rerank: Option, } #[derive(Debug, Deserialize)] @@ -42,6 +44,20 @@ struct ValidQueryActions { filter: Option, } +const fn default_num_documents_to_rerank() -> u64 { + 10 +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidRerank { + query: String, + model: String, + #[serde(default = "default_num_documents_to_rerank")] + num_documents_to_rerank: u64, + parameters: Option, +} + const fn default_limit() -> u64 { 10 } @@ -106,7 +122,11 @@ pub async fn build_search_query( // Build the CTE we actually use later let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); - let cte_name = format!("{key}_embedding_score"); + let cte_name = if vsa.rerank.is_some() { + format!("pre_rerank_{key}_embedding_score") + } else { + format!("{key}_embedding_score") + }; let boost = vsa.boost.unwrap_or(1.); let mut score_cte_non_recursive = Query::select(); let mut score_cte_recurisive = Query::select(); @@ -295,18 +315,78 @@ pub async fn build_search_query( .from_subquery(score_cte_non_recursive, Alias::new("non_recursive")) .union(sea_query::UnionType::All, score_cte_recurisive) .to_owned(); - let mut score_cte = CommonTableExpression::from_select(score_cte); score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); + if let Some(rerank) = vsa.rerank { + // Add our row_number_pre_rerank CTE + let mut row_number_pre_rerank = Query::select(); + row_number_pre_rerank + .column(SIden::Str("id")) + .from(SIden::String(cte_name.clone())) + .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) + .limit(rerank.num_documents_to_rerank); + let mut row_number_pre_rerank_cte = + CommonTableExpression::from_select(row_number_pre_rerank); + row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}"))); + with_clause.cte(row_number_pre_rerank_cte); + + // Our actual CTE + let mut query = Query::select(); + query.column(SIden::Str("id")); + query.expr_as(Expr::cust("(rank).score"), Alias::new("score")); + + // Build the actual CTE + let mut sub_query_rank_call = Query::select(); + let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]); + let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]); + let parameters_expr = + Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]); + sub_query_rank_call.expr_as(Expr::cust_with_exprs( + format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), + [model_expr, query_expr, parameters_expr], + ), Alias::new("rank")) + .from(SIden::String(format!("row_number_{cte_name}"))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::String(format!("row_number_{cte_name}")), SIden::Str("id"))), + ); + + let mut sub_query = Query::select(); + sub_query + .columns([SIden::Str("id"), SIden::Str("rank")]) + .from_as( + SIden::String(format!("row_number_{cte_name}")), + Alias::new("rnsv1"), + ) + .join_subquery( + JoinType::InnerJoin, + sub_query_rank_call, + Alias::new("rnsv2"), + Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"), + ); + + query.from_subquery(sub_query, Alias::new("sub_query")); + let mut query_cte = CommonTableExpression::from_select(query); + query_cte.table_name(Alias::new(format!("{key}_embedding_score"))); + with_clause.cte(query_cte); + } + // Add to the sum expression sum_expression = if let Some(expr) = sum_expression { - Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) + Some(expr.add(Expr::cust(format!( + r#"COALESCE("{key}_embedding_score".score, 0.0)"# + )))) } else { - Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) + Some(Expr::cust(format!( + r#"COALESCE("{key}_embedding_score".score, 0.0)"# + ))) }; - score_table_names.push(cte_name); + score_table_names.push(format!("{key}_embedding_score")); } for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() { @@ -315,7 +395,11 @@ pub async fn build_search_query( let boost = vma.boost.unwrap_or(1.0); // Build the score CTE - let cte_name = format!("{key}_tsvectors_score"); + let cte_name = if vma.rerank.is_some() { + format!("pre_rerank_{key}_tsvectors_score") + } else { + format!("{key}_tsvectors_score") + }; let mut score_cte_non_recursive = Query::select() .column((SIden::Str("documents"), SIden::Str("id"))) @@ -425,13 +509,74 @@ pub async fn build_search_query( score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); + if let Some(rerank) = vma.rerank { + // Add our row_number_pre_rerank CTE + let mut row_number_pre_rerank = Query::select(); + row_number_pre_rerank + .column(SIden::Str("id")) + .from(SIden::String(cte_name.clone())) + .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) + .limit(rerank.num_documents_to_rerank); + let mut row_number_pre_rerank_cte = + CommonTableExpression::from_select(row_number_pre_rerank); + row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}"))); + with_clause.cte(row_number_pre_rerank_cte); + + // Our actual CTE + let mut query = Query::select(); + query.column(SIden::Str("id")); + query.expr_as(Expr::cust("(rank).score"), Alias::new("score")); + + // Build the actual CTE + let mut sub_query_rank_call = Query::select(); + let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]); + let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]); + let parameters_expr = + Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]); + sub_query_rank_call.expr_as(Expr::cust_with_exprs( + format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), + [model_expr, query_expr, parameters_expr], + ), Alias::new("rank")) + .from(SIden::String(format!("row_number_{cte_name}"))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::String(format!("row_number_{cte_name}")), SIden::Str("id"))), + ); + + let mut sub_query = Query::select(); + sub_query + .columns([SIden::Str("id"), SIden::Str("rank")]) + .from_as( + SIden::String(format!("row_number_{cte_name}")), + Alias::new("rnsv1"), + ) + .join_subquery( + JoinType::InnerJoin, + sub_query_rank_call, + Alias::new("rnsv2"), + Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"), + ); + + query.from_subquery(sub_query, Alias::new("sub_query")); + let mut query_cte = CommonTableExpression::from_select(query); + query_cte.table_name(Alias::new(format!("{key}_tsvectors_score"))); + with_clause.cte(query_cte); + } + // Add to the sum expression sum_expression = if let Some(expr) = sum_expression { - Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) + Some(expr.add(Expr::cust(format!( + r#"COALESCE("{key}_tsvectors_score".score, 0.0)"# + )))) } else { - Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) + Some(Expr::cust(format!( + r#"COALESCE("{key}_tsvectors_score".score, 0.0)"# + ))) }; - score_table_names.push(cte_name); + score_table_names.push(format!("{key}_tsvectors_score")); } let query = if let Some(select_from) = score_table_names.first() { @@ -440,9 +585,9 @@ pub async fn build_search_query( .into_iter() .map(|t| Expr::col((SIden::String(t), SIden::Str("id"))).into()) .collect(); - let mut main_query = Query::select(); + let mut joined_query = Query::select(); for i in 1..score_table_names_e.len() { - main_query.full_outer_join( + joined_query.full_outer_join( SIden::String(score_table_names[i].to_string()), Expr::col(( SIden::String(score_table_names[i].to_string()), @@ -455,7 +600,8 @@ pub async fn build_search_query( let sum_expression = sum_expression .context("query requires some scoring through full_text_search or semantic_search")?; - main_query + + joined_query .expr_as(Expr::expr(id_select_expression.clone()), Alias::new("id")) .expr_as(sum_expression, Alias::new("score")) .column(SIden::Str("document")) @@ -468,10 +614,9 @@ pub async fn build_search_query( ) .order_by(SIden::Str("score"), Order::Desc) .limit(limit); - - let mut main_query = CommonTableExpression::from_select(main_query); - main_query.table_name(Alias::new("main")); - with_clause.cte(main_query); + let mut joined_query = CommonTableExpression::from_select(joined_query); + joined_query.table_name(Alias::new("main")); + with_clause.cte(joined_query); // Insert into searches table let searches_table = format!("{}_{}.searches", collection.name, pipeline.name); From 16a7799cc82bc54476269ac9c01de3b80483f14c Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 14 Jun 2024 14:01:23 -0700 Subject: [PATCH 2/2] Finalized re-ranking in document search --- pgml-sdks/pgml/src/lib.rs | 6 ++-- pgml-sdks/pgml/src/search_query_builder.rs | 36 +++++++++++----------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index c95180fc6..30ce09fea 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -980,7 +980,7 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_123"; + let collection_name = "test_r_c_cswle_126"; let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -1096,7 +1096,7 @@ mod tests { .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![9, 3, 4, 5, 6]); + assert_eq!(ids, vec![2, 9, 3, 8, 4]); let pool = get_or_initialize_pool(&None).await?; @@ -1121,7 +1121,7 @@ mod tests { // Document ids are 1 based in the db not 0 based like they are here assert_eq!( search_results.iter().map(|sr| sr.2).collect::>(), - vec![10, 4, 5, 8, 6] + vec![3, 10, 4, 9, 5] ); let event = json!({"clicked": true}); diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index f519add6f..7ca23ff25 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -151,6 +151,7 @@ pub async fn build_search_query( score_cte_non_recursive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .join_as( JoinType::InnerJoin, chunks_table.to_table_tuple(), @@ -177,6 +178,7 @@ pub async fn build_search_query( score_cte_recurisive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || documents.id"#))) .expr(Expr::cust(format!( r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# @@ -233,6 +235,7 @@ pub async fn build_search_query( score_cte_non_recursive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr(Expr::cust("ARRAY[documents.id] as previous_document_ids")) .expr(Expr::cust_with_values( format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), @@ -269,6 +272,7 @@ pub async fn build_search_query( Expr::cust("1 = 1"), ) .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr(Expr::cust(format!( r#""{cte_name}".previous_document_ids || documents.id"# ))) @@ -324,6 +328,7 @@ pub async fn build_search_query( let mut row_number_pre_rerank = Query::select(); row_number_pre_rerank .column(SIden::Str("id")) + .column(SIden::Str("chunk")) .from(SIden::String(cte_name.clone())) .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) .limit(rerank.num_documents_to_rerank); @@ -335,7 +340,10 @@ pub async fn build_search_query( // Our actual CTE let mut query = Query::select(); query.column(SIden::Str("id")); - query.expr_as(Expr::cust("(rank).score"), Alias::new("score")); + query.expr_as( + Expr::cust(format!("(rank).score * {boost}")), + Alias::new("score"), + ); // Build the actual CTE let mut sub_query_rank_call = Query::select(); @@ -347,14 +355,7 @@ pub async fn build_search_query( format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), [model_expr, query_expr, parameters_expr], ), Alias::new("rank")) - .from(SIden::String(format!("row_number_{cte_name}"))) - .join_as( - JoinType::InnerJoin, - chunks_table.to_table_tuple(), - Alias::new("chunks"), - Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::String(format!("row_number_{cte_name}")), SIden::Str("id"))), - ); + .from(SIden::String(format!("row_number_{cte_name}"))); let mut sub_query = Query::select(); sub_query @@ -403,6 +404,7 @@ pub async fn build_search_query( let mut score_cte_non_recursive = Query::select() .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr_as( Expr::cust_with_values( format!( @@ -445,6 +447,7 @@ pub async fn build_search_query( let mut score_cte_recursive = Query::select() .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr_as( Expr::cust_with_values( format!( @@ -514,6 +517,7 @@ pub async fn build_search_query( let mut row_number_pre_rerank = Query::select(); row_number_pre_rerank .column(SIden::Str("id")) + .column(SIden::Str("chunk")) .from(SIden::String(cte_name.clone())) .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) .limit(rerank.num_documents_to_rerank); @@ -525,7 +529,10 @@ pub async fn build_search_query( // Our actual CTE let mut query = Query::select(); query.column(SIden::Str("id")); - query.expr_as(Expr::cust("(rank).score"), Alias::new("score")); + query.expr_as( + Expr::cust(format!("(rank).score * {boost}")), + Alias::new("score"), + ); // Build the actual CTE let mut sub_query_rank_call = Query::select(); @@ -537,14 +544,7 @@ pub async fn build_search_query( format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), [model_expr, query_expr, parameters_expr], ), Alias::new("rank")) - .from(SIden::String(format!("row_number_{cte_name}"))) - .join_as( - JoinType::InnerJoin, - chunks_table.to_table_tuple(), - Alias::new("chunks"), - Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::String(format!("row_number_{cte_name}")), SIden::Str("id"))), - ); + .from(SIden::String(format!("row_number_{cte_name}"))); let mut sub_query = Query::select(); sub_query pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy