diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index c738fe1f5..7ff3bdc18 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -131,7 +131,7 @@ def put(self, values): self.text_index_cache[i] += len(printable_text) output.append(printable_text) if any(output): - self.text_queue.put(output, self.timeout) + self.text_queue.put(output) def end(self): self.next_tokens_are_prompt = True @@ -139,8 +139,8 @@ def end(self): for i, tokens in enumerate(self.token_cache): text = self.tokenizer.decode(tokens, **self.decode_kwargs) output.append(text[self.text_index_cache[i] :]) - self.text_queue.put(output, self.timeout) - self.text_queue.put(self.stop_signal, self.timeout) + self.text_queue.put(output) + self.text_queue.put(self.stop_signal) def __iter__(self): return self @@ -264,12 +264,13 @@ def __init__(self, model_name, **kwargs): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token - def stream(self, input, **kwargs): + def stream(self, input, timeout=None, **kwargs): streamer = None generation_kwargs = None if self.task == "conversational": streamer = TextIteratorStreamer( self.tokenizer, + timeout=timeout, skip_prompt=True, ) if "chat_template" in kwargs: @@ -286,7 +287,10 @@ def stream(self, input, **kwargs): input = self.tokenizer(input, return_tensors="pt").to(self.model.device) generation_kwargs = dict(input, streamer=streamer, **kwargs) else: - streamer = TextIteratorStreamer(self.tokenizer) + streamer = TextIteratorStreamer( + self.tokenizer, + timeout=timeout, + ) input = self.tokenizer(input, return_tensors="pt", padding=True).to( self.model.device ) @@ -355,7 +359,7 @@ def create_pipeline(task): return pipe -def transform_using(pipeline, args, inputs, stream=False): +def transform_using(pipeline, args, inputs, stream=False, timeout=None): args = orjson.loads(args) inputs = orjson.loads(inputs) @@ -364,7 +368,7 @@ def transform_using(pipeline, args, inputs, stream=False): convert_eos_token(pipeline.tokenizer, args) if stream: - return pipeline.stream(inputs, **args) + return pipeline.stream(inputs, timeout=timeout, **args) return orjson.dumps(pipeline(inputs, **args), default=orjson_default).decode()