Skip to content

Commit fb2426f

Browse files
authored
Silas add ranking (#1498)
1 parent 72473de commit fb2426f

File tree

6 files changed

+175
-16
lines changed

6 files changed

+175
-16
lines changed

pgml-extension/Cargo.lock

Lines changed: 20 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "2.8.5"
3+
version = "2.9.0"
44
edition = "2021"
55

66
[lib]
@@ -41,7 +41,7 @@ ndarray-stats = "0.5.1"
4141
parking_lot = "0.12"
4242
pgrx = "=0.11.3"
4343
pgrx-pg-sys = "=0.11.3"
44-
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
44+
pyo3 = { version = "0.20.0", features = ["anyhow", "auto-initialize"], optional = true }
4545
rand = "0.8"
4646
rmp-serde = { version = "1.1" }
4747
signal-hook = "0.3"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- src/api.rs:613
2+
-- pgml::api::rank
3+
CREATE FUNCTION pgml."rank"(
4+
"transformer" TEXT, /* &str */
5+
"query" TEXT, /* &str */
6+
"documents" TEXT[], /* alloc::vec::Vec<&str> */
7+
"kwargs" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
8+
) RETURNS TABLE (
9+
"corpus_id" bigint, /* i64 */
10+
"score" double precision, /* f64 */
11+
"text" TEXT /* core::option::Option<alloc::string::String> */
12+
)
13+
IMMUTABLE STRICT PARALLEL SAFE
14+
LANGUAGE c /* Rust */
15+
AS 'MODULE_PATHNAME', 'rank_wrapper';

pgml-extension/src/api.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,21 @@ pub fn embed_batch(
603603
kwargs: default!(JsonB, "'{}'"),
604604
) -> SetOfIterator<'static, Vec<f32>> {
605605
match crate::bindings::transformers::embed(transformer, inputs, &kwargs.0) {
606-
Ok(output) => SetOfIterator::new(output.into_iter()),
606+
Ok(output) => SetOfIterator::new(output),
607+
Err(e) => error!("{e}"),
608+
}
609+
}
610+
611+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
612+
#[pg_extern(immutable, parallel_safe, name = "rank")]
613+
pub fn rank(
614+
transformer: &str,
615+
query: &str,
616+
documents: Vec<&str>,
617+
kwargs: default!(JsonB, "'{}'"),
618+
) -> TableIterator<'static, (name!(corpus_id, i64), name!(score, f64), name!(text, Option<String>))> {
619+
match crate::bindings::transformers::rank(transformer, query, documents, &kwargs.0) {
620+
Ok(output) => TableIterator::new(output.into_iter().map(|x| (x.corpus_id, x.score, x.text))),
607621
Err(e) => error!("{e}"),
608622
}
609623
}

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use std::{collections::HashMap, path::Path};
66
use anyhow::{anyhow, bail, Context, Result};
77
use pgrx::*;
88
use pyo3::prelude::*;
9-
use pyo3::types::PyTuple;
9+
use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple};
10+
use serde::Deserialize;
1011
use serde_json::Value;
1112

1213
use crate::create_pymodule;
@@ -21,6 +22,59 @@ pub use transform::*;
2122

2223
create_pymodule!("/src/bindings/transformers/transformers.py");
2324

25+
// Need a wrapper so we can implement traits for it
26+
struct Json(Value);
27+
28+
impl From<Json> for Value {
29+
fn from(value: Json) -> Self {
30+
value.0
31+
}
32+
}
33+
34+
impl FromPyObject<'_> for Json {
35+
fn extract(ob: &PyAny) -> PyResult<Self> {
36+
if ob.is_instance_of::<PyDict>() {
37+
let dict: &PyDict = ob.downcast()?;
38+
let mut json = serde_json::Map::new();
39+
for (key, value) in dict.iter() {
40+
let value = Json::extract(value)?;
41+
json.insert(String::extract(key)?, value.0);
42+
}
43+
Ok(Self(serde_json::Value::Object(json)))
44+
} else if ob.is_instance_of::<PyBool>() {
45+
let value = bool::extract(ob)?;
46+
Ok(Self(serde_json::Value::Bool(value)))
47+
} else if ob.is_instance_of::<PyInt>() {
48+
let value = i64::extract(ob)?;
49+
Ok(Self(serde_json::Value::Number(value.into())))
50+
} else if ob.is_instance_of::<PyFloat>() {
51+
let value = f64::extract(ob)?;
52+
let value =
53+
serde_json::value::Number::from_f64(value).context("Could not convert f64 to serde_json::Number")?;
54+
Ok(Self(serde_json::Value::Number(value)))
55+
} else if ob.is_instance_of::<PyString>() {
56+
let value = String::extract(ob)?;
57+
Ok(Self(serde_json::Value::String(value)))
58+
} else if ob.is_instance_of::<PyList>() {
59+
let value = ob.downcast::<PyList>()?;
60+
let mut json_values = Vec::new();
61+
for v in value {
62+
let v = v.extract::<Json>()?;
63+
json_values.push(v.0);
64+
}
65+
Ok(Self(serde_json::Value::Array(json_values)))
66+
} else {
67+
if ob.is_none() {
68+
return Ok(Self(serde_json::Value::Null));
69+
}
70+
Err(anyhow::anyhow!(
71+
"Unsupported type for JSON conversion: {:?}",
72+
ob.get_type()
73+
))?
74+
}
75+
}
76+
}
77+
2478
pub fn get_model_from(task: &Value) -> Result<String> {
2579
Python::with_gil(|py| -> Result<String> {
2680
let get_model_from = get_module!(PY_MODULE)
@@ -55,6 +109,46 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -
55109
})
56110
}
57111

112+
#[derive(Deserialize)]
113+
pub struct RankResult {
114+
pub corpus_id: i64,
115+
pub score: f64,
116+
pub text: Option<String>,
117+
}
118+
119+
pub fn rank(
120+
transformer: &str,
121+
query: &str,
122+
documents: Vec<&str>,
123+
kwargs: &serde_json::Value,
124+
) -> Result<Vec<RankResult>> {
125+
let kwargs = serde_json::to_string(kwargs)?;
126+
Python::with_gil(|py| -> Result<Vec<RankResult>> {
127+
let embed: Py<PyAny> = get_module!(PY_MODULE).getattr(py, "rank").format_traceback(py)?;
128+
let output = embed
129+
.call1(
130+
py,
131+
PyTuple::new(
132+
py,
133+
&[
134+
transformer.to_string().into_py(py),
135+
query.into_py(py),
136+
documents.into_py(py),
137+
kwargs.into_py(py),
138+
],
139+
),
140+
)
141+
.format_traceback(py)?;
142+
let out: Vec<Json> = output.extract(py).format_traceback(py)?;
143+
out.into_iter()
144+
.map(|x| {
145+
let x: RankResult = serde_json::from_value(x.0)?;
146+
Ok(x)
147+
})
148+
.collect()
149+
})
150+
}
151+
58152
pub fn finetune_text_classification(
59153
task: &Task,
60154
dataset: TextClassificationDataset,

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import orjson
1313
from rouge import Rouge
1414
from sacrebleu.metrics import BLEU
15-
from sentence_transformers import SentenceTransformer
15+
from sentence_transformers import SentenceTransformer, CrossEncoder
1616
from sklearn.metrics import (
1717
mean_squared_error,
1818
r2_score,
@@ -500,6 +500,33 @@ def transform(task, args, inputs, stream=False):
500500
return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode()
501501

502502

503+
def create_cross_encoder(transformer):
504+
return CrossEncoder(transformer)
505+
506+
507+
def rank_using(model, query, documents, kwargs):
508+
if isinstance(kwargs, str):
509+
kwargs = orjson.loads(kwargs)
510+
511+
# The score is a numpy float32 before we convert it
512+
return [
513+
{"score": x.pop("score").item(), **x}
514+
for x in model.rank(query, documents, **kwargs)
515+
]
516+
517+
518+
def rank(transformer, query, documents, kwargs):
519+
kwargs = orjson.loads(kwargs)
520+
521+
if transformer not in __cache_sentence_transformer_by_name:
522+
__cache_sentence_transformer_by_name[transformer] = create_cross_encoder(
523+
transformer
524+
)
525+
model = __cache_sentence_transformer_by_name[transformer]
526+
527+
return rank_using(model, query, documents, kwargs)
528+
529+
503530
def create_embedding(transformer):
504531
return SentenceTransformer(transformer)
505532

0 commit comments

Comments
 (0)