@@ -13,7 +13,7 @@ use std::collections::HashMap;
13
13
use crate :: languages:: javascript:: * ;
14
14
use crate :: models;
15
15
use crate :: queries;
16
- use crate :: { query_builder, transaction_wrapper } ;
16
+ use crate :: query_builder;
17
17
18
18
/// A collection of documents
19
19
#[ derive( custom_derive, Debug , Clone ) ]
@@ -314,18 +314,16 @@ impl Collection {
314
314
} ;
315
315
let source_uuid = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
316
316
317
- transaction_wrapper ! (
318
- sqlx:: query( & query_builder!(
319
- "INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5" ,
320
- self . documents_table_name
321
- ) )
322
- . bind( & text)
323
- . bind( source_uuid)
324
- . bind( & document_json)
325
- . bind( & text)
326
- . bind( & document_json) ,
327
- self . pool. borrow( )
328
- ) ;
317
+ sqlx:: query ( & query_builder ! (
318
+ "INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5" ,
319
+ self . documents_table_name
320
+ ) )
321
+ . bind ( & text)
322
+ . bind ( source_uuid)
323
+ . bind ( & document_json)
324
+ . bind ( & text)
325
+ . bind ( & document_json)
326
+ . execute ( self . pool . borrow ( ) ) . await ?;
329
327
}
330
328
Ok ( ( ) )
331
329
}
@@ -363,18 +361,14 @@ impl Collection {
363
361
None => serde_json:: json!( { } ) ,
364
362
} ;
365
363
366
- let current_splitter;
367
- transaction_wrapper ! (
368
- current_splitter,
369
- sqlx:: query_as:: <_, models:: Splitter >( & query_builder!(
370
- "SELECT * from %s where name = $1 and parameters = $2;" ,
371
- self . splitters_table_name
372
- ) )
373
- . bind( & splitter_name)
374
- . bind( & splitter_params) ,
375
- self . pool. borrow( ) ,
376
- fetch_optional
377
- ) ;
364
+ let current_splitter: Option < models:: Splitter > = sqlx:: query_as ( & query_builder ! (
365
+ "SELECT * from %s where name = $1 and parameters = $2;" ,
366
+ self . splitters_table_name
367
+ ) )
368
+ . bind ( & splitter_name)
369
+ . bind ( & splitter_params)
370
+ . fetch_optional ( self . pool . borrow ( ) )
371
+ . await ?;
378
372
379
373
match current_splitter {
380
374
Some ( _splitter) => {
@@ -384,32 +378,27 @@ impl Collection {
384
378
) ;
385
379
}
386
380
None => {
387
- transaction_wrapper ! (
388
- sqlx:: query( & query_builder!(
389
- "INSERT INTO %s (name, parameters) VALUES ($1, $2)" ,
390
- self . splitters_table_name
391
- ) )
392
- . bind( splitter_name)
393
- . bind( splitter_params) ,
394
- self . pool. borrow( )
395
- ) ;
381
+ sqlx:: query ( & query_builder ! (
382
+ "INSERT INTO %s (name, parameters) VALUES ($1, $2)" ,
383
+ self . splitters_table_name
384
+ ) )
385
+ . bind ( splitter_name)
386
+ . bind ( splitter_params)
387
+ . execute ( self . pool . borrow ( ) )
388
+ . await ?;
396
389
}
397
390
}
398
391
Ok ( ( ) )
399
392
}
400
393
401
394
/// Gets all registered text [models::Splitter]s
402
395
pub async fn get_text_splitters ( & self ) -> anyhow:: Result < Vec < models:: Splitter > > {
403
- let splitters;
404
- transaction_wrapper ! (
405
- splitters,
406
- sqlx:: query_as:: <_, models:: Splitter >( & query_builder!(
407
- "SELECT * from %s" ,
408
- self . splitters_table_name
409
- ) ) ,
410
- self . pool. borrow( ) ,
411
- fetch_all
412
- ) ;
396
+ let splitters: Vec < models:: Splitter > = sqlx:: query_as ( & query_builder ! (
397
+ "SELECT * from %s" ,
398
+ self . splitters_table_name
399
+ ) )
400
+ . fetch_all ( self . pool . borrow ( ) )
401
+ . await ?;
413
402
Ok ( splitters)
414
403
}
415
404
@@ -443,17 +432,16 @@ impl Collection {
443
432
/// ```
444
433
pub async fn generate_chunks ( & self , splitter_id : Option < i64 > ) -> anyhow:: Result < ( ) > {
445
434
let splitter_id = splitter_id. unwrap_or ( 1 ) ;
446
- transaction_wrapper ! (
447
- sqlx:: query( & query_builder!(
448
- queries:: GENERATE_CHUNKS ,
449
- self . splitters_table_name,
450
- self . chunks_table_name,
451
- self . documents_table_name,
452
- self . chunks_table_name
453
- ) )
454
- . bind( splitter_id) ,
455
- self . pool. borrow( )
456
- ) ;
435
+ sqlx:: query ( & query_builder ! (
436
+ queries:: GENERATE_CHUNKS ,
437
+ self . splitters_table_name,
438
+ self . chunks_table_name,
439
+ self . documents_table_name,
440
+ self . chunks_table_name
441
+ ) )
442
+ . bind ( splitter_id)
443
+ . execute ( self . pool . borrow ( ) )
444
+ . await ?;
457
445
Ok ( ( ) )
458
446
}
459
447
@@ -492,19 +480,15 @@ impl Collection {
492
480
None => serde_json:: json!( { } ) ,
493
481
} ;
494
482
495
- let current_model;
496
- transaction_wrapper ! (
497
- current_model,
498
- sqlx:: query_as:: <_, models:: Model >( & query_builder!(
499
- "SELECT * from %s where task = $1 and name = $2 and parameters = $3;" ,
500
- self . models_table_name
501
- ) )
502
- . bind( & task)
503
- . bind( & model_name)
504
- . bind( & model_params) ,
505
- self . pool. borrow( ) ,
506
- fetch_optional
507
- ) ;
483
+ let current_model: Option < models:: Model > = sqlx:: query_as ( & query_builder ! (
484
+ "SELECT * from %s where task = $1 and name = $2 and parameters = $3;" ,
485
+ self . models_table_name
486
+ ) )
487
+ . bind ( & task)
488
+ . bind ( & model_name)
489
+ . bind ( & model_params)
490
+ . fetch_optional ( self . pool . borrow ( ) )
491
+ . await ?;
508
492
509
493
match current_model {
510
494
Some ( model) => {
@@ -515,37 +499,27 @@ impl Collection {
515
499
Ok ( model. id )
516
500
}
517
501
None => {
518
- let id;
519
- transaction_wrapper ! (
520
- id,
521
- sqlx:: query_as:: <_, ( i64 , ) >( & query_builder!(
522
- "INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id" ,
523
- self . models_table_name
524
- ) )
525
- . bind( task)
526
- . bind( model_name)
527
- . bind( model_params) ,
528
- self . pool. borrow( ) ,
529
- fetch_one
530
- ) ;
502
+ let id: ( i64 , ) = sqlx:: query_as ( & query_builder ! (
503
+ "INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id" ,
504
+ self . models_table_name
505
+ ) )
506
+ . bind ( task)
507
+ . bind ( model_name)
508
+ . bind ( model_params)
509
+ . fetch_one ( self . pool . borrow ( ) )
510
+ . await ?;
531
511
Ok ( id. 0 )
532
512
}
533
513
}
534
514
}
535
515
536
516
/// Gets all registered [models::Model]s
537
517
pub async fn get_models ( & self ) -> anyhow:: Result < Vec < models:: Model > > {
538
- let models;
539
- transaction_wrapper ! (
540
- models,
541
- sqlx:: query_as:: <_, models:: Model >( & query_builder!(
542
- "SELECT * from %s" ,
543
- self . models_table_name
544
- ) ) ,
545
- self . pool. borrow( ) ,
546
- fetch_all
547
- ) ;
548
- Ok ( models)
518
+ Ok (
519
+ sqlx:: query_as ( & query_builder ! ( "SELECT * from %s" , self . models_table_name) )
520
+ . fetch_all ( self . pool . borrow ( ) )
521
+ . await ?,
522
+ )
549
523
}
550
524
551
525
async fn create_or_get_embeddings_table (
@@ -554,17 +528,13 @@ impl Collection {
554
528
splitter_id : i64 ,
555
529
) -> anyhow:: Result < String > {
556
530
let pool = self . pool . borrow ( ) ;
557
- let table_name;
558
- transaction_wrapper ! (
559
- table_name,
560
- sqlx:: query_as:: <_, ( String , ) >( & query_builder!(
531
+ let table_name: Option < ( String , ) > =
532
+ sqlx:: query_as ( & query_builder ! (
561
533
"SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2" ,
562
534
self . transforms_table_name) )
563
535
. bind ( model_id)
564
- . bind( splitter_id) ,
565
- pool,
566
- fetch_optional
567
- ) ;
536
+ . bind ( splitter_id)
537
+ . fetch_optional ( pool) . await ?;
568
538
match table_name {
569
539
Some ( ( name, ) ) => Ok ( name) ,
570
540
None => {
@@ -573,12 +543,11 @@ impl Collection {
573
543
self . name,
574
544
& uuid:: Uuid :: new_v4( ) . to_string( ) [ 0 ..6 ]
575
545
) ;
576
- let embedding;
577
- transaction_wrapper ! ( embedding, sqlx:: query_as:: <_, ( Vec <f32 >, ) >( & query_builder!(
546
+ let embedding: ( Vec < f32 > , ) = sqlx:: query_as ( & query_builder ! (
578
547
"WITH model as (SELECT name, parameters from %s where id = $1) SELECT embedding from pgml.embed(transformer => (SELECT name FROM model), text => 'Hello, World!', kwargs => (SELECT parameters FROM model)) as embedding" ,
579
548
self . models_table_name) )
580
- . bind( model_id) ,
581
- pool , fetch_one) ;
549
+ . bind ( model_id)
550
+ . fetch_one ( pool ) . await ? ;
582
551
let embedding = embedding. 0 ;
583
552
let embedding_length = embedding. len ( ) as i64 ;
584
553
pool. execute (
@@ -591,15 +560,13 @@ impl Collection {
591
560
. as_str ( ) ,
592
561
)
593
562
. await ?;
594
- transaction_wrapper ! (
595
- sqlx:: query( & query_builder!(
596
- "INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)" ,
597
- self . transforms_table_name) )
598
- . bind( & table_name)
599
- . bind( model_id)
600
- . bind( splitter_id) ,
601
- pool
602
- ) ;
563
+ sqlx:: query ( & query_builder ! (
564
+ "INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)" ,
565
+ self . transforms_table_name) )
566
+ . bind ( & table_name)
567
+ . bind ( model_id)
568
+ . bind ( splitter_id)
569
+ . execute ( pool) . await ?;
603
570
pool. execute (
604
571
query_builder ! (
605
572
queries:: CREATE_INDEX ,
@@ -677,18 +644,17 @@ impl Collection {
677
644
. create_or_get_embeddings_table ( model_id, splitter_id)
678
645
. await ?;
679
646
680
- transaction_wrapper ! (
681
- sqlx:: query( & query_builder!(
682
- queries:: GENERATE_EMBEDDINGS ,
683
- self . models_table_name,
684
- embeddings_table_name,
685
- self . chunks_table_name,
686
- embeddings_table_name
687
- ) )
688
- . bind( model_id)
689
- . bind( splitter_id) ,
690
- self . pool. borrow( )
691
- ) ;
647
+ sqlx:: query ( & query_builder ! (
648
+ queries:: GENERATE_EMBEDDINGS ,
649
+ self . models_table_name,
650
+ embeddings_table_name,
651
+ self . chunks_table_name,
652
+ embeddings_table_name
653
+ ) )
654
+ . bind ( model_id)
655
+ . bind ( splitter_id)
656
+ . execute ( self . pool . borrow ( ) )
657
+ . await ?;
692
658
693
659
Ok ( ( ) )
694
660
}
@@ -751,17 +717,12 @@ impl Collection {
751
717
let model_id = model_id. unwrap_or ( 1 ) ;
752
718
let splitter_id = splitter_id. unwrap_or ( 1 ) ;
753
719
754
- let embeddings_table_name;
755
- transaction_wrapper ! (
756
- embeddings_table_name,
757
- sqlx:: query_as:: <_, ( String , ) >( & query_builder!(
720
+ let embeddings_table_name: Option < ( String , ) > = sqlx:: query_as ( & query_builder ! (
758
721
"SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2" ,
759
722
self . transforms_table_name) )
760
723
. bind ( model_id)
761
- . bind( splitter_id) ,
762
- self . pool. borrow( ) ,
763
- fetch_optional
764
- ) ;
724
+ . bind ( splitter_id)
725
+ . fetch_optional ( self . pool . borrow ( ) ) . await ?;
765
726
766
727
let embeddings_table_name = match embeddings_table_name {
767
728
Some ( ( table_name, ) ) => table_name,
@@ -770,10 +731,8 @@ impl Collection {
770
731
}
771
732
} ;
772
733
773
- let results;
774
- transaction_wrapper ! (
775
- results,
776
- sqlx:: query_as:: <_, ( f64 , String , Json <HashMap <String , String >>) >( & query_builder!(
734
+ let results: Vec < ( f64 , String , Json < HashMap < String , String > > ) > =
735
+ sqlx:: query_as ( & query_builder ! (
777
736
queries:: VECTOR_SEARCH ,
778
737
self . models_table_name,
779
738
embeddings_table_name,
@@ -784,10 +743,9 @@ impl Collection {
784
743
. bind ( model_id)
785
744
. bind ( query)
786
745
. bind ( query_params)
787
- . bind( top_k) ,
788
- self . pool. borrow( ) ,
789
- fetch_all
790
- ) ;
746
+ . bind ( top_k)
747
+ . fetch_all ( self . pool . borrow ( ) )
748
+ . await ?;
791
749
let results: Vec < ( f64 , String , HashMap < String , String > ) > =
792
750
results. into_iter ( ) . map ( |r| ( r. 0 , r. 1 , r. 2 . 0 ) ) . collect ( ) ;
793
751
Ok ( results)
0 commit comments