Skip to content

Commit 2beb247

Browse files
authored
Allow for users to pass in a quantization config (#1269)
1 parent a8eb55b commit 2beb247

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
PegasusTokenizer,
4242
TrainingArguments,
4343
Trainer,
44+
GPTQConfig
4445
)
4546
import threading
4647

@@ -279,7 +280,13 @@ def __init__(self, model_name, **kwargs):
279280
elif self.task == "summarization" or self.task == "translation":
280281
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
281282
elif self.task == "text-generation" or self.task == "conversational":
282-
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
283+
# See: https://huggingface.co/docs/transformers/main/quantization
284+
if "quantization_config" in kwargs:
285+
quantization_config = kwargs.pop("quantization_config")
286+
quantization_config = GPTQConfig(**quantization_config)
287+
self.model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, **kwargs)
288+
else:
289+
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
283290
else:
284291
raise PgMLException(f"Unhandled task: {self.task}")
285292

@@ -409,10 +416,13 @@ def create_pipeline(task):
409416
else:
410417
try:
411418
pipe = StandardPipeline(model_name, **task)
412-
except TypeError:
413-
# some models fail when given "device" kwargs, remove and try again
414-
task.pop("device")
415-
pipe = StandardPipeline(model_name, **task)
419+
except TypeError as error:
420+
if "device" in task:
421+
# some models fail when given "device" kwargs, remove and try again
422+
task.pop("device")
423+
pipe = StandardPipeline(model_name, **task)
424+
else:
425+
raise error
416426
return pipe
417427

418428

0 commit comments

Comments
 (0)