Skip to content

Commit c7494db

Browse files
montanalowMontana Low
andauthored
add support for numerics (#1324)
Co-authored-by: Montana Low <montanalow@gmail.com>
1 parent dd7c749 commit c7494db

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

pgml-extension/src/orm/model.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,12 @@ impl Model {
954954
.unwrap()
955955
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string())
956956
}
957+
pgrx_pg_sys::NUMERICOID => {
958+
let element: Result<Option<AnyNumeric>, TryFromDatumError> = tuple.get_by_index(index);
959+
element
960+
.unwrap()
961+
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string())
962+
}
957963
_ => error!(
958964
"Unsupported type for categorical column: {:?}. oid: {:?}",
959965
column.name, attribute.atttypid
@@ -992,6 +998,10 @@ impl Model {
992998
let element: Result<Option<f64>, TryFromDatumError> = tuple.get_by_index(index);
993999
features.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
9941000
}
1001+
pgrx_pg_sys::NUMERICOID => {
1002+
let element: Result<Option<AnyNumeric>, TryFromDatumError> = tuple.get_by_index(index);
1003+
features.push(element.unwrap().map_or(f32::NAN, |v| v.try_into().unwrap()));
1004+
}
9951005
// TODO handle NULL to NaN for arrays
9961006
pgrx_pg_sys::BOOLARRAYOID => {
9971007
let element: Result<Option<Vec<bool>>, TryFromDatumError> =
@@ -1035,6 +1045,13 @@ impl Model {
10351045
features.push(*j as f32);
10361046
}
10371047
}
1048+
pgrx_pg_sys::NUMERICARRAYOID => {
1049+
let element: Result<Option<Vec<AnyNumeric>>, TryFromDatumError> =
1050+
tuple.get_by_index(index);
1051+
for j in element.as_ref().unwrap().as_ref().unwrap() {
1052+
features.push(j.clone().try_into().unwrap());
1053+
}
1054+
}
10381055
_ => error!(
10391056
"Unsupported type for quantitative column: {:?}. oid: {:?}",
10401057
column.name, attribute.atttypid

pgml-extension/src/orm/snapshot.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,7 @@ impl Snapshot {
990990
"int8" => row[column.position].value::<i64>().unwrap().map(|v| v.to_string()),
991991
"float4" => row[column.position].value::<f32>().unwrap().map(|v| v.to_string()),
992992
"float8" => row[column.position].value::<f64>().unwrap().map(|v| v.to_string()),
993+
"numeric" => row[column.position].value::<AnyNumeric>().unwrap().map(|v| v.to_string()),
993994
"bpchar" | "text" | "varchar" => {
994995
row[column.position].value::<String>().unwrap().map(|v| v.to_string())
995996
}
@@ -1078,6 +1079,14 @@ impl Snapshot {
10781079
vector.push(j as f32)
10791080
}
10801081
}
1082+
"numeric[]" => {
1083+
let vec = row[column.position].value::<Vec<AnyNumeric>>().unwrap().unwrap();
1084+
check_column_size(column, vec.len());
1085+
1086+
for j in vec {
1087+
vector.push(j.rescale::<6,0>().unwrap().try_into().unwrap())
1088+
}
1089+
}
10811090
_ => error!(
10821091
"Unhandled type for quantitative array column: {} {:?}",
10831092
column.name, column.pg_type
@@ -1092,6 +1101,7 @@ impl Snapshot {
10921101
"int8" => row[column.position].value::<i64>().unwrap().map(|v| v as f32),
10931102
"float4" => row[column.position].value::<f32>().unwrap(),
10941103
"float8" => row[column.position].value::<f64>().unwrap().map(|v| v as f32),
1104+
"numeric" => row[column.position].value::<AnyNumeric>().unwrap().map(|v| v.rescale::<6,0>().unwrap().try_into().unwrap()),
10951105
_ => error!(
10961106
"Unhandled type for quantitative scalar column: {} {:?}",
10971107
column.name, column.pg_type

0 commit comments

Comments
 (0)