diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 5b8ddc4e7..bb97b31e8 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -4,8 +4,6 @@ use std::str::FromStr; use ndarray::Zip; use pgrx::iter::{SetOfIterator, TableIterator}; use pgrx::*; -use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyDict}; #[cfg(feature = "python")] use serde_json::json; @@ -634,40 +632,6 @@ pub fn transform_string( } } -struct TransformStreamIterator { - locals: Py, -} - -impl TransformStreamIterator { - fn new(python_iter: Py) -> Self { - let locals = Python::with_gil(|py| -> Result, PyErr> { - Ok([("python_iter", python_iter)].into_py_dict(py).into()) - }) - .map_err(|e| error!("{e}")) - .unwrap(); - Self { locals } - } -} - -impl Iterator for TransformStreamIterator { - type Item = String; - fn next(&mut self) -> Option { - // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - Python::with_gil(|py| -> Result, PyErr> { - let code = "next(python_iter)"; - let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?; - if res.is_none() { - Ok(None) - } else { - let res: String = res.extract()?; - Ok(Some(res)) - } - }) - .map_err(|e| error!("{e}")) - .unwrap() - } -} - #[cfg(all(feature = "python", not(feature = "use_as_lib")))] #[pg_extern(immutable, parallel_safe, name = "transform_stream")] #[allow(unused_variables)] // cache is maintained for api compatibility @@ -678,11 +642,11 @@ pub fn transform_stream_json( cache: default!(bool, false), ) -> SetOfIterator<'static, String> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, input) - .map_err(|e| error!("{e}")) - .unwrap(); - let res = TransformStreamIterator::new(python_iter); - SetOfIterator::new(res) + let python_iter = + crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); + SetOfIterator::new(python_iter) } #[cfg(all(feature = "python", not(feature = "use_as_lib")))] @@ -696,11 +660,11 @@ pub fn transform_stream_string( ) -> SetOfIterator<'static, String> { let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = crate::bindings::transformers::transform_stream(&task_json, &args.0, input) - .map_err(|e| error!("{e}")) - .unwrap(); - let res = TransformStreamIterator::new(python_iter); - SetOfIterator::new(res) + let python_iter = + crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); + SetOfIterator::new(python_iter) } #[cfg(feature = "python")] diff --git a/pgml-extension/src/bindings/transformers/transformers.rs b/pgml-extension/src/bindings/transformers/transformers.rs index 55d59b070..6b89dd2a8 100644 --- a/pgml-extension/src/bindings/transformers/transformers.rs +++ b/pgml-extension/src/bindings/transformers/transformers.rs @@ -1,17 +1,52 @@ use super::whitelist; use super::TracebackError; use anyhow::Result; +use pgrx::*; use pyo3::prelude::*; -use pyo3::types::PyTuple; +use pyo3::types::{IntoPyDict, PyDict, PyTuple}; + create_pymodule!("/src/bindings/transformers/transformers.py"); +pub struct TransformStreamIterator { + locals: Py, +} + +impl TransformStreamIterator { + fn new(python_iter: Py) -> Self { + let locals = Python::with_gil(|py| -> Result, PyErr> { + Ok([("python_iter", python_iter)].into_py_dict(py).into()) + }) + .map_err(|e| error!("{e}")) + .unwrap(); + Self { locals } + } +} + +impl Iterator for TransformStreamIterator { + type Item = String; + fn next(&mut self) -> Option { + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + Python::with_gil(|py| -> Result, PyErr> { + let code = "next(python_iter)"; + let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?; + if res.is_none() { + Ok(None) + } else { + let res: String = res.extract()?; + Ok(Some(res)) + } + }) + .map_err(|e| error!("{e}")) + .unwrap() + } +} + pub fn transform( task: &serde_json::Value, args: &serde_json::Value, inputs: Vec<&str>, ) -> Result { crate::bindings::python::activate()?; - whitelist::verify_task(task)?; let task = serde_json::to_string(task)?; @@ -45,7 +80,6 @@ pub fn transform_stream( input: &str, ) -> Result> { crate::bindings::python::activate()?; - whitelist::verify_task(task)?; let task = serde_json::to_string(task)?; @@ -75,3 +109,14 @@ pub fn transform_stream( Ok(output) }) } + +pub fn transform_stream_iterator( + task: &serde_json::Value, + args: &serde_json::Value, + input: &str, +) -> Result { + let python_iter = transform_stream(task, args, input) + .map_err(|e| error!("{e}")) + .unwrap(); + Ok(TransformStreamIterator::new(python_iter)) +}