Skip to content

Commit 5e80c26

Browse files
authored
SDK - Better async streaming and 1.0.2 bump
1 parent f2f5506 commit 5e80c26

File tree

5 files changed

+40
-114
lines changed

5 files changed

+40
-114
lines changed

pgml-sdks/pgml/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "1.0.1"
3+
version = "1.0.2"
44
edition = "2021"
55
authors = ["PosgresML <team@postgresml.org>"]
66
homepage = "https://postgresml.org/"

pgml-sdks/pgml/javascript/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "pgml",
3-
"version": "1.0.1",
3+
"version": "1.0.2",
44
"description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone",
55
"keywords": [
66
"postgres",

pgml-sdks/pgml/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "maturin"
55
[project]
66
name = "pgml"
77
requires-python = ">=3.7"
8-
version = "1.0.1"
8+
version = "1.0.2"
99
description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases."
1010
authors = [
1111
{name = "PostgresML", email = "team@postgresml.org"},

pgml-sdks/pgml/src/transformer_pipeline.rs

Lines changed: 34 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
use anyhow::Context;
2-
use futures::Stream;
32
use rust_bridge::{alias, alias_methods};
4-
use sqlx::{postgres::PgRow, Row};
5-
use sqlx::{Postgres, Transaction};
6-
use std::collections::VecDeque;
7-
use std::future::Future;
8-
use std::pin::Pin;
9-
use std::task::Poll;
3+
use sqlx::Row;
104
use tracing::instrument;
115

126
/// Provides access to builtin database methods
@@ -22,99 +16,6 @@ use crate::{get_or_initialize_pool, types::Json};
2216
#[cfg(feature = "python")]
2317
use crate::types::{GeneralJsonAsyncIteratorPython, JsonPython};
2418

25-
#[allow(clippy::type_complexity)]
26-
struct TransformerStream {
27-
transaction: Option<Transaction<'static, Postgres>>,
28-
future: Option<Pin<Box<dyn Future<Output = Result<Vec<PgRow>, sqlx::Error>> + Send + 'static>>>,
29-
commit: Option<Pin<Box<dyn Future<Output = Result<(), sqlx::Error>> + Send + 'static>>>,
30-
done: bool,
31-
query: String,
32-
db_batch_size: i32,
33-
results: VecDeque<PgRow>,
34-
}
35-
36-
impl std::fmt::Debug for TransformerStream {
37-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38-
f.debug_struct("TransformerStream").finish()
39-
}
40-
}
41-
42-
impl TransformerStream {
43-
fn new(transaction: Transaction<'static, Postgres>, db_batch_size: i32) -> Self {
44-
let query = format!("FETCH {} FROM c", db_batch_size);
45-
Self {
46-
transaction: Some(transaction),
47-
future: None,
48-
commit: None,
49-
done: false,
50-
query,
51-
db_batch_size,
52-
results: VecDeque::new(),
53-
}
54-
}
55-
}
56-
57-
impl Stream for TransformerStream {
58-
type Item = anyhow::Result<Json>;
59-
60-
fn poll_next(
61-
mut self: Pin<&mut Self>,
62-
cx: &mut std::task::Context<'_>,
63-
) -> Poll<Option<Self::Item>> {
64-
if self.done {
65-
if let Some(c) = self.commit.as_mut() {
66-
if c.as_mut().poll(cx).is_ready() {
67-
self.commit = None;
68-
}
69-
}
70-
} else {
71-
if self.future.is_none() {
72-
unsafe {
73-
let s = self.as_mut().get_unchecked_mut();
74-
let s: *mut Self = s;
75-
let s = Box::leak(Box::from_raw(s));
76-
s.future = Some(Box::pin(
77-
sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()),
78-
));
79-
}
80-
}
81-
82-
if let Poll::Ready(o) = self.as_mut().future.as_mut().unwrap().as_mut().poll(cx) {
83-
let rows = o?;
84-
if rows.len() < self.db_batch_size as usize {
85-
self.done = true;
86-
unsafe {
87-
let s = self.as_mut().get_unchecked_mut();
88-
let transaction = std::mem::take(&mut s.transaction).unwrap();
89-
s.commit = Some(Box::pin(transaction.commit()));
90-
}
91-
} else {
92-
unsafe {
93-
let s = self.as_mut().get_unchecked_mut();
94-
let s: *mut Self = s;
95-
let s = Box::leak(Box::from_raw(s));
96-
s.future = Some(Box::pin(
97-
sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()),
98-
));
99-
}
100-
}
101-
for r in rows.into_iter() {
102-
self.results.push_back(r)
103-
}
104-
}
105-
}
106-
107-
if !self.results.is_empty() {
108-
let r = self.results.pop_front().unwrap();
109-
Poll::Ready(Some(Ok(r.get::<Json, _>(0))))
110-
} else if self.done {
111-
Poll::Ready(None)
112-
} else {
113-
Poll::Pending
114-
}
115-
}
116-
}
117-
11819
#[alias_methods(new, transform, transform_stream)]
11920
impl TransformerPipeline {
12021
/// Creates a new [TransformerPipeline]
@@ -200,7 +101,7 @@ impl TransformerPipeline {
200101
) -> anyhow::Result<GeneralJsonAsyncIterator> {
201102
let pool = get_or_initialize_pool(&self.database_url).await?;
202103
let args = args.unwrap_or_default();
203-
let batch_size = batch_size.unwrap_or(10);
104+
let batch_size = batch_size.unwrap_or(1);
204105

205106
let mut transaction = pool.begin().await?;
206107
// We set the task in the new constructor so we can unwrap here
@@ -234,10 +135,37 @@ impl TransformerPipeline {
234135
.await?;
235136
}
236137

237-
Ok(GeneralJsonAsyncIterator(Box::pin(TransformerStream::new(
238-
transaction,
239-
batch_size,
240-
))))
138+
let s = futures::stream::try_unfold(transaction, move |mut transaction| async move {
139+
let query = format!("FETCH {} FROM c", batch_size);
140+
let mut res: Vec<Json> = sqlx::query_scalar(&query)
141+
.fetch_all(&mut *transaction)
142+
.await?;
143+
if !res.is_empty() {
144+
if batch_size > 1 {
145+
let res: Vec<String> = res
146+
.into_iter()
147+
.map(|v| {
148+
v.0.as_array()
149+
.context("internal SDK error - cannot parse db value as array. Please post a new github issue")
150+
.map(|v| {
151+
v[0].as_str()
152+
.context(
153+
"internal SDK error - cannot parse db value as string. Please post a new github issue",
154+
)
155+
.map(|v| v.to_owned())
156+
})
157+
})
158+
.collect::<anyhow::Result<anyhow::Result<Vec<String>>>>()??;
159+
Ok(Some((serde_json::json!(res).into(), transaction)))
160+
} else {
161+
Ok(Some((std::mem::take(&mut res[0]), transaction)))
162+
}
163+
} else {
164+
transaction.commit().await?;
165+
Ok(None)
166+
}
167+
});
168+
Ok(GeneralJsonAsyncIterator(Box::pin(s)))
241169
}
242170
}
243171

@@ -305,7 +233,7 @@ mod tests {
305233
serde_json::json!("AI is going to").into(),
306234
Some(
307235
serde_json::json!({
308-
"max_new_tokens": 10
236+
"max_new_tokens": 30
309237
})
310238
.into(),
311239
),

pgml-sdks/pgml/src/types.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use anyhow::Context;
2-
use futures::{Stream, StreamExt};
2+
use futures::{stream::BoxStream, Stream, StreamExt};
33
use itertools::Itertools;
44
use rust_bridge::alias_manual;
55
use sea_query::Iden;
@@ -123,11 +123,9 @@ impl IntoTableNameAndSchema for String {
123123
}
124124
}
125125

126-
/// A wrapper around `std::pin::Pin<Box<dyn Stream<Item = anyhow::Result<Json>> + Send>>`
126+
/// A wrapper around `BoxStream<'static, anyhow::Result<Json>>`
127127
#[derive(alias_manual)]
128-
pub struct GeneralJsonAsyncIterator(
129-
pub std::pin::Pin<Box<dyn Stream<Item = anyhow::Result<Json>> + Send>>,
130-
);
128+
pub struct GeneralJsonAsyncIterator(pub BoxStream<'static, anyhow::Result<Json>>);
131129

132130
impl Stream for GeneralJsonAsyncIterator {
133131
type Item = anyhow::Result<Json>;

0 commit comments

Comments
 (0)