diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index ab3411447..a4da7bcbe 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "2.8.0" +version = "2.8.1" edition = "2021" [lib] diff --git a/pgml-extension/sql/pgml--2.8.0--2.8.1.sql b/pgml-extension/sql/pgml--2.8.0--2.8.1.sql new file mode 100644 index 000000000..f5d364156 --- /dev/null +++ b/pgml-extension/sql/pgml--2.8.0--2.8.1.sql @@ -0,0 +1,67 @@ +-- pgml::api::transform_conversational_json +CREATE FUNCTION pgml."transform"( + "task" jsonb, /* pgrx::datum::json::JsonB */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ + "cache" bool DEFAULT false /* bool */ +) RETURNS jsonb /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_conversational_json_wrapper'; + +-- pgml::api::transform_conversational_string +CREATE FUNCTION pgml."transform"( + "task" TEXT, /* alloc::string::String */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ + "cache" bool DEFAULT false /* bool */ +) RETURNS jsonb /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_conversational_string_wrapper'; + +-- pgml::api::transform_stream_string +DROP FUNCTION IF EXISTS pgml."transform_stream"(text,jsonb,text,boolean); +CREATE FUNCTION pgml."transform_stream"( + "task" TEXT, /* alloc::string::String */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "input" TEXT DEFAULT '', /* &str */ + "cache" bool DEFAULT false /* bool */ +) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_stream_string_wrapper'; + +-- pgml::api::transform_stream_json +DROP FUNCTION IF EXISTS pgml."transform_stream"(jsonb,jsonb,text,boolean); +CREATE FUNCTION pgml."transform_stream"( + "task" jsonb, /* pgrx::datum::json::JsonB */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "input" TEXT DEFAULT '', /* &str */ + "cache" bool DEFAULT false /* bool */ +) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_stream_json_wrapper'; + +-- pgml::api::transform_stream_conversational_json +CREATE FUNCTION pgml."transform_stream"( + "task" TEXT, /* alloc::string::String */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ + "cache" bool DEFAULT false /* bool */ +) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_stream_conversational_string_wrapper'; + +-- pgml::api::transform_stream_conversational_string +CREATE FUNCTION pgml."transform_stream"( + "task" jsonb, /* pgrx::datum::json::JsonB */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ + "cache" bool DEFAULT false /* bool */ +) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_stream_conversational_json_wrapper'; diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index ab132bc4c..3bf663026 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -632,6 +632,50 @@ pub fn transform_string( } } +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_conversational_json( + task: JsonB, + args: default!(JsonB, "'{}'"), + inputs: default!(Vec, "ARRAY[]::JSONB[]"), + cache: default!(bool, false), +) -> JsonB { + if !task.0["task"] + .as_str() + .is_some_and(|v| v == "conversational") + { + error!( + "ARRAY[]::JSONB inputs for transform should only be used with a conversational task" + ); + } + match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { + Ok(output) => JsonB(output), + Err(e) => error!("{e}"), + } +} + +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_conversational_string( + task: String, + args: default!(JsonB, "'{}'"), + inputs: default!(Vec, "ARRAY[]::JSONB[]"), + cache: default!(bool, false), +) -> JsonB { + if task != "conversational" { + error!( + "ARRAY[]::JSONB inputs for transform should only be used with a conversational task" + ); + } + let task_json = json!({ "task": task }); + match crate::bindings::transformers::transform(&task_json, &args.0, inputs) { + Ok(output) => JsonB(output), + Err(e) => error!("{e}"), + } +} + #[cfg(all(feature = "python", not(feature = "use_as_lib")))] #[pg_extern(immutable, parallel_safe, name = "transform_stream")] #[allow(unused_variables)] // cache is maintained for api compatibility @@ -640,7 +684,7 @@ pub fn transform_stream_json( args: default!(JsonB, "'{}'"), input: default!(&str, "''"), cache: default!(bool, false), -) -> SetOfIterator<'static, String> { +) -> SetOfIterator<'static, JsonB> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call let python_iter = crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input) @@ -657,7 +701,7 @@ pub fn transform_stream_string( args: default!(JsonB, "'{}'"), input: default!(&str, "''"), cache: default!(bool, false), -) -> SetOfIterator<'static, String> { +) -> SetOfIterator<'static, JsonB> { let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call let python_iter = @@ -667,6 +711,54 @@ pub fn transform_stream_string( SetOfIterator::new(python_iter) } +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform_stream")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_stream_conversational_json( + task: JsonB, + args: default!(JsonB, "'{}'"), + inputs: default!(Vec, "ARRAY[]::JSONB[]"), + cache: default!(bool, false), +) -> SetOfIterator<'static, JsonB> { + if !task.0["task"] + .as_str() + .is_some_and(|v| v == "conversational") + { + error!( + "ARRAY[]::JSONB inputs for transform_stream should only be used with a conversational task" + ); + } + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + let python_iter = + crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs) + .map_err(|e| error!("{e}")) + .unwrap(); + SetOfIterator::new(python_iter) +} + +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform_stream")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_stream_conversational_string( + task: String, + args: default!(JsonB, "'{}'"), + inputs: default!(Vec, "ARRAY[]::JSONB[]"), + cache: default!(bool, false), +) -> SetOfIterator<'static, JsonB> { + if task != "conversational" { + error!( + "ARRAY::JSONB inputs for transform_stream should only be used with a conversational task" + ); + } + let task_json = json!({ "task": task }); + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + let python_iter = + crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs) + .map_err(|e| error!("{e}")) + .unwrap(); + SetOfIterator::new(python_iter) +} + #[cfg(feature = "python")] #[pg_extern(immutable, parallel_safe, name = "generate")] fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String { diff --git a/pgml-extension/src/bindings/transformers/transform.rs b/pgml-extension/src/bindings/transformers/transform.rs index a03c0d751..fa03984d9 100644 --- a/pgml-extension/src/bindings/transformers/transform.rs +++ b/pgml-extension/src/bindings/transformers/transform.rs @@ -23,17 +23,17 @@ impl TransformStreamIterator { } impl Iterator for TransformStreamIterator { - type Item = String; + type Item = JsonB; fn next(&mut self) -> Option { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - Python::with_gil(|py| -> Result, PyErr> { + Python::with_gil(|py| -> Result, PyErr> { let code = "next(python_iter)"; let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?; if res.is_none() { Ok(None) } else { - let res: String = res.extract()?; - Ok(Some(res)) + let res: Vec = res.extract()?; + Ok(Some(JsonB(serde_json::to_value(res).unwrap()))) } }) .map_err(|e| error!("{e}")) @@ -41,10 +41,10 @@ impl Iterator for TransformStreamIterator { } } -pub fn transform( +pub fn transform( task: &serde_json::Value, args: &serde_json::Value, - inputs: Vec<&str>, + inputs: T, ) -> Result { crate::bindings::python::activate()?; whitelist::verify_task(task)?; @@ -74,17 +74,17 @@ pub fn transform( Ok(serde_json::from_str(&results)?) } -pub fn transform_stream( +pub fn transform_stream( task: &serde_json::Value, args: &serde_json::Value, - input: &str, + input: T, ) -> Result> { crate::bindings::python::activate()?; whitelist::verify_task(task)?; let task = serde_json::to_string(task)?; let args = serde_json::to_string(args)?; - let inputs = serde_json::to_string(&vec![input])?; + let input = serde_json::to_string(&input)?; Python::with_gil(|py| -> Result> { let transform: Py = get_module!(PY_MODULE) @@ -99,7 +99,7 @@ pub fn transform_stream( &[ task.into_py(py), args.into_py(py), - inputs.into_py(py), + input.into_py(py), true.into_py(py), ], ), @@ -110,10 +110,10 @@ pub fn transform_stream( }) } -pub fn transform_stream_iterator( +pub fn transform_stream_iterator( task: &serde_json::Value, args: &serde_json::Value, - input: &str, + input: T, ) -> Result { let python_iter = transform_stream(task, args, input) .map_err(|e| error!("{e}")) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 143f6d393..5c6078785 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -40,7 +40,6 @@ PegasusTokenizer, TrainingArguments, Trainer, - TextStreamer, ) from threading import Thread from typing import Optional @@ -94,21 +93,31 @@ def ensure_device(kwargs): else: kwargs["device"] = "cpu" -# A copy of HuggingFace's with small changes in the __next__ to not raise an exception -class TextIteratorStreamer(TextStreamer): - def __init__( - self, tokenizer, skip_prompt = False, timeout = None, **decode_kwargs - ): - super().__init__(tokenizer, skip_prompt, **decode_kwargs) - self.text_queue = queue.Queue() - self.stop_signal = None + +# Follows BaseStreamer template from transformers library +class TextIteratorStreamer: + def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs): + self.tokenizer = tokenizer + self.skip_prompt = skip_prompt self.timeout = timeout + self.decode_kwargs = decode_kwargs + self.next_tokens_are_prompt = True + self.stop_signal = None + self.text_queue = queue.Queue() - def on_finalized_text(self, text: str, stream_end: bool = False): - """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" - self.text_queue.put(text, timeout=self.timeout) - if stream_end: - self.text_queue.put(self.stop_signal, timeout=self.timeout) + def put(self, value): + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + # Can't batch this decode + decoded_values = [] + for v in value: + decoded_values.append(self.tokenizer.decode(v, **self.decode_kwargs)) + self.text_queue.put(decoded_values, self.timeout) + + def end(self): + self.next_tokens_are_prompt = True + self.text_queue.put(self.stop_signal, self.timeout) def __iter__(self): return self @@ -118,44 +127,27 @@ def __next__(self): if value != self.stop_signal: return value - -class GPTQPipeline(object): +class GGMLPipeline(object): def __init__(self, model_name, **task): - from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig - from huggingface_hub import snapshot_download - - model_path = snapshot_download(model_name) + import ctransformers - quantized_config = BaseQuantizeConfig.from_pretrained(model_path) - self.model = AutoGPTQForCausalLM.from_quantized( - model_path, quantized_config=quantized_config, **task + task.pop("model") + task.pop("task") + task.pop("device") + self.model = ctransformers.AutoModelForCausalLM.from_pretrained( + model_name, **task ) - if "use_fast_tokenizer" in task: - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, use_fast=task.pop("use_fast_tokenizer") - ) - else: - self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.tokenizer = None self.task = "text-generation" def stream(self, inputs, **kwargs): - streamer = TextIteratorStreamer(self.tokenizer) - inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) - generation_kwargs = dict(inputs, streamer=streamer, **kwargs) - thread = Thread(target=self.model.generate, kwargs=generation_kwargs) - thread.start() - return streamer + output = self.model(inputs[0], stream=True, **kwargs) + return ThreadedGeneratorIterator(output, inputs[0]) def __call__(self, inputs, **kwargs): outputs = [] for input in inputs: - tokens = ( - self.tokenizer(input, return_tensors="pt") - .to(self.model.device) - .input_ids - ) - token_ids = self.model.generate(input_ids=tokens, **kwargs)[0] - outputs.append(self.tokenizer.decode(token_ids)) + outputs.append(self.model(input, **kwargs)) return outputs @@ -184,36 +176,18 @@ def __next__(self): return v -class GGMLPipeline(object): - def __init__(self, model_name, **task): - import ctransformers - - task.pop("model") - task.pop("task") - task.pop("device") - self.model = ctransformers.AutoModelForCausalLM.from_pretrained( - model_name, **task - ) - self.tokenizer = None - self.task = "text-generation" - - def stream(self, inputs, **kwargs): - output = self.model(inputs[0], stream=True, **kwargs) - return ThreadedGeneratorIterator(output, inputs[0]) - - def __call__(self, inputs, **kwargs): - outputs = [] - for input in inputs: - outputs.append(self.model(input, **kwargs)) - return outputs - - class StandardPipeline(object): def __init__(self, model_name, **kwargs): # the default pipeline constructor doesn't pass all the kwargs (particularly load_in_4bit) # to the model constructor, so we construct the model/tokenizer manually if possible, # but that is only possible when the task is passed in, since if you pass the model # to the pipeline constructor, the task will no longer be inferred from the default... + + # See: https://huggingface.co/docs/hub/security-tokens + # This renaming is for backwards compatability + if "use_auth_token" in kwargs: + kwargs["token"] = kwargs.pop("use_auth_token") + if ( "task" in kwargs and model_name is not None @@ -224,6 +198,7 @@ def __init__(self, model_name, **kwargs): "summarization", "translation", "text-generation", + "conversational", ] ): self.task = kwargs.pop("task") @@ -238,14 +213,14 @@ def __init__(self, model_name, **kwargs): ) elif self.task == "summarization" or self.task == "translation": self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs) - elif self.task == "text-generation": + elif self.task == "text-generation" or self.task == "conversational": self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) else: raise PgMLException(f"Unhandled task: {self.task}") - if "use_auth_token" in kwargs: + if "token" in kwargs: self.tokenizer = AutoTokenizer.from_pretrained( - model_name, use_auth_token=kwargs["use_auth_token"] + model_name, use_auth_token=kwargs["token"] ) else: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -257,24 +232,66 @@ def __init__(self, model_name, **kwargs): ) else: self.pipe = transformers.pipeline(**kwargs) + self.tokenizer = self.pipe.tokenizer self.task = self.pipe.task self.model = self.pipe.model - if self.pipe.tokenizer is None: - self.pipe.tokenizer = AutoTokenizer.from_pretrained( - self.model.name_or_path - ) - self.tokenizer = self.pipe.tokenizer - def stream(self, inputs, **kwargs): - streamer = TextIteratorStreamer(self.tokenizer) - inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) - generation_kwargs = dict(inputs, streamer=streamer, **kwargs) + # Make sure we set the pad token if it does not exist + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def stream(self, input, **kwargs): + streamer = None + generation_kwargs = None + if self.task == "conversational": + streamer = TextIteratorStreamer( + self.tokenizer, skip_prompt=True, skip_special_tokens=True + ) + if "chat_template" in kwargs: + input = self.tokenizer.apply_chat_template( + input, + add_generation_prompt=True, + tokenize=False, + chat_template=kwargs.pop("chat_template"), + ) + else: + input = self.tokenizer.apply_chat_template( + input, add_generation_prompt=True, tokenize=False + ) + input = self.tokenizer(input, return_tensors="pt").to(self.model.device) + generation_kwargs = dict(input, streamer=streamer, **kwargs) + else: + streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) + input = self.tokenizer(input, return_tensors="pt", padding=True).to( + self.model.device + ) + generation_kwargs = dict(input, streamer=streamer, **kwargs) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() return streamer def __call__(self, inputs, **kwargs): - return self.pipe(inputs, **kwargs) + if self.task == "conversational": + if "chat_template" in kwargs: + inputs = self.tokenizer.apply_chat_template( + inputs, + add_generation_prompt=True, + tokenize=False, + chat_template=kwargs.pop("chat_template"), + ) + else: + inputs = self.tokenizer.apply_chat_template( + inputs, add_generation_prompt=True, tokenize=False + ) + inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) + args = dict(inputs, **kwargs) + outputs = self.model.generate(**args) + # We only want the new ouputs for conversational pipelines + outputs = outputs[:, inputs["input_ids"].shape[1] :] + outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + return outputs + else: + return self.pipe(inputs, **kwargs) def get_model_from(task): @@ -303,8 +320,6 @@ def create_pipeline(task): lower = None if lower and ("-ggml" in lower or "-gguf" in lower): pipe = GGMLPipeline(model_name, **task) - elif lower and "-gptq" in lower and not (model_type == "mistral" or model_type == "llama"): - pipe = GPTQPipeline(model_name, **task) else: try: pipe = StandardPipeline(model_name, **task) diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index 4f476884f..82b51670c 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -14,7 +14,7 @@ const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" export function init_logger(level?: string, format?: string): void; export function migrate(): Promise; -export type Json = { [key: string]: any }; +export type Json = any; export type DateTime = Date; export function newCollection(name: string, database_url?: string): Collection; @@ -23,6 +23,7 @@ export function newSplitter(name?: string, parameters?: Json): Splitter; export function newBuiltins(database_url?: string): Builtins; export function newPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline; export function newTransformerPipeline(task: string, model?: string, args?: Json, database_url?: string): TransformerPipeline; +export function newOpenSourceAI(database_url?: string): OpenSourceAI; "#; fn main() { diff --git a/pgml-sdks/pgml/javascript/tests/jest.config.js b/pgml-sdks/pgml/javascript/tests/jest.config.js index 7e67de525..7cf8a2c1e 100644 --- a/pgml-sdks/pgml/javascript/tests/jest.config.js +++ b/pgml-sdks/pgml/javascript/tests/jest.config.js @@ -4,5 +4,6 @@ export default { roots: [''], transform: { '^.+\\.tsx?$': 'ts-jest' - } + }, + testTimeout: 30000, } diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index affb314fa..acc766bd8 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -299,7 +299,93 @@ it("can transformer pipeline stream", async () => { output.push(result.value); result = await it.next(); } - expect(output.length).toBeGreaterThan(0) + expect(output.length).toBeGreaterThan(0); +}); + +/////////////////////////////////////////////////// +// Test OpenSourceAI ////////////////////////////// +/////////////////////////////////////////////////// + +it("can open source ai create", () => { + const client = pgml.newOpenSourceAI(); + const results = client.chat_completions_create( + "mistralai/Mistral-7B-v0.1", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], + ); + expect(results.choices.length).toBeGreaterThan(0); +}); + + +it("can open source ai create async", async () => { + const client = pgml.newOpenSourceAI(); + const results = await client.chat_completions_create_async( + "mistralai/Mistral-7B-v0.1", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], + ); + expect(results.choices.length).toBeGreaterThan(0); +}); + + +it("can open source ai create stream", () => { + const client = pgml.newOpenSourceAI(); + const it = client.chat_completions_create_stream( + "mistralai/Mistral-7B-v0.1", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], + ); + let result = it.next(); + while (!result.done) { + expect(result.value.choices.length).toBeGreaterThan(0); + result = it.next(); + } +}); + +it("can open source ai create stream async", async () => { + const client = pgml.newOpenSourceAI(); + const it = await client.chat_completions_create_stream_async( + "mistralai/Mistral-7B-v0.1", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], + ); + let result = await it.next(); + while (!result.done) { + expect(result.value.choices.length).toBeGreaterThan(0); + result = await it.next(); + } }); /////////////////////////////////////////////////// diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 97ca155f5..f3b1fbec9 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -307,7 +307,8 @@ async def test_order_documents(): async def test_transformer_pipeline(): t = pgml.TransformerPipeline("text-generation") it = await t.transform(["AI is going to"], {"max_new_tokens": 5}) - assert (len(it)) > 0 + assert len(it) > 0 + @pytest.mark.asyncio async def test_transformer_pipeline_stream(): @@ -316,7 +317,95 @@ async def test_transformer_pipeline_stream(): total = [] async for c in it: total.append(c) - assert (len(total)) > 0 + assert len(total) > 0 + + +################################################### +## OpenSourceAI tests ########################### +################################################### + + +def test_open_source_ai_create(): + client = pgml.OpenSourceAI() + results = client.chat_completions_create( + "mistralai/Mistral-7B-v0.1", + [ + { + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?", + }, + ], + temperature=0.85, + ) + assert len(results["choices"]) > 0 + + +@pytest.mark.asyncio +async def test_open_source_ai_create_async(): + client = pgml.OpenSourceAI() + results = await client.chat_completions_create_async( + "mistralai/Mistral-7B-v0.1", + [ + { + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?", + }, + ], + temperature=0.85, + ) + import json + assert len(results["choices"]) > 0 + + +def test_open_source_ai_create_stream(): + client = pgml.OpenSourceAI() + results = client.chat_completions_create_stream( + "mistralai/Mistral-7B-v0.1", + [ + { + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?", + }, + ], + temperature=0.85, + n=3, + ) + for c in results: + assert len(c["choices"]) > 0 + + +@pytest.mark.asyncio +async def test_open_source_ai_create_stream_async(): + client = pgml.OpenSourceAI() + results = await client.chat_completions_create_stream_async( + "mistralai/Mistral-7B-v0.1", + [ + { + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?", + }, + ], + temperature=0.85, + n=3, + ) + async for c in results: + assert len(c["choices"]) > 0 ################################################### diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 188948c72..db023b951 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -101,7 +101,7 @@ mod tests { let query = "SELECT * from pgml.collections"; let results = builtins.query(query).fetch_all().await?; assert!(results.as_array().is_some()); - Ok(()) + Ok(()) } #[sqlx::test] diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index 1aafd654b..c9a09326d 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -1,12 +1,12 @@ use futures::StreamExt; use neon::prelude::*; use rust_bridge::javascript::{FromJsType, IntoJsResult}; +use std::cell::RefCell; use std::sync::Arc; use crate::{ pipeline::PipelineSyncData, - transformer_pipeline::TransformerStream, - types::{DateTime, Json}, + types::{DateTime, GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, }; //////////////////////////////////////////////////////////////////////////////// @@ -74,17 +74,17 @@ impl IntoJsResult for PipelineSyncData { } #[derive(Clone)] -struct TransformerStreamArcMutex(Arc>); +struct GeneralJsonAsyncIteratorJavaScript(Arc>); -impl Finalize for TransformerStreamArcMutex {} +impl Finalize for GeneralJsonAsyncIteratorJavaScript {} fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult { let this = cx.this(); - let s: Handle> = this + let s: Handle> = this .get(&mut cx, "s") .expect("Error getting self in transformer_stream_iterate_next"); - let ts: &TransformerStreamArcMutex = &s; - let ts: TransformerStreamArcMutex = ts.clone(); + let ts: &GeneralJsonAsyncIteratorJavaScript = &s; + let ts: GeneralJsonAsyncIteratorJavaScript = ts.clone(); let channel = cx.channel(); let (deferred, promise) = cx.promise(); @@ -95,17 +95,19 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult .try_settle_with(&channel, move |mut cx| { let o = cx.empty_object(); if let Some(v) = v { - let v: String = v.expect("Error calling next on TransformerStream"); - let v = cx.string(v); + let v: Json = v.expect("Error calling next on GeneralJsonAsyncIterator"); + let v = v + .into_js_result(&mut cx) + .expect("Error converting rust Json to JavaScript Object"); let d = cx.boolean(false); o.set(&mut cx, "value", v) - .expect("Error setting object value in transformer_sream_iterate_next"); + .expect("Error setting object value in transform_sream_iterate_next"); o.set(&mut cx, "done", d) - .expect("Error setting object value in transformer_sream_iterate_next"); + .expect("Error setting object value in transform_sream_iterate_next"); } else { let d = cx.boolean(true); o.set(&mut cx, "done", d) - .expect("Error setting object value in transformer_sream_iterate_next"); + .expect("Error setting object value in transform_sream_iterate_next"); } Ok(o) }) @@ -114,8 +116,8 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult Ok(promise) } -impl IntoJsResult for TransformerStream { - type Output = JsObject; +impl IntoJsResult for GeneralJsonAsyncIterator { + type Output = JsValue; fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>( self, cx: &mut C, @@ -123,11 +125,55 @@ impl IntoJsResult for TransformerStream { let o = cx.empty_object(); let f: Handle = JsFunction::new(cx, transform_stream_iterate_next)?; o.set(cx, "next", f)?; - let s = cx.boxed(TransformerStreamArcMutex(Arc::new( + let s = cx.boxed(GeneralJsonAsyncIteratorJavaScript(Arc::new( tokio::sync::Mutex::new(self), ))); o.set(cx, "s", s)?; - Ok(o) + Ok(o.as_value(cx)) + } +} + +struct GeneralJsonIteratorJavaScript(RefCell); + +impl Finalize for GeneralJsonIteratorJavaScript {} + +fn transform_iterate_next(mut cx: FunctionContext) -> JsResult { + let this = cx.this(); + let s: Handle> = this + .get(&mut cx, "s") + .expect("Error getting self in transform_iterate_next"); + let v = s.0.borrow_mut().next(); + let o = cx.empty_object(); + if let Some(v) = v { + let v: Json = v.expect("Error calling next on GeneralJsonAsyncIterator"); + let v = v + .into_js_result(&mut cx) + .expect("Error converting rust Json to JavaScript Object"); + let d = cx.boolean(false); + o.set(&mut cx, "value", v) + .expect("Error setting object value in transform_iterate_next"); + o.set(&mut cx, "done", d) + .expect("Error setting object value in transform_iterate_next"); + } else { + let d = cx.boolean(true); + o.set(&mut cx, "done", d) + .expect("Error setting object value in transform_iterate_next"); + } + Ok(o) +} + +impl IntoJsResult for GeneralJsonIterator { + type Output = JsValue; + fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>( + self, + cx: &mut C, + ) -> JsResult<'b, Self::Output> { + let o = cx.empty_object(); + let f: Handle = JsFunction::new(cx, transform_iterate_next)?; + o.set(cx, "next", f)?; + let s = cx.boxed(GeneralJsonIteratorJavaScript(RefCell::new(self))); + o.set(cx, "s", s)?; + Ok(o.as_value(cx)) } } diff --git a/pgml-sdks/pgml/src/languages/python.rs b/pgml-sdks/pgml/src/languages/python.rs index 2cf1bcf9c..9d19b16bd 100644 --- a/pgml-sdks/pgml/src/languages/python.rs +++ b/pgml-sdks/pgml/src/languages/python.rs @@ -6,7 +6,10 @@ use std::sync::Arc; use rust_bridge::python::CustomInto; -use crate::{pipeline::PipelineSyncData, transformer_pipeline::TransformerStream, types::Json}; +use crate::{ + pipeline::PipelineSyncData, + types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, +}; //////////////////////////////////////////////////////////////////////////////// // Rust to PY ////////////////////////////////////////////////////////////////// @@ -55,12 +58,12 @@ impl IntoPy for PipelineSyncData { #[pyclass] #[derive(Clone)] -struct TransformerStreamPython { - wrapped: Arc>, +struct GeneralJsonAsyncIteratorPython { + wrapped: Arc>, } #[pymethods] -impl TransformerStreamPython { +impl GeneralJsonAsyncIteratorPython { fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } @@ -71,8 +74,8 @@ impl TransformerStreamPython { let mut ts = ts.lock().await; if let Some(o) = ts.next().await { Ok(Some(Python::with_gil(|py| { - o.expect("Error calling next on TransformerStream") - .to_object(py) + o.expect("Error calling next on GeneralJsonAsyncIterator") + .into_py(py) }))) } else { Err(pyo3::exceptions::PyStopAsyncIteration::new_err( @@ -84,15 +87,47 @@ impl TransformerStreamPython { } } -impl IntoPy for TransformerStream { +impl IntoPy for GeneralJsonAsyncIterator { fn into_py(self, py: Python) -> PyObject { - let f: Py = Py::new( + let f: Py = Py::new( py, - TransformerStreamPython { + GeneralJsonAsyncIteratorPython { wrapped: Arc::new(tokio::sync::Mutex::new(self)), }, ) - .expect("Error converting TransformerStream to TransformerStreamPython"); + .expect("Error converting GeneralJsonAsyncIterator to GeneralJsonAsyncIteratorPython"); + f.to_object(py) + } +} + +#[pyclass] +struct GeneralJsonIteratorPython { + wrapped: GeneralJsonIterator, +} + +#[pymethods] +impl GeneralJsonIteratorPython { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>, py: Python) -> PyResult> { + if let Some(o) = slf.wrapped.next() { + let o = o.expect("Error calling next on GeneralJsonIterator"); + Ok(Some(o.into_py(py))) + } else { + Err(pyo3::exceptions::PyStopIteration::new_err( + "stream exhausted", + )) + } + } +} + +impl IntoPy for GeneralJsonIterator { + fn into_py(self, py: Python) -> PyObject { + let f: Py = + Py::new(py, GeneralJsonIteratorPython { wrapped: self }) + .expect("Error converting GeneralJsonIterator to GeneralJsonIteratorPython"); f.to_object(py) } } @@ -149,7 +184,13 @@ impl FromPyObject<'_> for PipelineSyncData { } } -impl FromPyObject<'_> for TransformerStream { +impl FromPyObject<'_> for GeneralJsonAsyncIterator { + fn extract(_ob: &PyAny) -> PyResult { + panic!("We must implement this, but this is impossible to be reached") + } +} + +impl FromPyObject<'_> for GeneralJsonIterator { fn extract(_ob: &PyAny) -> PyResult { panic!("We must implement this, but this is impossible to be reached") } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index cd0eaaeef..f7ef4ceec 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -19,6 +19,7 @@ mod languages; pub mod migrations; mod model; pub mod models; +mod open_source_ai; mod order_by_builder; mod pipeline; mod queries; @@ -34,6 +35,7 @@ mod utils; pub use builtins::Builtins; pub use collection::Collection; pub use model::Model; +pub use open_source_ai::OpenSourceAI; pub use pipeline::Pipeline; pub use splitter::Splitter; pub use transformer_pipeline::TransformerPipeline; @@ -152,6 +154,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -201,6 +204,10 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { transformer_pipeline::TransformerPipelineJavascript::new, )?; cx.export_function("newPipeline", pipeline::PipelineJavascript::new)?; + cx.export_function( + "newOpenSourceAI", + open_source_ai::OpenSourceAIJavascript::new, + )?; Ok(()) } @@ -758,7 +765,6 @@ mod tests { .filter(filter) .fetch_all() .await?; - println!("{:?}", results); assert_eq!(results.len(), expected_result_count); } diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs new file mode 100644 index 000000000..18adde288 --- /dev/null +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -0,0 +1,448 @@ +use anyhow::Context; +use futures::{Stream, StreamExt}; +use rust_bridge::{alias, alias_methods}; +use std::time::{SystemTime, UNIX_EPOCH}; +use uuid::Uuid; + +use crate::{ + get_or_set_runtime, + types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, + TransformerPipeline, +}; + +#[cfg(feature = "python")] +use crate::types::{GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython, JsonPython}; + +#[derive(alias, Debug, Clone)] +pub struct OpenSourceAI { + database_url: Option, +} + +fn try_model_nice_name_to_model_name_and_parameters( + model_name: &str, +) -> Option<(&'static str, Json)> { + match model_name { + "mistralai/Mistral-7B-Instruct-v0.1" => Some(( + "mistralai/Mistral-7B-Instruct-v0.1", + serde_json::json!({ + "task": "conversational", + "model": "mistralai/Mistral-7B-Instruct-v0.1", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "HuggingFaceH4/zephyr-7b-beta" => Some(( + "HuggingFaceH4/zephyr-7b-beta", + serde_json::json!({ + "task": "conversational", + "model": "HuggingFaceH4/zephyr-7b-beta", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "TheBloke/Llama-2-7B-Chat-GPTQ" => Some(( + "TheBloke/Llama-2-7B-Chat-GPTQ", + serde_json::json!({ + "task": "conversational", + "model": "TheBloke/Llama-2-7B-Chat-GPTQ", + "device_map": "auto", + "revision": "main" + }) + .into(), + )), + + "teknium/OpenHermes-2.5-Mistral-7B" => Some(( + "teknium/OpenHermes-2.5-Mistral-7B", + serde_json::json!({ + "task": "conversational", + "model": "teknium/OpenHermes-2.5-Mistral-7B", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "Open-Orca/Mistral-7B-OpenOrca" => Some(( + "Open-Orca/Mistral-7B-OpenOrca", + serde_json::json!({ + "task": "conversational", + "model": "Open-Orca/Mistral-7B-OpenOrca", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "Undi95/Toppy-M-7B" => Some(( + "Undi95/Toppy-M-7B", + serde_json::json!({ + "model": "Undi95/Toppy-M-7B", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "Undi95/ReMM-SLERP-L2-13B" => Some(( + "Undi95/ReMM-SLERP-L2-13B", + serde_json::json!({ + "model": "Undi95/ReMM-SLERP-L2-13B", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "Gryphe/MythoMax-L2-13b" => Some(( + "Gryphe/MythoMax-L2-13b", + serde_json::json!({ + "model": "Gryphe/MythoMax-L2-13b", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "PygmalionAI/mythalion-13b" => Some(( + "PygmalionAI/mythalion-13b", + serde_json::json!({ + "model": "PygmalionAI/mythalion-13b", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "deepseek-ai/deepseek-llm-7b-chat" => Some(( + "deepseek-ai/deepseek-llm-7b-chat", + serde_json::json!({ + "model": "deepseek-ai/deepseek-llm-7b-chat", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "Phind/Phind-CodeLlama-34B-v2" => Some(( + "Phind/Phind-CodeLlama-34B-v2", + serde_json::json!({ + "model": "Phind/Phind-CodeLlama-34B-v2", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + _ => None, + } +} + +fn try_get_model_chat_template(model_name: &str) -> Option<&'static str> { + match model_name { + // Any Alpaca instruct tuned model + "Undi95/Toppy-M-7B" | "Undi95/ReMM-SLERP-L2-13B" | "Gryphe/MythoMax-L2-13b" | "Phind/Phind-CodeLlama-34B-v2" => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### Instruction:\n' + message['content'] + '\n'}}\n{% elif message['role'] == 'system' %}\n{{ message['content'] + '\n'}}\n{% elif message['role'] == 'model' %}\n{{ '### Response:>\n' + message['content'] + eos_token + '\n'}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Response:' }}\n{% endif %}\n{% endfor %}"), + "PygmalionAI/mythalion-13b" => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'model' %}\n{{ '<|model|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|model|>' }}\n{% endif %}\n{% endfor %}"), + _ => None + } +} + +struct AsyncToSyncJsonIterator(std::pin::Pin> + Send>>); + +impl Iterator for AsyncToSyncJsonIterator { + type Item = anyhow::Result; + + fn next(&mut self) -> Option { + let runtime = get_or_set_runtime(); + runtime.block_on(self.0.next()) + } +} + +#[alias_methods( + new, + chat_completions_create, + chat_completions_create_async, + chat_completions_create_stream, + chat_completions_create_stream_async +)] +impl OpenSourceAI { + pub fn new(database_url: Option) -> Self { + Self { database_url } + } + + fn create_pipeline_model_name_parameters( + &self, + mut model: Json, + ) -> anyhow::Result<(TransformerPipeline, String, Json)> { + if model.is_object() { + let args = model.as_object_mut().unwrap(); + let model_name = args + .remove("model") + .context("`model` is a required key in the model object")?; + let model_name = model_name.as_str().context("`model` must be a string")?; + Ok(( + TransformerPipeline::new( + "conversational", + Some(model_name.to_string()), + Some(model.clone()), + self.database_url.clone(), + ), + model_name.to_string(), + model, + )) + } else { + let model_name = model + .as_str() + .context("`model` must either be a string or an object")?; + let (real_model_name, parameters) = + try_model_nice_name_to_model_name_and_parameters(model_name).context( + r#"Please select one of the provided models: +mistralai/Mistral-7B-v0.1 +"#, + )?; + Ok(( + TransformerPipeline::new( + "conversational", + Some(real_model_name.to_string()), + Some(parameters.clone()), + self.database_url.clone(), + ), + model_name.to_string(), + parameters, + )) + } + } + + #[allow(clippy::too_many_arguments)] + pub async fn chat_completions_create_stream_async( + &self, + model: Json, + messages: Vec, + max_tokens: Option, + temperature: Option, + n: Option, + chat_template: Option, + ) -> anyhow::Result { + let (transformer_pipeline, model_name, model_parameters) = + self.create_pipeline_model_name_parameters(model)?; + + let max_tokens = max_tokens.unwrap_or(1000); + let temperature = temperature.unwrap_or(0.8); + let n = n.unwrap_or(1) as usize; + let to_hash = format!("{}{}{}{}", *model_parameters, max_tokens, temperature, n); + let md5_digest = md5::compute(to_hash.as_bytes()); + let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; + + let mut args = serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }); + if let Some(t) = chat_template + .or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string())) + { + args.as_object_mut().unwrap().insert( + "chat_template".to_string(), + serde_json::to_value(t).unwrap(), + ); + } + + let messages = serde_json::to_value(messages)?.into(); + let iterator = transformer_pipeline + .transform_stream(messages, Some(args.into()), Some(1)) + .await?; + + let id = Uuid::new_v4().to_string(); + let iter = iterator.map(move |choices| { + let since_the_epoch = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + Ok(serde_json::json!({ + "id": id.clone(), + "system_fingerprint": fingerprint.clone(), + "object": "chat.completion.chunk", + "created": since_the_epoch.as_secs(), + "model": model_name.clone(), + "choices": choices?.as_array().context("Error parsing choices from GeneralJsonAsyncIterator")?.iter().enumerate().map(|(i, c)| { + serde_json::json!({ + "index": i, + "delta": { + "role": "assistant", + "content": c + } + }) + }).collect::() + }) + .into()) + }); + + Ok(GeneralJsonAsyncIterator(Box::pin(iter))) + } + + #[allow(clippy::too_many_arguments)] + pub fn chat_completions_create_stream( + &self, + model: Json, + messages: Vec, + max_tokens: Option, + temperature: Option, + n: Option, + chat_template: Option, + ) -> anyhow::Result { + let runtime = crate::get_or_set_runtime(); + let iter = runtime.block_on(self.chat_completions_create_stream_async( + model, + messages, + max_tokens, + temperature, + n, + chat_template, + ))?; + Ok(GeneralJsonIterator(Box::new(AsyncToSyncJsonIterator( + Box::pin(iter), + )))) + } + + #[allow(clippy::too_many_arguments)] + pub async fn chat_completions_create_async( + &self, + model: Json, + messages: Vec, + max_tokens: Option, + temperature: Option, + n: Option, + chat_template: Option, + ) -> anyhow::Result { + let (transformer_pipeline, model_name, model_parameters) = + self.create_pipeline_model_name_parameters(model)?; + + let max_tokens = max_tokens.unwrap_or(1000); + let temperature = temperature.unwrap_or(0.8); + let n = n.unwrap_or(1) as usize; + let to_hash = format!("{}{}{}{}", *model_parameters, max_tokens, temperature, n); + let md5_digest = md5::compute(to_hash.as_bytes()); + let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; + + let mut args = serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }); + if let Some(t) = chat_template + .or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string())) + { + args.as_object_mut().unwrap().insert( + "chat_template".to_string(), + serde_json::to_value(t).unwrap(), + ); + } + + let choices = transformer_pipeline + .transform(messages, Some(args.into())) + .await?; + let choices: Vec = choices + .as_array() + .context("Error parsing return from TransformerPipeline")? + .iter() + .enumerate() + .map(|(i, c)| { + serde_json::json!({ + "index": i, + "message": { + "role": "assistant", + "content": c + } + // Finish reason should be here + }) + .into() + }) + .collect(); + let since_the_epoch = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + Ok(serde_json::json!({ + "id": Uuid::new_v4().to_string(), + "object": "chat.completion", + "created": since_the_epoch.as_secs(), + "model": model_name, + "system_fingerprint": fingerprint, + "choices": choices, + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + }) + .into()) + } + + #[allow(clippy::too_many_arguments)] + pub fn chat_completions_create( + &self, + model: Json, + messages: Vec, + max_tokens: Option, + temperature: Option, + n: Option, + chat_template: Option, + ) -> anyhow::Result { + let runtime = crate::get_or_set_runtime(); + runtime.block_on(self.chat_completions_create_async( + model, + messages, + max_tokens, + temperature, + n, + chat_template, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::StreamExt; + + #[test] + fn can_open_source_ai_create() -> anyhow::Result<()> { + let client = OpenSourceAI::new(None); + let results = client.chat_completions_create(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ + serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), + serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), + ], Some(10), None, Some(3), None)?; + assert!(results["choices"].as_array().is_some()); + Ok(()) + } + + #[sqlx::test] + fn can_open_source_ai_create_async() -> anyhow::Result<()> { + let client = OpenSourceAI::new(None); + let results = client.chat_completions_create_async(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ + serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), + serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), + ], Some(10), None, Some(3), None).await?; + assert!(results["choices"].as_array().is_some()); + Ok(()) + } + + #[sqlx::test] + fn can_open_source_ai_create_stream_async() -> anyhow::Result<()> { + let client = OpenSourceAI::new(None); + let mut stream = client.chat_completions_create_stream_async(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ + serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), + serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), + ], Some(10), None, Some(3), None).await?; + while let Some(o) = stream.next().await { + o?; + } + Ok(()) + } + + #[test] + fn can_open_source_ai_create_stream() -> anyhow::Result<()> { + let client = OpenSourceAI::new(None); + let iterator = client.chat_completions_create_stream(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ + serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), + serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), + ], Some(10), None, Some(3), None)?; + for o in iterator { + o?; + } + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 70fd3f925..00dd556f7 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -1,5 +1,6 @@ +use anyhow::Context; use futures::Stream; -use rust_bridge::{alias, alias_manual, alias_methods}; +use rust_bridge::{alias, alias_methods}; use sqlx::{postgres::PgRow, Row}; use sqlx::{Postgres, Transaction}; use std::collections::VecDeque; @@ -15,14 +16,14 @@ pub struct TransformerPipeline { database_url: Option, } +use crate::types::GeneralJsonAsyncIterator; use crate::{get_or_initialize_pool, types::Json}; #[cfg(feature = "python")] -use crate::types::JsonPython; +use crate::types::{GeneralJsonAsyncIteratorPython, JsonPython}; #[allow(clippy::type_complexity)] -#[derive(alias_manual)] -pub struct TransformerStream { +struct TransformerStream { transaction: Option>, future: Option, sqlx::Error>> + Send + 'static>>>, commit: Option> + Send + 'static>>>, @@ -54,7 +55,7 @@ impl TransformerStream { } impl Stream for TransformerStream { - type Item = Result; + type Item = anyhow::Result; fn poll_next( mut self: Pin<&mut Self>, @@ -105,7 +106,7 @@ impl Stream for TransformerStream { if !self.results.is_empty() { let r = self.results.pop_front().unwrap(); - Poll::Ready(Some(Ok(r.get::(0)))) + Poll::Ready(Some(Ok(r.get::(0)))) } else if self.done { Poll::Ready(None) } else { @@ -141,16 +142,36 @@ impl TransformerPipeline { } #[instrument(skip(self))] - pub async fn transform(&self, inputs: Vec, args: Option) -> anyhow::Result { + pub async fn transform(&self, inputs: Vec, args: Option) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let args = args.unwrap_or_default(); - let results = sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)") - .bind(&self.task) - .bind(inputs) - .bind(&args) - .fetch_all(&pool) - .await?; + // We set the task in the new constructor so we can unwrap here + let results = if self.task["task"].as_str().unwrap() == "conversational" { + let inputs: Vec = inputs.into_iter().map(|j| j.0).collect(); + sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)") + .bind(&self.task) + .bind(inputs) + .bind(&args) + .fetch_all(&pool) + .await? + } else { + let inputs: anyhow::Result> = + inputs + .into_iter() + .map(|input| { + input.as_str().context( + "the inputs arg must be strings when not using the conversational task", + ).map(|s| s.to_string()) + }) + .collect(); + sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)") + .bind(&self.task) + .bind(inputs?) + .bind(&args) + .fetch_all(&pool) + .await? + }; let results = results.get(0).unwrap().get::(0); Ok(Json(results)) } @@ -158,25 +179,50 @@ impl TransformerPipeline { #[instrument(skip(self))] pub async fn transform_stream( &self, - input: &str, + input: Json, args: Option, batch_size: Option, - ) -> anyhow::Result { + ) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let args = args.unwrap_or_default(); let batch_size = batch_size.unwrap_or(10); let mut transaction = pool.begin().await?; - sqlx::query( - "DECLARE c CURSOR FOR SELECT pgml.transform_stream(task => $1, input => $2, args => $3)", - ) - .bind(&self.task) - .bind(input) - .bind(&args) - .execute(&mut *transaction) - .await?; + // We set the task in the new constructor so we can unwrap here + if self.task["task"].as_str().unwrap() == "conversational" { + let inputs = input + .as_array() + .context("`input` to transformer_stream must be an array of objects")? + .to_vec(); + sqlx::query( + "DECLARE c CURSOR FOR SELECT pgml.transform_stream(task => $1, inputs => $2, args => $3)", + ) + .bind(&self.task) + .bind(inputs) + .bind(&args) + .execute(&mut *transaction) + .await?; + } else { + let input = input + .as_str() + .context( + "`input` to transformer_stream must be a string if task is not conversational", + )? + .to_string(); + sqlx::query( + "DECLARE c CURSOR FOR SELECT pgml.transform_stream(task => $1, input => $2, args => $3)", + ) + .bind(&self.task) + .bind(input) + .bind(&args) + .execute(&mut *transaction) + .await?; + } - Ok(TransformerStream::new(transaction, batch_size)) + Ok(GeneralJsonAsyncIterator(Box::pin(TransformerStream::new( + transaction, + batch_size, + )))) } } @@ -198,8 +244,8 @@ mod tests { let results = t .transform( vec![ - "How are you doing today?".to_string(), - "What is a good song?".to_string(), + serde_json::Value::String("How are you doing today?".to_string()).into(), + serde_json::Value::String("How are you doing today?".to_string()).into(), ], None, ) @@ -215,8 +261,8 @@ mod tests { let results = t .transform( vec![ - "How are you doing today?".to_string(), - "What is a good song?".to_string(), + serde_json::Value::String("How are you doing today?".to_string()).into(), + serde_json::Value::String("How are you doing today?".to_string()).into(), ], None, ) @@ -230,10 +276,10 @@ mod tests { internal_init_logger(None, None).ok(); let t = TransformerPipeline::new( "text-generation", - Some("TheBloke/zephyr-7B-beta-GGUF".to_string()), + Some("TheBloke/zephyr-7B-beta-GPTQ".to_string()), Some( serde_json::json!({ - "model_file": "zephyr-7b-beta.Q5_K_M.gguf", "model_type": "mistral" + "model_type": "mistral", "revision": "main", "device_map": "auto" }) .into(), ), @@ -241,7 +287,7 @@ mod tests { ); let mut stream = t .transform_stream( - "AI is going to", + serde_json::json!("AI is going to").into(), Some( serde_json::json!({ "max_new_tokens": 10 diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index ba80583e8..bdf7308a3 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -1,4 +1,5 @@ use anyhow::Context; +use futures::{Stream, StreamExt}; use itertools::Itertools; use rust_bridge::alias_manual; use sea_query::Iden; @@ -42,6 +43,19 @@ impl Serialize for Json { } } +// This will cause some conflicting trait issue +// impl From for Json { +// fn from(v: T) -> Self { +// Self(serde_json::to_value(v).unwrap()) +// } +// } + +impl Json { + pub fn from_serializable(v: T) -> Self { + Self(serde_json::to_value(v).unwrap()) + } +} + pub(crate) trait TryToNumeric { fn try_to_u64(&self) -> anyhow::Result; fn try_to_i64(&self) -> anyhow::Result { @@ -109,3 +123,30 @@ impl IntoTableNameAndSchema for String { .expect("Malformed table name in IntoTableNameAndSchema") } } + +#[derive(alias_manual)] +pub struct GeneralJsonAsyncIterator( + pub std::pin::Pin> + Send>>, +); + +impl Stream for GeneralJsonAsyncIterator { + type Item = anyhow::Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0.poll_next_unpin(cx) + } +} + +#[derive(alias_manual)] +pub struct GeneralJsonIterator(pub Box> + Send>); + +impl Iterator for GeneralJsonIterator { + type Item = anyhow::Result; + + fn next(&mut self) -> Option { + self.0.next() + } +}