Skip to content

Commit 4cfc458

Browse files
authored
Added correct timeout (#1219)
1 parent b85cef4 commit 4cfc458

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,16 @@ def put(self, values):
131131
self.text_index_cache[i] += len(printable_text)
132132
output.append(printable_text)
133133
if any(output):
134-
self.text_queue.put(output, self.timeout)
134+
self.text_queue.put(output)
135135

136136
def end(self):
137137
self.next_tokens_are_prompt = True
138138
output = []
139139
for i, tokens in enumerate(self.token_cache):
140140
text = self.tokenizer.decode(tokens, **self.decode_kwargs)
141141
output.append(text[self.text_index_cache[i] :])
142-
self.text_queue.put(output, self.timeout)
143-
self.text_queue.put(self.stop_signal, self.timeout)
142+
self.text_queue.put(output)
143+
self.text_queue.put(self.stop_signal)
144144

145145
def __iter__(self):
146146
return self
@@ -264,12 +264,13 @@ def __init__(self, model_name, **kwargs):
264264
if self.tokenizer.pad_token is None:
265265
self.tokenizer.pad_token = self.tokenizer.eos_token
266266

267-
def stream(self, input, **kwargs):
267+
def stream(self, input, timeout=None, **kwargs):
268268
streamer = None
269269
generation_kwargs = None
270270
if self.task == "conversational":
271271
streamer = TextIteratorStreamer(
272272
self.tokenizer,
273+
timeout=timeout,
273274
skip_prompt=True,
274275
)
275276
if "chat_template" in kwargs:
@@ -286,7 +287,10 @@ def stream(self, input, **kwargs):
286287
input = self.tokenizer(input, return_tensors="pt").to(self.model.device)
287288
generation_kwargs = dict(input, streamer=streamer, **kwargs)
288289
else:
289-
streamer = TextIteratorStreamer(self.tokenizer)
290+
streamer = TextIteratorStreamer(
291+
self.tokenizer,
292+
timeout=timeout,
293+
)
290294
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
291295
self.model.device
292296
)
@@ -355,7 +359,7 @@ def create_pipeline(task):
355359
return pipe
356360

357361

358-
def transform_using(pipeline, args, inputs, stream=False):
362+
def transform_using(pipeline, args, inputs, stream=False, timeout=None):
359363
args = orjson.loads(args)
360364
inputs = orjson.loads(inputs)
361365

@@ -364,7 +368,7 @@ def transform_using(pipeline, args, inputs, stream=False):
364368
convert_eos_token(pipeline.tokenizer, args)
365369

366370
if stream:
367-
return pipeline.stream(inputs, **args)
371+
return pipeline.stream(inputs, timeout=timeout, **args)
368372
return orjson.dumps(pipeline(inputs, **args), default=orjson_default).decode()
369373

370374

0 commit comments

Comments
 (0)