Skip to content

Commit 8b5c00d

Browse files
authored
Fixed batch embedding (#1452)
1 parent 63a8f4a commit 8b5c00d

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

pgml-extension/src/api.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -589,17 +589,21 @@ fn load_dataset(
589589
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
590590
#[pg_extern(immutable, parallel_safe, name = "embed")]
591591
pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec<f32> {
592-
embed_batch(transformer, Vec::from([text]), kwargs)
593-
.first()
594-
.unwrap()
595-
.to_vec()
592+
match crate::bindings::transformers::embed(transformer, vec![text], &kwargs.0) {
593+
Ok(output) => output.first().unwrap().to_vec(),
594+
Err(e) => error!("{e}"),
595+
}
596596
}
597597

598598
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
599599
#[pg_extern(immutable, parallel_safe, name = "embed")]
600-
pub fn embed_batch(transformer: &str, inputs: Vec<&str>, kwargs: default!(JsonB, "'{}'")) -> Vec<Vec<f32>> {
600+
pub fn embed_batch(
601+
transformer: &str,
602+
inputs: Vec<&str>,
603+
kwargs: default!(JsonB, "'{}'"),
604+
) -> SetOfIterator<'static, Vec<f32>> {
601605
match crate::bindings::transformers::embed(transformer, inputs, &kwargs.0) {
602-
Ok(output) => output,
606+
Ok(output) => SetOfIterator::new(output.into_iter()),
603607
Err(e) => error!("{e}"),
604608
}
605609
}

0 commit comments

Comments
 (0)