From 600bbfe8541e68455479fbb0347f22b5840e3193 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 6 Jun 2024 15:29:50 -0700 Subject: [PATCH] Updated load_dataset to be resistent to bad columns --- .../src/bindings/transformers/mod.rs | 49 ++++++++++++------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 33f103e62..a5bd045c3 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -434,6 +434,7 @@ pub fn load_dataset( Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#))?; let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); for i in 0..num_rows { + let mut skip = false; let mut row = Vec::with_capacity(num_cols); for (name, values) in data { let value = values @@ -441,47 +442,57 @@ pub fn load_dataset( .ok_or_else(|| anyhow!("expected {values} to be an array"))? .get(i) .ok_or_else(|| anyhow!("invalid index {i} for {values}"))?; - match types + let (ty, datum) = match types .get(name) .ok_or_else(|| anyhow!("{types:?} expected to have key {name}"))? .as_str() .ok_or_else(|| anyhow!("json field {name} expected to be string"))? { - "string" => row.push(( + "string" => ( PgBuiltInOids::TEXTOID.oid(), value .as_str() - .ok_or_else(|| anyhow!("expected {value} to be string"))? - .into_datum(), - )), - "dict" | "list" => row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())), - "int64" | "int32" | "int16" => row.push(( + .map(IntoDatum::into_datum) + .ok_or_else(|| anyhow!("expected column {name} with {value} to be string")), + ), + "dict" | "list" => (PgBuiltInOids::JSONBOID.oid(), Ok(JsonB(value.clone()).into_datum())), + "int64" | "int32" | "int16" => ( PgBuiltInOids::INT8OID.oid(), value .as_i64() - .ok_or_else(|| anyhow!("expected {value} to be i64"))? - .into_datum(), - )), - "float64" | "float32" | "float16" => row.push(( + .map(IntoDatum::into_datum) + .ok_or_else(|| anyhow!("expected column {name} with {value} to be i64")), + ), + "float64" | "float32" | "float16" => ( PgBuiltInOids::FLOAT8OID.oid(), value .as_f64() - .ok_or_else(|| anyhow!("expected {value} to be f64"))? - .into_datum(), - )), - "bool" => row.push(( + .map(IntoDatum::into_datum) + .ok_or_else(|| anyhow!("expected column {name} with {value} to be f64")), + ), + "bool" => ( PgBuiltInOids::BOOLOID.oid(), value .as_bool() - .ok_or_else(|| anyhow!("expected {value} to be bool"))? - .into_datum(), - )), + .map(IntoDatum::into_datum) + .ok_or_else(|| anyhow!("expected column {name} with {value} to be bool")), + ), type_ => { bail!("unhandled dataset value type while reading dataset: {value:?} {type_:?}") } + }; + match datum { + Ok(datum) => row.push((ty, datum)), + Err(e) => { + warning!("failed to convert dataset value to datum while reading dataset: {e}"); + skip = true; + break; + } } } - Spi::run_with_args(&insert, Some(row))? + if !skip { + Spi::run_with_args(&insert, Some(row))? + } } Ok(num_rows)