Skip to content

Commit 03ca54e

Browse files
authored
Embeddings support in the SDK (#1475)
1 parent a09fa86 commit 03ca54e

File tree

5 files changed

+91
-3
lines changed

5 files changed

+91
-3
lines changed

pgml-sdks/pgml/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-sdks/pgml/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "1.0.3"
3+
version = "1.0.4"
44
edition = "2021"
55
authors = ["PosgresML <team@postgresml.org>"]
66
homepage = "https://postgresml.org/"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pytest
2+
pytest-asyncio

pgml-sdks/pgml/python/tests/test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ def test_can_create_builtins():
7272
builtins = pgml.Builtins()
7373
assert builtins is not None
7474

75+
@pytest.mark.asyncio
76+
async def test_can_embed_with_builtins():
77+
builtins = pgml.Builtins()
78+
result = await builtins.embed("intfloat/e5-small-v2", "test")
79+
assert result is not None
80+
81+
@pytest.mark.asyncio
82+
async def test_can_embed_batch_with_builtins():
83+
builtins = pgml.Builtins()
84+
result = await builtins.embed_batch("intfloat/e5-small-v2", ["test"])
85+
assert result is not None
86+
7587

7688
###################################################
7789
## Test searches ##################################

pgml-sdks/pgml/src/builtins.rs

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use anyhow::Context;
12
use rust_bridge::{alias, alias_methods};
23
use sqlx::Row;
34
use tracing::instrument;
@@ -13,7 +14,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json};
1314
#[cfg(feature = "python")]
1415
use crate::{query_runner::QueryRunnerPython, types::JsonPython};
1516

16-
#[alias_methods(new, query, transform)]
17+
#[alias_methods(new, query, transform, embed, embed_batch)]
1718
impl Builtins {
1819
pub fn new(database_url: Option<String>) -> Self {
1920
Self { database_url }
@@ -87,6 +88,55 @@ impl Builtins {
8788
let results = results.first().unwrap().get::<serde_json::Value, _>(0);
8889
Ok(Json(results))
8990
}
91+
92+
/// Run the built-in `pgml.embed()` function.
93+
///
94+
/// # Arguments
95+
///
96+
/// * `model` - The model to use.
97+
/// * `text` - The text to embed.
98+
///
99+
pub async fn embed(&self, model: &str, text: &str) -> anyhow::Result<Json> {
100+
let pool = get_or_initialize_pool(&self.database_url).await?;
101+
let query = sqlx::query("SELECT embed FROM pgml.embed($1, $2)");
102+
let result = query.bind(model).bind(text).fetch_one(&pool).await?;
103+
let result = result.get::<Vec<f32>, _>(0);
104+
let result = serde_json::to_value(result)?;
105+
Ok(Json(result))
106+
}
107+
108+
/// Run the built-in `pgml.embed()` function, but with handling for batch inputs and outputs.
109+
///
110+
/// # Arguments
111+
///
112+
/// * `model` - The model to use.
113+
/// * `texts` - The texts to embed.
114+
///
115+
pub async fn embed_batch(&self, model: &str, texts: Json) -> anyhow::Result<Json> {
116+
let texts = texts
117+
.0
118+
.as_array()
119+
.with_context(|| "embed_batch takes an array of strings")?
120+
.into_iter()
121+
.map(|v| {
122+
v.as_str()
123+
.with_context(|| "only text embeddings are supported")
124+
.unwrap()
125+
.to_string()
126+
})
127+
.collect::<Vec<String>>();
128+
let pool = get_or_initialize_pool(&self.database_url).await?;
129+
let query = sqlx::query("SELECT embed AS embed_batch FROM pgml.embed($1, $2)");
130+
let results = query
131+
.bind(model)
132+
.bind(texts)
133+
.fetch_all(&pool)
134+
.await?
135+
.into_iter()
136+
.map(|embeddings| embeddings.get::<Vec<f32>, _>(0))
137+
.collect::<Vec<Vec<f32>>>();
138+
Ok(Json(serde_json::to_value(results)?))
139+
}
90140
}
91141

92142
#[cfg(test)]
@@ -117,4 +167,28 @@ mod tests {
117167
assert!(results.as_array().is_some());
118168
Ok(())
119169
}
170+
171+
#[tokio::test]
172+
async fn can_embed() -> anyhow::Result<()> {
173+
internal_init_logger(None, None).ok();
174+
let builtins = Builtins::new(None);
175+
let results = builtins.embed("intfloat/e5-small-v2", "test").await?;
176+
assert!(results.as_array().is_some());
177+
Ok(())
178+
}
179+
180+
#[tokio::test]
181+
async fn can_embed_batch() -> anyhow::Result<()> {
182+
internal_init_logger(None, None).ok();
183+
let builtins = Builtins::new(None);
184+
let results = builtins
185+
.embed_batch(
186+
"intfloat/e5-small-v2",
187+
Json(serde_json::json!(["test", "test2",])),
188+
)
189+
.await?;
190+
assert!(results.as_array().is_some());
191+
assert_eq!(results.as_array().unwrap().len(), 2);
192+
Ok(())
193+
}
120194
}

0 commit comments

Comments
 (0)