diff --git a/pgml-extension/examples/embedding.sql b/pgml-extension/examples/embedding.sql new file mode 100644 index 000000000..4e6c5968d --- /dev/null +++ b/pgml-extension/examples/embedding.sql @@ -0,0 +1,7 @@ +\timing on + +SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"trust_remote_code": true}'); +SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cuda", "trust_remote_code": true}'); +SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cpu", "trust_remote_code": true}'); +SELECT pgml.embed('hkunlp/instructor-xl', 'hi mom', '{"instruction": "Encode it with love"}'); +SELECT pgml.embed('mixedbread-ai/mxbai-embed-large-v1', 'test', '{"prompt": "test prompt: "}'); diff --git a/pgml-extension/examples/image_classification.sql b/pgml-extension/examples/image_classification.sql index f9a7888a6..24e363e4a 100644 --- a/pgml-extension/examples/image_classification.sql +++ b/pgml-extension/examples/image_classification.sql @@ -66,7 +66,7 @@ SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', hyperpara -- runtimes SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'python'); -SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust'); +--SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust'); --SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'python', hyperparams => '{"n_estimators": 10}'); -- too slow SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'rust', hyperparams => '{"n_estimators": 10}'); diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index baa2c2500..ea2df12b9 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -527,8 +527,8 @@ def rank(transformer, query, documents, kwargs): return rank_using(model, query, documents, kwargs) -def create_embedding(transformer): - return SentenceTransformer(transformer) +def create_embedding(transformer, kwargs): + return SentenceTransformer(transformer, **kwargs) def embed_using(model, transformer, inputs, kwargs): @@ -545,16 +545,32 @@ def embed_using(model, transformer, inputs, kwargs): def embed(transformer, inputs, kwargs): kwargs = orjson.loads(kwargs) - ensure_device(kwargs) + init_kwarg_keys = [ + "device", + "trust_remote_code", + "revision", + "model_kwargs", + "tokenizer_kwargs", + "config_kwargs", + "truncate_dim", + "token", + ] + init_kwargs = { + key: value for key, value in kwargs.items() if key in init_kwarg_keys + } + encode_kwargs = { + key: value for key, value in kwargs.items() if key not in init_kwarg_keys + } + if transformer not in __cache_sentence_transformer_by_name: __cache_sentence_transformer_by_name[transformer] = create_embedding( - transformer + transformer, init_kwargs ) model = __cache_sentence_transformer_by_name[transformer] - return embed_using(model, transformer, inputs, kwargs) + return embed_using(model, transformer, inputs, encode_kwargs) def clear_gpu_cache(memory_usage: None): diff --git a/pgml-extension/tests/test.sql b/pgml-extension/tests/test.sql index 2256e0ca4..10ffb4339 100644 --- a/pgml-extension/tests/test.sql +++ b/pgml-extension/tests/test.sql @@ -31,5 +31,6 @@ SELECT pgml.load_dataset('wine'); \i examples/vectors.sql \i examples/chunking.sql \i examples/preprocessing.sql +\i examples/embedding.sql -- transformers are generally too slow to run in the test suite --\i examples/transformers.sql