|
41 | 41 | PegasusTokenizer,
|
42 | 42 | TrainingArguments,
|
43 | 43 | Trainer,
|
| 44 | + GPTQConfig |
44 | 45 | )
|
45 | 46 | import threading
|
46 | 47 |
|
@@ -279,7 +280,13 @@ def __init__(self, model_name, **kwargs):
|
279 | 280 | elif self.task == "summarization" or self.task == "translation":
|
280 | 281 | self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
|
281 | 282 | elif self.task == "text-generation" or self.task == "conversational":
|
282 |
| - self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) |
| 283 | + # See: https://huggingface.co/docs/transformers/main/quantization |
| 284 | + if "quantization_config" in kwargs: |
| 285 | + quantization_config = kwargs.pop("quantization_config") |
| 286 | + quantization_config = GPTQConfig(**quantization_config) |
| 287 | + self.model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, **kwargs) |
| 288 | + else: |
| 289 | + self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) |
283 | 290 | else:
|
284 | 291 | raise PgMLException(f"Unhandled task: {self.task}")
|
285 | 292 |
|
@@ -409,10 +416,13 @@ def create_pipeline(task):
|
409 | 416 | else:
|
410 | 417 | try:
|
411 | 418 | pipe = StandardPipeline(model_name, **task)
|
412 |
| - except TypeError: |
413 |
| - # some models fail when given "device" kwargs, remove and try again |
414 |
| - task.pop("device") |
415 |
| - pipe = StandardPipeline(model_name, **task) |
| 419 | + except TypeError as error: |
| 420 | + if "device" in task: |
| 421 | + # some models fail when given "device" kwargs, remove and try again |
| 422 | + task.pop("device") |
| 423 | + pipe = StandardPipeline(model_name, **task) |
| 424 | + else: |
| 425 | + raise error |
416 | 426 | return pipe
|
417 | 427 |
|
418 | 428 |
|
|
0 commit comments