Skip to content

Commit 996b514

Browse files
montanalowMontana Low
andauthored
GGML and GPTQ compatibility (#748)
Co-authored-by: Montana Low <montanalow@gmail.com>
1 parent 11fa8ee commit 996b514

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

pgml-extension/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
accelerate==0.19.0
2+
auto-gptq==0.2.2
3+
ctransformers==0.2.8
24
datasets==2.12.0
35
deepspeed==0.9.2
46
huggingface-hub==0.14.1

pgml-extension/src/bindings/transformers.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,47 @@ def ensure_device(kwargs):
8181
else:
8282
kwargs["device"] = "cpu"
8383

84+
85+
class GPTQPipeline(object):
86+
def __init__(self, model_name, **task):
87+
import auto_gptq
88+
from huggingface_hub import snapshot_download
89+
model_path = snapshot_download(model_name)
90+
91+
self.model = auto_gptq.AutoGPTQForCausalLM.from_quantized(model_path, **task)
92+
if "use_fast_tokenizer" in task:
93+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=task.pop("use_fast_tokenizer"))
94+
else:
95+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
96+
self.task = "text-generation"
97+
98+
def __call__(self, inputs, **kwargs):
99+
outputs = []
100+
for input in inputs:
101+
tokens = self.tokenizer(input, return_tensors="pt").to(self.model.device).input_ids
102+
token_ids = self.model.generate(input_ids=tokens, **kwargs)[0]
103+
outputs.append(self.tokenizer.decode(token_ids))
104+
return outputs
105+
106+
107+
class GGMLPipeline(object):
108+
def __init__(self, model_name, **task):
109+
import ctransformers
110+
111+
task.pop("model")
112+
task.pop("task")
113+
task.pop("device")
114+
self.model = ctransformers.AutoModelForCausalLM.from_pretrained(model_name, **task)
115+
self.tokenizer = None
116+
self.task = "text-generation"
117+
118+
def __call__(self, inputs, **kwargs):
119+
outputs = []
120+
for input in inputs:
121+
outputs.append(self.model(input, **kwargs))
122+
return outputs
123+
124+
84125
def transform(task, args, inputs):
85126
task = orjson.loads(task)
86127
args = orjson.loads(args)
@@ -90,21 +131,25 @@ def transform(task, args, inputs):
90131
if key not in __cache_transform_pipeline_by_task:
91132
ensure_device(task)
92133
convert_dtype(task)
93-
pipe = transformers.pipeline(**task)
94-
if pipe.tokenizer is None:
95-
pipe.tokenizer = AutoTokenizer.from_pretrained(pipe.model.name_or_path)
134+
model_name = task.get("model", None)
135+
model_name = model_name.lower() if model_name else None
136+
if model_name and "-ggml" in model_name:
137+
pipe = GGMLPipeline(model_name, **task)
138+
elif model_name and "-gptq" in model_name:
139+
pipe = GPTQPipeline(model_name, **task)
140+
else:
141+
pipe = transformers.pipeline(**task)
142+
if pipe.tokenizer is None:
143+
pipe.tokenizer = AutoTokenizer.from_pretrained(pipe.model.name_or_path)
96144
__cache_transform_pipeline_by_task[key] = pipe
97145

98146
pipe = __cache_transform_pipeline_by_task[key]
99147

100148
if pipe.task == "question-answering":
101149
inputs = [orjson.loads(input) for input in inputs]
102-
103150
convert_eos_token(pipe.tokenizer, args)
104151

105-
results = pipe(inputs, **args)
106-
107-
return orjson.dumps(results, default=orjson_default).decode()
152+
return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode()
108153

109154

110155
def embed(transformer, inputs, kwargs):

0 commit comments

Comments
 (0)