Skip to content

Commit 41de0aa

Browse files
authored
Adds pipeline model caching in the transform function. (#593)
1 parent 48fdfca commit 41de0aa

File tree

4 files changed

+21
-9
lines changed

4 files changed

+21
-9
lines changed

pgml-docs/docs/user_guides/transformers/pre_trained_models.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ The Hugging Face [`Pipeline`](https://huggingface.co/docs/transformers/main_clas
1111

1212
```sql linenums="1" title="transformer.sql"
1313
pgml.transform(
14-
task TEXT OR JSONB, -- task name or full pipeline initializer arguments
15-
call JSONB, -- additional call arguments alongside the inputs
16-
inputs TEXT[] OR BYTEA[] -- inputs for inference
14+
task TEXT OR JSONB, -- task name or full pipeline initializer arguments
15+
call JSONB, -- additional call arguments alongside the inputs
16+
inputs TEXT[] OR BYTEA[], -- inputs for inference
17+
cache BOOLEAN -- if TRUE, the model will be cached in memory. FALSE by default.
1718
)
1819
```
1920

@@ -73,7 +74,8 @@ Sentiment analysis is one use of `text-classification`, but there are [many othe
7374
inputs => ARRAY[
7475
'I love how amazingly simple ML has become!',
7576
'I hate doing mundane and thankless tasks. ☹️'
76-
]
77+
],
78+
cache => TRUE
7779
) AS positivity;
7880
```
7981

pgml-extension/src/api.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,10 @@ pub fn transform_json(
564564
task: JsonB,
565565
args: default!(JsonB, "'{}'"),
566566
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
567+
cache: default!(bool, false)
567568
) -> JsonB {
568569
JsonB(crate::bindings::transformers::transform(
569-
&task.0, &args.0, &inputs,
570+
&task.0, &args.0, &inputs, cache
570571
))
571572
}
572573

@@ -576,12 +577,13 @@ pub fn transform_string(
576577
task: String,
577578
args: default!(JsonB, "'{}'"),
578579
inputs: default!(Vec<String>, "ARRAY[]::TEXT[]"),
580+
cache: default!(bool, false)
579581
) -> JsonB {
580582
let mut task_map = HashMap::new();
581583
task_map.insert("task", task);
582584
let task_json = json!(task_map);
583585
JsonB(crate::bindings::transformers::transform(
584-
&task_json, &args.0, &inputs,
586+
&task_json, &args.0, &inputs, cache
585587
))
586588
}
587589

pgml-extension/src/bindings/transformers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,26 @@
3939

4040
__cache_transformer_by_model_id = {}
4141
__cache_sentence_transformer_by_name = {}
42+
__cache_transform_pipeline_by_task = {}
4243

4344
class NumpyJSONEncoder(json.JSONEncoder):
4445
def default(self, obj):
4546
if isinstance(obj, np.float32):
4647
return float(obj)
4748
return super().default(obj)
4849

49-
def transform(task, args, inputs):
50+
def transform(task, args, inputs, cache):
5051
task = json.loads(task)
5152
args = json.loads(args)
5253
inputs = json.loads(inputs)
5354

54-
pipe = transformers.pipeline(**task)
55+
if cache:
56+
key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())])
57+
if key not in __cache_transform_pipeline_by_task:
58+
__cache_transform_pipeline_by_task[key] = transformers.pipeline(**task)
59+
pipe = __cache_transform_pipeline_by_task[key]
60+
else:
61+
pipe = transformers.pipeline(**task)
5562

5663
if pipe.task == "question-answering":
5764
inputs = [json.loads(input) for input in inputs]

pgml-extension/src/bindings/transformers.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub fn transform(
2525
task: &serde_json::Value,
2626
args: &serde_json::Value,
2727
inputs: &Vec<String>,
28+
cache: bool
2829
) -> serde_json::Value {
2930
let task = serde_json::to_string(task).unwrap();
3031
let args = serde_json::to_string(args).unwrap();
@@ -38,7 +39,7 @@ pub fn transform(
3839
py,
3940
PyTuple::new(
4041
py,
41-
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
42+
&[task.into_py(py), args.into_py(py), inputs.into_py(py), cache.into_py(py)],
4243
),
4344
)
4445
.unwrap()

0 commit comments

Comments
 (0)