Skip to content

Commit 0b42fcc

Browse files
authored
large models need device_maps (#633)
1 parent 81ff9f3 commit 0b42fcc

File tree

5 files changed

+25
-35
lines changed

5 files changed

+25
-35
lines changed

pgml-extension/examples/transformers.sql

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
\timing on
44

55
SELECT pgml.embed('intfloat/e5-small', 'hi mom');
6-
6+
SELECT pgml.embed('intfloat/e5-small', 'hi mom', '{"device": "cuda"}');
7+
SELECT pgml.embed('intfloat/e5-small', 'hi mom', '{"device": "cpu"}');
78

89
SELECT pgml.transform(
910
'translation_en_to_fr',
@@ -16,7 +17,7 @@ SELECT pgml.transform(
1617
SELECT pgml.transform(
1718
'{"model": "roberta-large-mnli"}'::JSONB,
1819
inputs => ARRAY[
19-
'I love how amazingly simple ML has become!',
20+
'I love how amazingly simple ML has become!',
2021
'Some models are painfully slow and expensive ☹️'
2122
]
2223
) AS result;
@@ -35,13 +36,13 @@ SELECT pgml.transform(
3536
]
3637
);
3738
SELECT pgml.transform(
39+
task => '{"task": "text-classification",
40+
"model": "finiteautomata/bertweet-base-sentiment-analysis"
41+
}'::JSONB,
3842
inputs => ARRAY[
3943
'I love how amazingly simple ML has become!',
4044
'I hate doing mundane and thankless tasks. ☹️'
4145
],
42-
task => '{"task": "text-classification",
43-
"model": "finiteautomata/bertweet-base-sentiment-analysis"
44-
}'::JSONB
4546
) AS positivity;
4647

4748
SELECT pgml.transform(

pgml-extension/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
accelerate==0.16.0
1+
accelerate==0.19.0
22
datasets==2.10.1
33
deepspeed==0.8.1
44
InstructorEmbedding
@@ -15,5 +15,5 @@ torch==1.13.1
1515
torchaudio==0.13.1
1616
torchvision==0.14.1
1717
tqdm==4.64.1
18-
transformers==4.26.1
18+
transformers==4.28.1
1919
xgboost

pgml-extension/src/api.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,10 +574,9 @@ pub fn transform_json(
574574
task: JsonB,
575575
args: default!(JsonB, "'{}'"),
576576
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
577-
cache: default!(bool, false),
578577
) -> JsonB {
579578
JsonB(crate::bindings::transformers::transform(
580-
&task.0, &args.0, &inputs, cache,
579+
&task.0, &args.0, &inputs,
581580
))
582581
}
583582

@@ -587,13 +586,12 @@ pub fn transform_string(
587586
task: String,
588587
args: default!(JsonB, "'{}'"),
589588
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
590-
cache: default!(bool, false),
591589
) -> JsonB {
592590
let mut task_map = HashMap::new();
593591
task_map.insert("task", task);
594592
let task_json = json!(task_map);
595593
JsonB(crate::bindings::transformers::transform(
596-
&task_json, &args.0, &inputs, cache,
594+
&task_json, &args.0, &inputs,
597595
))
598596
}
599597

pgml-extension/src/bindings/transformers.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,17 @@ def default(self, obj):
5050
return super().default(obj)
5151

5252

53-
def transform(task, args, inputs, cache):
53+
def transform(task, args, inputs):
5454
task = json.loads(task)
5555
args = json.loads(args)
5656
inputs = json.loads(inputs)
5757

58-
task["device"] = assign_device(task.get("device"))
58+
ensure_device(task)
5959

60-
if cache:
61-
key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())])
62-
if key not in __cache_transform_pipeline_by_task:
63-
__cache_transform_pipeline_by_task[key] = transformers.pipeline(**task)
64-
pipe = __cache_transform_pipeline_by_task[key]
65-
else:
66-
pipe = transformers.pipeline(**task)
60+
key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())])
61+
if key not in __cache_transform_pipeline_by_task:
62+
__cache_transform_pipeline_by_task[key] = transformers.pipeline(**task)
63+
pipe = __cache_transform_pipeline_by_task[key]
6764

6865
if pipe.task == "question-answering":
6966
inputs = [json.loads(input) for input in inputs]
@@ -73,7 +70,7 @@ def transform(task, args, inputs, cache):
7370

7471
def embed(transformer, text, kwargs):
7572
kwargs = json.loads(kwargs)
76-
kwargs["device"] = assign_device(kwargs.get("device"))
73+
ensure_device(kwargs)
7774
instructor = transformer.startswith("hkunlp/instructor")
7875
if instructor:
7976
klass = INSTRUCTOR
@@ -543,16 +540,12 @@ def generate(model_id, data, config):
543540
return all_preds
544541

545542

546-
def assign_device(device=None):
547-
if device is not None:
548-
if device == "cpu" or "cuda:" in device:
549-
return device
550-
if "cuda" in device and not torch.cuda.is_available():
551-
raise Exception("CUDA is not available")
552-
553-
if torch.cuda.is_available():
554-
device = "cuda:" + str(os.getpid() % torch.cuda.device_count())
555-
else:
556-
device = "cpu"
543+
def ensure_device(kwargs):
544+
device = kwargs.get("device")
545+
device_map = kwargs.get("device_map")
546+
if device is None and device_map is None:
547+
if torch.cuda.is_available():
548+
kwargs["device"] = "cuda:" + str(os.getpid() % torch.cuda.device_count())
549+
else:
550+
kwargs["device"] = "cpu"
557551

558-
return device

pgml-extension/src/bindings/transformers.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ pub fn transform(
2525
task: &serde_json::Value,
2626
args: &serde_json::Value,
2727
inputs: &Vec<String>,
28-
cache: bool,
2928
) -> serde_json::Value {
3029
crate::bindings::venv::activate();
3130

@@ -45,7 +44,6 @@ pub fn transform(
4544
task.into_py(py),
4645
args.into_py(py),
4746
inputs.into_py(py),
48-
cache.into_py(py),
4947
],
5048
),
5149
)

0 commit comments

Comments
 (0)