Skip to content

Commit be11967

Browse files
committed
Working simple python thread metrics
1 parent f7401b8 commit be11967

File tree

1 file changed

+63
-5
lines changed

1 file changed

+63
-5
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import shutil
44
import time
55
import queue
6+
import sys
7+
import json
68

79
import datasets
810
from InstructorEmbedding import INSTRUCTOR
@@ -40,7 +42,7 @@
4042
TrainingArguments,
4143
Trainer,
4244
)
43-
from threading import Thread
45+
import threading
4446

4547
__cache_transformer_by_model_id = {}
4648
__cache_sentence_transformer_by_name = {}
@@ -62,6 +64,26 @@
6264
}
6365

6466

67+
class WorkerThreads:
68+
def __init__(self):
69+
self.worker_threads = {}
70+
71+
def delete_thread(self, id):
72+
del self.worker_threads[id]
73+
74+
def update_thread(self, id, value):
75+
self.worker_threads[id] = value
76+
77+
def get_thread(self, id):
78+
if id in self.worker_threads:
79+
return self.worker_threads[id]
80+
else:
81+
return None
82+
83+
84+
worker_threads = WorkerThreads()
85+
86+
6587
class PgMLException(Exception):
6688
pass
6789

@@ -105,6 +127,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
105127
self.token_cache = []
106128
self.text_index_cache = []
107129

130+
def set_worker_thread_id(self, id):
131+
self.worker_thread_id = id
132+
133+
def get_worker_thread_id(self):
134+
return self.worker_thread_id
135+
108136
def put(self, values):
109137
if self.skip_prompt and self.next_tokens_are_prompt:
110138
self.next_tokens_are_prompt = False
@@ -149,6 +177,22 @@ def __next__(self):
149177
return value
150178

151179

180+
def streaming_worker(worker_threads, model, **kwargs):
181+
thread_id = threading.get_native_id()
182+
try:
183+
worker_threads.update_thread(
184+
thread_id, json.dumps({"model": model.name_or_path})
185+
)
186+
except:
187+
worker_threads.update_thread(thread_id, "Error setting data")
188+
try:
189+
model.generate(**kwargs)
190+
except BaseException as error:
191+
print(f"Error in streaming_worker: {error}", file=sys.stderr)
192+
finally:
193+
worker_threads.delete_thread(thread_id)
194+
195+
152196
class GGMLPipeline(object):
153197
def __init__(self, model_name, **task):
154198
import ctransformers
@@ -185,7 +229,7 @@ def do_work():
185229
self.q.put(x)
186230
self.done = True
187231

188-
thread = Thread(target=do_work)
232+
thread = threading.Thread(target=do_work)
189233
thread.start()
190234

191235
def __iter__(self):
@@ -283,7 +327,13 @@ def stream(self, input, timeout=None, **kwargs):
283327
input, add_generation_prompt=True, tokenize=False
284328
)
285329
input = self.tokenizer(input, return_tensors="pt").to(self.model.device)
286-
generation_kwargs = dict(input, streamer=streamer, **kwargs)
330+
generation_kwargs = dict(
331+
input,
332+
worker_threads=worker_threads,
333+
model=self.model,
334+
streamer=streamer,
335+
**kwargs,
336+
)
287337
else:
288338
streamer = TextIteratorStreamer(
289339
self.tokenizer,
@@ -292,9 +342,17 @@ def stream(self, input, timeout=None, **kwargs):
292342
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
293343
self.model.device
294344
)
295-
generation_kwargs = dict(input, streamer=streamer, **kwargs)
296-
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
345+
generation_kwargs = dict(
346+
input,
347+
worker_threads=worker_threads,
348+
model=self.model,
349+
streamer=streamer,
350+
**kwargs,
351+
)
352+
# thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
353+
thread = threading.Thread(target=streaming_worker, kwargs=generation_kwargs)
297354
thread.start()
355+
streamer.set_worker_thread_id(thread.native_id)
298356
return streamer
299357

300358
def __call__(self, inputs, **kwargs):

0 commit comments

Comments
 (0)