diff --git a/pgml-sdks/rust/pgml/src/collection.rs b/pgml-sdks/rust/pgml/src/collection.rs index 2946fdecd..6dd1aa268 100644 --- a/pgml-sdks/rust/pgml/src/collection.rs +++ b/pgml-sdks/rust/pgml/src/collection.rs @@ -13,7 +13,7 @@ use std::collections::HashMap; use crate::languages::javascript::*; use crate::models; use crate::queries; -use crate::{query_builder, transaction_wrapper}; +use crate::query_builder; /// A collection of documents #[derive(custom_derive, Debug, Clone)] @@ -314,18 +314,16 @@ impl Collection { }; let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - transaction_wrapper!( - sqlx::query(&query_builder!( - "INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5", - self.documents_table_name - )) - .bind(&text) - .bind(source_uuid) - .bind(&document_json) - .bind(&text) - .bind(&document_json), - self.pool.borrow() - ); + sqlx::query(&query_builder!( + "INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5", + self.documents_table_name + )) + .bind(&text) + .bind(source_uuid) + .bind(&document_json) + .bind(&text) + .bind(&document_json) + .execute(self.pool.borrow()).await?; } Ok(()) } @@ -363,18 +361,14 @@ impl Collection { None => serde_json::json!({}), }; - let current_splitter; - transaction_wrapper!( - current_splitter, - sqlx::query_as::<_, models::Splitter>(&query_builder!( - "SELECT * from %s where name = $1 and parameters = $2;", - self.splitters_table_name - )) - .bind(&splitter_name) - .bind(&splitter_params), - self.pool.borrow(), - fetch_optional - ); + let current_splitter: Option = sqlx::query_as(&query_builder!( + "SELECT * from %s where name = $1 and parameters = $2;", + self.splitters_table_name + )) + .bind(&splitter_name) + .bind(&splitter_params) + .fetch_optional(self.pool.borrow()) + .await?; match current_splitter { Some(_splitter) => { @@ -384,15 +378,14 @@ impl Collection { ); } None => { - transaction_wrapper!( - sqlx::query(&query_builder!( - "INSERT INTO %s (name, parameters) VALUES ($1, $2)", - self.splitters_table_name - )) - .bind(splitter_name) - .bind(splitter_params), - self.pool.borrow() - ); + sqlx::query(&query_builder!( + "INSERT INTO %s (name, parameters) VALUES ($1, $2)", + self.splitters_table_name + )) + .bind(splitter_name) + .bind(splitter_params) + .execute(self.pool.borrow()) + .await?; } } Ok(()) @@ -400,16 +393,12 @@ impl Collection { /// Gets all registered text [models::Splitter]s pub async fn get_text_splitters(&self) -> anyhow::Result> { - let splitters; - transaction_wrapper!( - splitters, - sqlx::query_as::<_, models::Splitter>(&query_builder!( - "SELECT * from %s", - self.splitters_table_name - )), - self.pool.borrow(), - fetch_all - ); + let splitters: Vec = sqlx::query_as(&query_builder!( + "SELECT * from %s", + self.splitters_table_name + )) + .fetch_all(self.pool.borrow()) + .await?; Ok(splitters) } @@ -443,17 +432,16 @@ impl Collection { /// ``` pub async fn generate_chunks(&self, splitter_id: Option) -> anyhow::Result<()> { let splitter_id = splitter_id.unwrap_or(1); - transaction_wrapper!( - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS, - self.splitters_table_name, - self.chunks_table_name, - self.documents_table_name, - self.chunks_table_name - )) - .bind(splitter_id), - self.pool.borrow() - ); + sqlx::query(&query_builder!( + queries::GENERATE_CHUNKS, + self.splitters_table_name, + self.chunks_table_name, + self.documents_table_name, + self.chunks_table_name + )) + .bind(splitter_id) + .execute(self.pool.borrow()) + .await?; Ok(()) } @@ -492,19 +480,15 @@ impl Collection { None => serde_json::json!({}), }; - let current_model; - transaction_wrapper!( - current_model, - sqlx::query_as::<_, models::Model>(&query_builder!( - "SELECT * from %s where task = $1 and name = $2 and parameters = $3;", - self.models_table_name - )) - .bind(&task) - .bind(&model_name) - .bind(&model_params), - self.pool.borrow(), - fetch_optional - ); + let current_model: Option = sqlx::query_as(&query_builder!( + "SELECT * from %s where task = $1 and name = $2 and parameters = $3;", + self.models_table_name + )) + .bind(&task) + .bind(&model_name) + .bind(&model_params) + .fetch_optional(self.pool.borrow()) + .await?; match current_model { Some(model) => { @@ -515,19 +499,15 @@ impl Collection { Ok(model.id) } None => { - let id; - transaction_wrapper!( - id, - sqlx::query_as::<_, (i64,)>(&query_builder!( - "INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id", - self.models_table_name - )) - .bind(task) - .bind(model_name) - .bind(model_params), - self.pool.borrow(), - fetch_one - ); + let id: (i64,) = sqlx::query_as(&query_builder!( + "INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id", + self.models_table_name + )) + .bind(task) + .bind(model_name) + .bind(model_params) + .fetch_one(self.pool.borrow()) + .await?; Ok(id.0) } } @@ -535,17 +515,11 @@ impl Collection { /// Gets all registered [models::Model]s pub async fn get_models(&self) -> anyhow::Result> { - let models; - transaction_wrapper!( - models, - sqlx::query_as::<_, models::Model>(&query_builder!( - "SELECT * from %s", - self.models_table_name - )), - self.pool.borrow(), - fetch_all - ); - Ok(models) + Ok( + sqlx::query_as(&query_builder!("SELECT * from %s", self.models_table_name)) + .fetch_all(self.pool.borrow()) + .await?, + ) } async fn create_or_get_embeddings_table( @@ -554,17 +528,13 @@ impl Collection { splitter_id: i64, ) -> anyhow::Result { let pool = self.pool.borrow(); - let table_name; - transaction_wrapper!( - table_name, - sqlx::query_as::<_, (String,)>(&query_builder!( + let table_name: Option<(String,)> = + sqlx::query_as(&query_builder!( "SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2", self.transforms_table_name)) .bind(model_id) - .bind(splitter_id), - pool, - fetch_optional - ); + .bind(splitter_id) + .fetch_optional(pool).await?; match table_name { Some((name,)) => Ok(name), None => { @@ -573,12 +543,11 @@ impl Collection { self.name, &uuid::Uuid::new_v4().to_string()[0..6] ); - let embedding; - transaction_wrapper!(embedding, sqlx::query_as::<_, (Vec,)>(&query_builder!( + let embedding: (Vec,) = sqlx::query_as(&query_builder!( "WITH model as (SELECT name, parameters from %s where id = $1) SELECT embedding from pgml.embed(transformer => (SELECT name FROM model), text => 'Hello, World!', kwargs => (SELECT parameters FROM model)) as embedding", self.models_table_name)) - .bind(model_id), - pool, fetch_one); + .bind(model_id) + .fetch_one(pool).await?; let embedding = embedding.0; let embedding_length = embedding.len() as i64; pool.execute( @@ -591,15 +560,13 @@ impl Collection { .as_str(), ) .await?; - transaction_wrapper!( - sqlx::query(&query_builder!( - "INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)", - self.transforms_table_name)) - .bind(&table_name) - .bind(model_id) - .bind(splitter_id), - pool - ); + sqlx::query(&query_builder!( + "INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)", + self.transforms_table_name)) + .bind(&table_name) + .bind(model_id) + .bind(splitter_id) + .execute(pool).await?; pool.execute( query_builder!( queries::CREATE_INDEX, @@ -677,18 +644,17 @@ impl Collection { .create_or_get_embeddings_table(model_id, splitter_id) .await?; - transaction_wrapper!( - sqlx::query(&query_builder!( - queries::GENERATE_EMBEDDINGS, - self.models_table_name, - embeddings_table_name, - self.chunks_table_name, - embeddings_table_name - )) - .bind(model_id) - .bind(splitter_id), - self.pool.borrow() - ); + sqlx::query(&query_builder!( + queries::GENERATE_EMBEDDINGS, + self.models_table_name, + embeddings_table_name, + self.chunks_table_name, + embeddings_table_name + )) + .bind(model_id) + .bind(splitter_id) + .execute(self.pool.borrow()) + .await?; Ok(()) } @@ -751,17 +717,12 @@ impl Collection { let model_id = model_id.unwrap_or(1); let splitter_id = splitter_id.unwrap_or(1); - let embeddings_table_name; - transaction_wrapper!( - embeddings_table_name, - sqlx::query_as::<_, (String,)>(&query_builder!( + let embeddings_table_name: Option<(String,)> = sqlx::query_as(&query_builder!( "SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2", self.transforms_table_name)) .bind(model_id) - .bind(splitter_id), - self.pool.borrow(), - fetch_optional - ); + .bind(splitter_id) + .fetch_optional(self.pool.borrow()).await?; let embeddings_table_name = match embeddings_table_name { Some((table_name,)) => table_name, @@ -770,10 +731,8 @@ impl Collection { } }; - let results; - transaction_wrapper!( - results, - sqlx::query_as::<_, (f64, String, Json>)>(&query_builder!( + let results: Vec<(f64, String, Json>)> = + sqlx::query_as(&query_builder!( queries::VECTOR_SEARCH, self.models_table_name, embeddings_table_name, @@ -784,10 +743,9 @@ impl Collection { .bind(model_id) .bind(query) .bind(query_params) - .bind(top_k), - self.pool.borrow(), - fetch_all - ); + .bind(top_k) + .fetch_all(self.pool.borrow()) + .await?; let results: Vec<(f64, String, HashMap)> = results.into_iter().map(|r| (r.0, r.1, r.2 .0)).collect(); Ok(results) diff --git a/pgml-sdks/rust/pgml/src/database.rs b/pgml-sdks/rust/pgml/src/database.rs index 23ef21515..341f2f777 100644 --- a/pgml-sdks/rust/pgml/src/database.rs +++ b/pgml-sdks/rust/pgml/src/database.rs @@ -11,7 +11,7 @@ use crate::collection::*; use crate::languages::javascript::*; use crate::models; use crate::queries; -use crate::{query_builder, transaction_wrapper}; +use crate::query_builder; /// A connection to a postgres database #[derive(custom_derive, Clone, Debug)] @@ -46,10 +46,9 @@ impl Database { .max_connections(5) .connect_with(connection_options) .await?; - transaction_wrapper!( - sqlx::query(queries::CREATE_COLLECTIONS_TABLE), - pool.borrow() - ); + sqlx::query(queries::CREATE_COLLECTIONS_TABLE) + .execute(pool.borrow()) + .await?; let pool = pool; Ok(Self { pool }) } @@ -74,23 +73,19 @@ impl Database { /// } /// ``` pub async fn create_or_get_collection(&self, name: &str) -> anyhow::Result { - let collection; - transaction_wrapper!( - collection, - sqlx::query_as::<_, models::Collection>( - "SELECT * from pgml.collections where name = $1;" - ) - .bind(name), - self.pool.borrow(), - fetch_optional - ); + let collection: Option = sqlx::query_as::<_, models::Collection>( + "SELECT * from pgml.collections where name = $1;", + ) + .bind(name) + .fetch_optional(self.pool.borrow()) + .await?; match collection { Some(c) => Ok(Collection::from_model_and_pool(c, self.pool.clone())), None => { - transaction_wrapper!( - sqlx::query("INSERT INTO pgml.collections (name) VALUES ($1)").bind(name), - self.pool.borrow() - ); + sqlx::query("INSERT INTO pgml.collections (name) VALUES ($1)") + .bind(name) + .execute(self.pool.borrow()) + .await?; Ok(Collection::new(name.to_string(), self.pool.clone()).await?) } } @@ -120,20 +115,18 @@ impl Database { .expect("Error getting system time") .as_secs(); let archive_table_name = format!("{}_archive_{}", name, timestamp); - transaction_wrapper!( - sqlx::query(&query_builder!( - "ALTER SCHEMA %s RENAME TO %s", - name, - archive_table_name - )), - self.pool.borrow() - ); - transaction_wrapper!( - sqlx::query("UPDATE pgml.collections SET name = $1, active = FALSE where name = $2") - .bind(archive_table_name) - .bind(name), - self.pool.borrow() - ); + sqlx::query(&query_builder!( + "ALTER SCHEMA %s RENAME TO %s", + name, + archive_table_name + )) + .execute(self.pool.borrow()) + .await?; + sqlx::query("UPDATE pgml.collections SET name = $1, active = FALSE where name = $2") + .bind(archive_table_name) + .bind(name) + .execute(self.pool.borrow()) + .await?; Ok(()) } } diff --git a/pgml-sdks/rust/pgml/src/lib.rs b/pgml-sdks/rust/pgml/src/lib.rs index f57b9373f..eecee526e 100644 --- a/pgml-sdks/rust/pgml/src/lib.rs +++ b/pgml-sdks/rust/pgml/src/lib.rs @@ -101,7 +101,7 @@ mod tests { let connection_string = env::var("DATABASE_URL").unwrap(); init_logger(LevelFilter::Info).unwrap(); - let collection_name = "test25"; + let collection_name = "test26"; let db = Database::new(&connection_string).await.unwrap(); let collection = db.create_or_get_collection(collection_name).await.unwrap(); diff --git a/pgml-sdks/rust/pgml/src/utils.rs b/pgml-sdks/rust/pgml/src/utils.rs index 5c7b93670..ad72bc8b5 100644 --- a/pgml-sdks/rust/pgml/src/utils.rs +++ b/pgml-sdks/rust/pgml/src/utils.rs @@ -7,23 +7,3 @@ macro_rules! query_builder { query }}; } - -#[macro_export] -macro_rules! transaction_wrapper { - ($e:expr, $a:expr) => { - let mut transaction = $a.begin().await?; - $e.execute(&mut transaction).await?; - sqlx::query("DEALLOCATE ALL") - .execute(&mut transaction) - .await?; - transaction.commit().await?; - }; - ($n:ident, $e:expr, $a:expr, $i:ident) => { - let mut transaction = $a.begin().await?; - $n = $e.$i(&mut transaction).await?; - sqlx::query("DEALLOCATE ALL") - .execute(&mut transaction) - .await?; - transaction.commit().await?; - }; -}