Skip to content

Commit 403816a

Browse files
authored
add backtrace for debugging embed (#760)
1 parent 7cf4da0 commit 403816a

File tree

2 files changed

+25
-33
lines changed

2 files changed

+25
-33
lines changed

pgml-extension/src/api.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,6 @@ pub fn embed_batch(
580580
crate::bindings::transformers::embed(transformer, inputs, &kwargs.0)
581581
}
582582

583-
584583
/// Clears the GPU cache.
585584
///
586585
/// # Arguments
@@ -596,10 +595,9 @@ pub fn embed_batch(
596595
/// SELECT pgml.clear_gpu_cache(memory_usage => 0.5);
597596
/// ```
598597
#[pg_extern(immutable, parallel_safe, name = "clear_gpu_cache")]
599-
pub fn clear_gpu_cache(
600-
memory_usage: default!(Option<f32>, "NULL")
601-
) -> bool {
602-
let memory_usage: Option<f32> = memory_usage.map(|memory_usage| memory_usage.try_into().unwrap());
598+
pub fn clear_gpu_cache(memory_usage: default!(Option<f32>, "NULL")) -> bool {
599+
let memory_usage: Option<f32> =
600+
memory_usage.map(|memory_usage| memory_usage.try_into().unwrap());
603601
crate::bindings::transformers::clear_gpu_cache(memory_usage)
604602
}
605603

pgml-extension/src/bindings/transformers.rs

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,27 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -
6262
let kwargs = serde_json::to_string(kwargs).unwrap();
6363
Python::with_gil(|py| -> Vec<Vec<f32>> {
6464
let embed: Py<PyAny> = PY_MODULE.getattr(py, "embed").unwrap().into();
65-
embed
66-
.call1(
65+
let result = embed.call1(
66+
py,
67+
PyTuple::new(
6768
py,
68-
PyTuple::new(
69-
py,
70-
&[
71-
transformer.to_string().into_py(py),
72-
inputs.into_py(py),
73-
kwargs.into_py(py),
74-
],
75-
),
76-
)
77-
.unwrap()
78-
.extract(py)
79-
.unwrap()
69+
&[
70+
transformer.to_string().into_py(py),
71+
inputs.into_py(py),
72+
kwargs.into_py(py),
73+
],
74+
),
75+
);
76+
77+
let result = match result {
78+
Err(e) => {
79+
let traceback = e.traceback(py).unwrap().format().unwrap();
80+
error!("{traceback} {e}")
81+
}
82+
Ok(o) => o.extract(py).unwrap(),
83+
};
84+
85+
result
8086
})
8187
}
8288

@@ -312,25 +318,13 @@ pub fn load_dataset(
312318
num_rows
313319
}
314320

315-
pub fn clear_gpu_cache(
316-
memory_usage: Option<f32>
317-
) -> bool {
318-
321+
pub fn clear_gpu_cache(memory_usage: Option<f32>) -> bool {
319322
Python::with_gil(|py| -> bool {
320323
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache").unwrap().into();
321324
clear_gpu_cache
322-
.call1(
323-
py,
324-
PyTuple::new(
325-
py,
326-
&[
327-
memory_usage.into_py(py),
328-
],
329-
),
330-
)
325+
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))
331326
.unwrap()
332327
.extract(py)
333328
.unwrap()
334329
})
335330
}
336-

0 commit comments

Comments
 (0)