Skip to content

Commit 7cf4da0

Browse files
authored
collection.register_model() now returns the created or found model id (#759)
1 parent 97f7d4e commit 7cf4da0

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

pgml-sdks/rust/pgml/src/collection.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ impl Collection {
484484
task: Option<String>,
485485
model_name: Option<String>,
486486
model_params: Option<HashMap<String, String>>,
487-
) -> anyhow::Result<()> {
487+
) -> anyhow::Result<i64> {
488488
let task = task.unwrap_or("embedding".to_string());
489489
let model_name = model_name.unwrap_or("intfloat/e5-small".to_string());
490490
let model_params = match model_params {
@@ -507,27 +507,30 @@ impl Collection {
507507
);
508508

509509
match current_model {
510-
Some(_model) => {
510+
Some(model) => {
511511
warn!(
512512
"Model with name: {} and parameters: {:?} already exists",
513513
model_name, model_params
514514
);
515+
Ok(model.id)
515516
}
516517
None => {
518+
let id;
517519
transaction_wrapper!(
518-
sqlx::query(&query_builder!(
519-
"INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3)",
520+
id,
521+
sqlx::query_as::<_, (i64,)>(&query_builder!(
522+
"INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id",
520523
self.models_table_name
521524
))
522525
.bind(task)
523526
.bind(model_name)
524527
.bind(model_params),
525-
self.pool.borrow()
528+
self.pool.borrow(),
529+
fetch_one
526530
);
531+
Ok(id.0)
527532
}
528533
}
529-
530-
Ok(())
531534
}
532535

533536
/// Gets all registered [models::Model]s

0 commit comments

Comments
 (0)