Skip to content

Commit 541a1d5

Browse files
authored
Removed transactions (#765)
1 parent cba5448 commit 541a1d5

File tree

4 files changed

+126
-195
lines changed

4 files changed

+126
-195
lines changed

pgml-sdks/rust/pgml/src/collection.rs

Lines changed: 99 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use std::collections::HashMap;
1313
use crate::languages::javascript::*;
1414
use crate::models;
1515
use crate::queries;
16-
use crate::{query_builder, transaction_wrapper};
16+
use crate::query_builder;
1717

1818
/// A collection of documents
1919
#[derive(custom_derive, Debug, Clone)]
@@ -314,18 +314,16 @@ impl Collection {
314314
};
315315
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
316316

317-
transaction_wrapper!(
318-
sqlx::query(&query_builder!(
319-
"INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5",
320-
self.documents_table_name
321-
))
322-
.bind(&text)
323-
.bind(source_uuid)
324-
.bind(&document_json)
325-
.bind(&text)
326-
.bind(&document_json),
327-
self.pool.borrow()
328-
);
317+
sqlx::query(&query_builder!(
318+
"INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5",
319+
self.documents_table_name
320+
))
321+
.bind(&text)
322+
.bind(source_uuid)
323+
.bind(&document_json)
324+
.bind(&text)
325+
.bind(&document_json)
326+
.execute(self.pool.borrow()).await?;
329327
}
330328
Ok(())
331329
}
@@ -363,18 +361,14 @@ impl Collection {
363361
None => serde_json::json!({}),
364362
};
365363

366-
let current_splitter;
367-
transaction_wrapper!(
368-
current_splitter,
369-
sqlx::query_as::<_, models::Splitter>(&query_builder!(
370-
"SELECT * from %s where name = $1 and parameters = $2;",
371-
self.splitters_table_name
372-
))
373-
.bind(&splitter_name)
374-
.bind(&splitter_params),
375-
self.pool.borrow(),
376-
fetch_optional
377-
);
364+
let current_splitter: Option<models::Splitter> = sqlx::query_as(&query_builder!(
365+
"SELECT * from %s where name = $1 and parameters = $2;",
366+
self.splitters_table_name
367+
))
368+
.bind(&splitter_name)
369+
.bind(&splitter_params)
370+
.fetch_optional(self.pool.borrow())
371+
.await?;
378372

379373
match current_splitter {
380374
Some(_splitter) => {
@@ -384,32 +378,27 @@ impl Collection {
384378
);
385379
}
386380
None => {
387-
transaction_wrapper!(
388-
sqlx::query(&query_builder!(
389-
"INSERT INTO %s (name, parameters) VALUES ($1, $2)",
390-
self.splitters_table_name
391-
))
392-
.bind(splitter_name)
393-
.bind(splitter_params),
394-
self.pool.borrow()
395-
);
381+
sqlx::query(&query_builder!(
382+
"INSERT INTO %s (name, parameters) VALUES ($1, $2)",
383+
self.splitters_table_name
384+
))
385+
.bind(splitter_name)
386+
.bind(splitter_params)
387+
.execute(self.pool.borrow())
388+
.await?;
396389
}
397390
}
398391
Ok(())
399392
}
400393

401394
/// Gets all registered text [models::Splitter]s
402395
pub async fn get_text_splitters(&self) -> anyhow::Result<Vec<models::Splitter>> {
403-
let splitters;
404-
transaction_wrapper!(
405-
splitters,
406-
sqlx::query_as::<_, models::Splitter>(&query_builder!(
407-
"SELECT * from %s",
408-
self.splitters_table_name
409-
)),
410-
self.pool.borrow(),
411-
fetch_all
412-
);
396+
let splitters: Vec<models::Splitter> = sqlx::query_as(&query_builder!(
397+
"SELECT * from %s",
398+
self.splitters_table_name
399+
))
400+
.fetch_all(self.pool.borrow())
401+
.await?;
413402
Ok(splitters)
414403
}
415404

@@ -443,17 +432,16 @@ impl Collection {
443432
/// ```
444433
pub async fn generate_chunks(&self, splitter_id: Option<i64>) -> anyhow::Result<()> {
445434
let splitter_id = splitter_id.unwrap_or(1);
446-
transaction_wrapper!(
447-
sqlx::query(&query_builder!(
448-
queries::GENERATE_CHUNKS,
449-
self.splitters_table_name,
450-
self.chunks_table_name,
451-
self.documents_table_name,
452-
self.chunks_table_name
453-
))
454-
.bind(splitter_id),
455-
self.pool.borrow()
456-
);
435+
sqlx::query(&query_builder!(
436+
queries::GENERATE_CHUNKS,
437+
self.splitters_table_name,
438+
self.chunks_table_name,
439+
self.documents_table_name,
440+
self.chunks_table_name
441+
))
442+
.bind(splitter_id)
443+
.execute(self.pool.borrow())
444+
.await?;
457445
Ok(())
458446
}
459447

@@ -492,19 +480,15 @@ impl Collection {
492480
None => serde_json::json!({}),
493481
};
494482

495-
let current_model;
496-
transaction_wrapper!(
497-
current_model,
498-
sqlx::query_as::<_, models::Model>(&query_builder!(
499-
"SELECT * from %s where task = $1 and name = $2 and parameters = $3;",
500-
self.models_table_name
501-
))
502-
.bind(&task)
503-
.bind(&model_name)
504-
.bind(&model_params),
505-
self.pool.borrow(),
506-
fetch_optional
507-
);
483+
let current_model: Option<models::Model> = sqlx::query_as(&query_builder!(
484+
"SELECT * from %s where task = $1 and name = $2 and parameters = $3;",
485+
self.models_table_name
486+
))
487+
.bind(&task)
488+
.bind(&model_name)
489+
.bind(&model_params)
490+
.fetch_optional(self.pool.borrow())
491+
.await?;
508492

509493
match current_model {
510494
Some(model) => {
@@ -515,37 +499,27 @@ impl Collection {
515499
Ok(model.id)
516500
}
517501
None => {
518-
let id;
519-
transaction_wrapper!(
520-
id,
521-
sqlx::query_as::<_, (i64,)>(&query_builder!(
522-
"INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id",
523-
self.models_table_name
524-
))
525-
.bind(task)
526-
.bind(model_name)
527-
.bind(model_params),
528-
self.pool.borrow(),
529-
fetch_one
530-
);
502+
let id: (i64,) = sqlx::query_as(&query_builder!(
503+
"INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id",
504+
self.models_table_name
505+
))
506+
.bind(task)
507+
.bind(model_name)
508+
.bind(model_params)
509+
.fetch_one(self.pool.borrow())
510+
.await?;
531511
Ok(id.0)
532512
}
533513
}
534514
}
535515

536516
/// Gets all registered [models::Model]s
537517
pub async fn get_models(&self) -> anyhow::Result<Vec<models::Model>> {
538-
let models;
539-
transaction_wrapper!(
540-
models,
541-
sqlx::query_as::<_, models::Model>(&query_builder!(
542-
"SELECT * from %s",
543-
self.models_table_name
544-
)),
545-
self.pool.borrow(),
546-
fetch_all
547-
);
548-
Ok(models)
518+
Ok(
519+
sqlx::query_as(&query_builder!("SELECT * from %s", self.models_table_name))
520+
.fetch_all(self.pool.borrow())
521+
.await?,
522+
)
549523
}
550524

551525
async fn create_or_get_embeddings_table(
@@ -554,17 +528,13 @@ impl Collection {
554528
splitter_id: i64,
555529
) -> anyhow::Result<String> {
556530
let pool = self.pool.borrow();
557-
let table_name;
558-
transaction_wrapper!(
559-
table_name,
560-
sqlx::query_as::<_, (String,)>(&query_builder!(
531+
let table_name: Option<(String,)> =
532+
sqlx::query_as(&query_builder!(
561533
"SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2",
562534
self.transforms_table_name))
563535
.bind(model_id)
564-
.bind(splitter_id),
565-
pool,
566-
fetch_optional
567-
);
536+
.bind(splitter_id)
537+
.fetch_optional(pool).await?;
568538
match table_name {
569539
Some((name,)) => Ok(name),
570540
None => {
@@ -573,12 +543,11 @@ impl Collection {
573543
self.name,
574544
&uuid::Uuid::new_v4().to_string()[0..6]
575545
);
576-
let embedding;
577-
transaction_wrapper!(embedding, sqlx::query_as::<_, (Vec<f32>,)>(&query_builder!(
546+
let embedding: (Vec<f32>,) = sqlx::query_as(&query_builder!(
578547
"WITH model as (SELECT name, parameters from %s where id = $1) SELECT embedding from pgml.embed(transformer => (SELECT name FROM model), text => 'Hello, World!', kwargs => (SELECT parameters FROM model)) as embedding",
579548
self.models_table_name))
580-
.bind(model_id),
581-
pool, fetch_one);
549+
.bind(model_id)
550+
.fetch_one(pool).await?;
582551
let embedding = embedding.0;
583552
let embedding_length = embedding.len() as i64;
584553
pool.execute(
@@ -591,15 +560,13 @@ impl Collection {
591560
.as_str(),
592561
)
593562
.await?;
594-
transaction_wrapper!(
595-
sqlx::query(&query_builder!(
596-
"INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)",
597-
self.transforms_table_name))
598-
.bind(&table_name)
599-
.bind(model_id)
600-
.bind(splitter_id),
601-
pool
602-
);
563+
sqlx::query(&query_builder!(
564+
"INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)",
565+
self.transforms_table_name))
566+
.bind(&table_name)
567+
.bind(model_id)
568+
.bind(splitter_id)
569+
.execute(pool).await?;
603570
pool.execute(
604571
query_builder!(
605572
queries::CREATE_INDEX,
@@ -677,18 +644,17 @@ impl Collection {
677644
.create_or_get_embeddings_table(model_id, splitter_id)
678645
.await?;
679646

680-
transaction_wrapper!(
681-
sqlx::query(&query_builder!(
682-
queries::GENERATE_EMBEDDINGS,
683-
self.models_table_name,
684-
embeddings_table_name,
685-
self.chunks_table_name,
686-
embeddings_table_name
687-
))
688-
.bind(model_id)
689-
.bind(splitter_id),
690-
self.pool.borrow()
691-
);
647+
sqlx::query(&query_builder!(
648+
queries::GENERATE_EMBEDDINGS,
649+
self.models_table_name,
650+
embeddings_table_name,
651+
self.chunks_table_name,
652+
embeddings_table_name
653+
))
654+
.bind(model_id)
655+
.bind(splitter_id)
656+
.execute(self.pool.borrow())
657+
.await?;
692658

693659
Ok(())
694660
}
@@ -751,17 +717,12 @@ impl Collection {
751717
let model_id = model_id.unwrap_or(1);
752718
let splitter_id = splitter_id.unwrap_or(1);
753719

754-
let embeddings_table_name;
755-
transaction_wrapper!(
756-
embeddings_table_name,
757-
sqlx::query_as::<_, (String,)>(&query_builder!(
720+
let embeddings_table_name: Option<(String,)> = sqlx::query_as(&query_builder!(
758721
"SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2",
759722
self.transforms_table_name))
760723
.bind(model_id)
761-
.bind(splitter_id),
762-
self.pool.borrow(),
763-
fetch_optional
764-
);
724+
.bind(splitter_id)
725+
.fetch_optional(self.pool.borrow()).await?;
765726

766727
let embeddings_table_name = match embeddings_table_name {
767728
Some((table_name,)) => table_name,
@@ -770,10 +731,8 @@ impl Collection {
770731
}
771732
};
772733

773-
let results;
774-
transaction_wrapper!(
775-
results,
776-
sqlx::query_as::<_, (f64, String, Json<HashMap<String, String>>)>(&query_builder!(
734+
let results: Vec<(f64, String, Json<HashMap<String, String>>)> =
735+
sqlx::query_as(&query_builder!(
777736
queries::VECTOR_SEARCH,
778737
self.models_table_name,
779738
embeddings_table_name,
@@ -784,10 +743,9 @@ impl Collection {
784743
.bind(model_id)
785744
.bind(query)
786745
.bind(query_params)
787-
.bind(top_k),
788-
self.pool.borrow(),
789-
fetch_all
790-
);
746+
.bind(top_k)
747+
.fetch_all(self.pool.borrow())
748+
.await?;
791749
let results: Vec<(f64, String, HashMap<String, String>)> =
792750
results.into_iter().map(|r| (r.0, r.1, r.2 .0)).collect();
793751
Ok(results)

0 commit comments

Comments
 (0)