Skip to content

Removed transactions #765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 99 additions & 141 deletions pgml-sdks/rust/pgml/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::collections::HashMap;
use crate::languages::javascript::*;
use crate::models;
use crate::queries;
use crate::{query_builder, transaction_wrapper};
use crate::query_builder;

/// A collection of documents
#[derive(custom_derive, Debug, Clone)]
Expand Down Expand Up @@ -314,18 +314,16 @@ impl Collection {
};
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;

transaction_wrapper!(
sqlx::query(&query_builder!(
"INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5",
self.documents_table_name
))
.bind(&text)
.bind(source_uuid)
.bind(&document_json)
.bind(&text)
.bind(&document_json),
self.pool.borrow()
);
sqlx::query(&query_builder!(
"INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5",
self.documents_table_name
))
.bind(&text)
.bind(source_uuid)
.bind(&document_json)
.bind(&text)
.bind(&document_json)
.execute(self.pool.borrow()).await?;
}
Ok(())
}
Expand Down Expand Up @@ -363,18 +361,14 @@ impl Collection {
None => serde_json::json!({}),
};

let current_splitter;
transaction_wrapper!(
current_splitter,
sqlx::query_as::<_, models::Splitter>(&query_builder!(
"SELECT * from %s where name = $1 and parameters = $2;",
self.splitters_table_name
))
.bind(&splitter_name)
.bind(&splitter_params),
self.pool.borrow(),
fetch_optional
);
let current_splitter: Option<models::Splitter> = sqlx::query_as(&query_builder!(
"SELECT * from %s where name = $1 and parameters = $2;",
self.splitters_table_name
))
.bind(&splitter_name)
.bind(&splitter_params)
.fetch_optional(self.pool.borrow())
.await?;

match current_splitter {
Some(_splitter) => {
Expand All @@ -384,32 +378,27 @@ impl Collection {
);
}
None => {
transaction_wrapper!(
sqlx::query(&query_builder!(
"INSERT INTO %s (name, parameters) VALUES ($1, $2)",
self.splitters_table_name
))
.bind(splitter_name)
.bind(splitter_params),
self.pool.borrow()
);
sqlx::query(&query_builder!(
"INSERT INTO %s (name, parameters) VALUES ($1, $2)",
self.splitters_table_name
))
.bind(splitter_name)
.bind(splitter_params)
.execute(self.pool.borrow())
.await?;
}
}
Ok(())
}

/// Gets all registered text [models::Splitter]s
pub async fn get_text_splitters(&self) -> anyhow::Result<Vec<models::Splitter>> {
let splitters;
transaction_wrapper!(
splitters,
sqlx::query_as::<_, models::Splitter>(&query_builder!(
"SELECT * from %s",
self.splitters_table_name
)),
self.pool.borrow(),
fetch_all
);
let splitters: Vec<models::Splitter> = sqlx::query_as(&query_builder!(
"SELECT * from %s",
self.splitters_table_name
))
.fetch_all(self.pool.borrow())
.await?;
Ok(splitters)
}

Expand Down Expand Up @@ -443,17 +432,16 @@ impl Collection {
/// ```
pub async fn generate_chunks(&self, splitter_id: Option<i64>) -> anyhow::Result<()> {
let splitter_id = splitter_id.unwrap_or(1);
transaction_wrapper!(
sqlx::query(&query_builder!(
queries::GENERATE_CHUNKS,
self.splitters_table_name,
self.chunks_table_name,
self.documents_table_name,
self.chunks_table_name
))
.bind(splitter_id),
self.pool.borrow()
);
sqlx::query(&query_builder!(
queries::GENERATE_CHUNKS,
self.splitters_table_name,
self.chunks_table_name,
self.documents_table_name,
self.chunks_table_name
))
.bind(splitter_id)
.execute(self.pool.borrow())
.await?;
Ok(())
}

Expand Down Expand Up @@ -492,19 +480,15 @@ impl Collection {
None => serde_json::json!({}),
};

let current_model;
transaction_wrapper!(
current_model,
sqlx::query_as::<_, models::Model>(&query_builder!(
"SELECT * from %s where task = $1 and name = $2 and parameters = $3;",
self.models_table_name
))
.bind(&task)
.bind(&model_name)
.bind(&model_params),
self.pool.borrow(),
fetch_optional
);
let current_model: Option<models::Model> = sqlx::query_as(&query_builder!(
"SELECT * from %s where task = $1 and name = $2 and parameters = $3;",
self.models_table_name
))
.bind(&task)
.bind(&model_name)
.bind(&model_params)
.fetch_optional(self.pool.borrow())
.await?;

match current_model {
Some(model) => {
Expand All @@ -515,37 +499,27 @@ impl Collection {
Ok(model.id)
}
None => {
let id;
transaction_wrapper!(
id,
sqlx::query_as::<_, (i64,)>(&query_builder!(
"INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id",
self.models_table_name
))
.bind(task)
.bind(model_name)
.bind(model_params),
self.pool.borrow(),
fetch_one
);
let id: (i64,) = sqlx::query_as(&query_builder!(
"INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id",
self.models_table_name
))
.bind(task)
.bind(model_name)
.bind(model_params)
.fetch_one(self.pool.borrow())
.await?;
Ok(id.0)
}
}
}

/// Gets all registered [models::Model]s
pub async fn get_models(&self) -> anyhow::Result<Vec<models::Model>> {
let models;
transaction_wrapper!(
models,
sqlx::query_as::<_, models::Model>(&query_builder!(
"SELECT * from %s",
self.models_table_name
)),
self.pool.borrow(),
fetch_all
);
Ok(models)
Ok(
sqlx::query_as(&query_builder!("SELECT * from %s", self.models_table_name))
.fetch_all(self.pool.borrow())
.await?,
)
}

async fn create_or_get_embeddings_table(
Expand All @@ -554,17 +528,13 @@ impl Collection {
splitter_id: i64,
) -> anyhow::Result<String> {
let pool = self.pool.borrow();
let table_name;
transaction_wrapper!(
table_name,
sqlx::query_as::<_, (String,)>(&query_builder!(
let table_name: Option<(String,)> =
sqlx::query_as(&query_builder!(
"SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2",
self.transforms_table_name))
.bind(model_id)
.bind(splitter_id),
pool,
fetch_optional
);
.bind(splitter_id)
.fetch_optional(pool).await?;
match table_name {
Some((name,)) => Ok(name),
None => {
Expand All @@ -573,12 +543,11 @@ impl Collection {
self.name,
&uuid::Uuid::new_v4().to_string()[0..6]
);
let embedding;
transaction_wrapper!(embedding, sqlx::query_as::<_, (Vec<f32>,)>(&query_builder!(
let embedding: (Vec<f32>,) = sqlx::query_as(&query_builder!(
"WITH model as (SELECT name, parameters from %s where id = $1) SELECT embedding from pgml.embed(transformer => (SELECT name FROM model), text => 'Hello, World!', kwargs => (SELECT parameters FROM model)) as embedding",
self.models_table_name))
.bind(model_id),
pool, fetch_one);
.bind(model_id)
.fetch_one(pool).await?;
let embedding = embedding.0;
let embedding_length = embedding.len() as i64;
pool.execute(
Expand All @@ -591,15 +560,13 @@ impl Collection {
.as_str(),
)
.await?;
transaction_wrapper!(
sqlx::query(&query_builder!(
"INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)",
self.transforms_table_name))
.bind(&table_name)
.bind(model_id)
.bind(splitter_id),
pool
);
sqlx::query(&query_builder!(
"INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)",
self.transforms_table_name))
.bind(&table_name)
.bind(model_id)
.bind(splitter_id)
.execute(pool).await?;
pool.execute(
query_builder!(
queries::CREATE_INDEX,
Expand Down Expand Up @@ -677,18 +644,17 @@ impl Collection {
.create_or_get_embeddings_table(model_id, splitter_id)
.await?;

transaction_wrapper!(
sqlx::query(&query_builder!(
queries::GENERATE_EMBEDDINGS,
self.models_table_name,
embeddings_table_name,
self.chunks_table_name,
embeddings_table_name
))
.bind(model_id)
.bind(splitter_id),
self.pool.borrow()
);
sqlx::query(&query_builder!(
queries::GENERATE_EMBEDDINGS,
self.models_table_name,
embeddings_table_name,
self.chunks_table_name,
embeddings_table_name
))
.bind(model_id)
.bind(splitter_id)
.execute(self.pool.borrow())
.await?;

Ok(())
}
Expand Down Expand Up @@ -751,17 +717,12 @@ impl Collection {
let model_id = model_id.unwrap_or(1);
let splitter_id = splitter_id.unwrap_or(1);

let embeddings_table_name;
transaction_wrapper!(
embeddings_table_name,
sqlx::query_as::<_, (String,)>(&query_builder!(
let embeddings_table_name: Option<(String,)> = sqlx::query_as(&query_builder!(
"SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2",
self.transforms_table_name))
.bind(model_id)
.bind(splitter_id),
self.pool.borrow(),
fetch_optional
);
.bind(splitter_id)
.fetch_optional(self.pool.borrow()).await?;

let embeddings_table_name = match embeddings_table_name {
Some((table_name,)) => table_name,
Expand All @@ -770,10 +731,8 @@ impl Collection {
}
};

let results;
transaction_wrapper!(
results,
sqlx::query_as::<_, (f64, String, Json<HashMap<String, String>>)>(&query_builder!(
let results: Vec<(f64, String, Json<HashMap<String, String>>)> =
sqlx::query_as(&query_builder!(
queries::VECTOR_SEARCH,
self.models_table_name,
embeddings_table_name,
Expand All @@ -784,10 +743,9 @@ impl Collection {
.bind(model_id)
.bind(query)
.bind(query_params)
.bind(top_k),
self.pool.borrow(),
fetch_all
);
.bind(top_k)
.fetch_all(self.pool.borrow())
.await?;
let results: Vec<(f64, String, HashMap<String, String>)> =
results.into_iter().map(|r| (r.0, r.1, r.2 .0)).collect();
Ok(results)
Expand Down
Loading