@@ -5,29 +5,29 @@ use pgrx::{pg_schema, pg_test};
5
5
use serde_json:: Value ;
6
6
use std:: ffi:: CStr ;
7
7
8
- use crate :: config:: { PGML_HF_TRUST_REMOTE_CODE , PGML_HF_TRUST_WHITELIST , PGML_HF_WHITELIST } ;
8
+ use crate :: config:: { PGML_HF_TRUST_REMOTE_CODE , PGML_HF_TRUST_REMOTE_CODE_WHITELIST , PGML_HF_WHITELIST } ;
9
9
10
10
/// Verify that the model in the task JSON is allowed based on the huggingface whitelists.
11
11
pub fn verify_task ( task : & Value ) -> Result < ( ) , Error > {
12
12
let task_model = match get_model_name ( task) {
13
13
Some ( model) => model. to_string ( ) ,
14
14
None => return Ok ( ( ) ) ,
15
15
} ;
16
- let whitelisted_models = config_csv_list ( & PGML_HF_WHITELIST . 1 ) ;
16
+ let whitelisted_models = config_csv_list ( & PGML_HF_WHITELIST ) ;
17
17
18
18
let model_is_allowed = whitelisted_models. is_empty ( ) || whitelisted_models. contains ( & task_model) ;
19
19
if !model_is_allowed {
20
20
bail ! (
21
21
"model {} is not whitelisted. Consider adding to {} in postgresql.conf" ,
22
22
task_model,
23
- PGML_HF_WHITELIST . 0
23
+ "pgml.huggingface_whitelist"
24
24
) ;
25
25
}
26
26
27
27
let task_trust = get_trust_remote_code ( task) ;
28
- let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE . 1 . get ( ) ;
28
+ let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE . get ( ) ;
29
29
30
- let trusted_models = config_csv_list ( & PGML_HF_TRUST_WHITELIST . 1 ) ;
30
+ let trusted_models = config_csv_list ( & PGML_HF_TRUST_REMOTE_CODE_WHITELIST ) ;
31
31
32
32
let model_is_trusted = trusted_models. is_empty ( ) || trusted_models. contains ( & task_model) ;
33
33
@@ -36,9 +36,9 @@ pub fn verify_task(task: &Value) -> Result<(), Error> {
36
36
bail ! (
37
37
"model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}" ,
38
38
task_model,
39
- PGML_HF_TRUST_REMOTE_CODE . 0 ,
39
+ "pgml.huggingface_trust_remote_code" ,
40
40
task_model,
41
- PGML_HF_TRUST_WHITELIST . 0
41
+ "pgml.huggingface_trust_remote_code_whitelist" ,
42
42
) ;
43
43
}
44
44
@@ -129,7 +129,7 @@ mod tests {
129
129
#[ pg_test]
130
130
fn test_empty_whitelist ( ) {
131
131
let model = "Salesforce/xgen-7b-8k-inst" ;
132
- set_config ( PGML_HF_WHITELIST . 0 , "" ) . unwrap ( ) ;
132
+ set_config ( "pgml.huggingface_whitelist" , "" ) . unwrap ( ) ;
133
133
let task_json = format ! ( json_template!( ) , model, false ) ;
134
134
let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
135
135
assert ! ( verify_task( & task) . is_ok( ) ) ;
@@ -138,12 +138,12 @@ mod tests {
138
138
#[ pg_test]
139
139
fn test_nonempty_whitelist ( ) {
140
140
let model = "Salesforce/xgen-7b-8k-inst" ;
141
- set_config ( PGML_HF_WHITELIST . 0 , model) . unwrap ( ) ;
141
+ set_config ( "pgml.huggingface_whitelist" , model) . unwrap ( ) ;
142
142
let task_json = format ! ( json_template!( ) , model, false ) ;
143
143
let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
144
144
assert ! ( verify_task( & task) . is_ok( ) ) ;
145
145
146
- set_config ( PGML_HF_WHITELIST . 0 , "other_model" ) . unwrap ( ) ;
146
+ set_config ( "pgml.huggingface_whitelist" , "other_model" ) . unwrap ( ) ;
147
147
let task_json = format ! ( json_template!( ) , model, false ) ;
148
148
let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
149
149
assert ! ( verify_task( & task) . is_err( ) ) ;
@@ -152,8 +152,8 @@ mod tests {
152
152
#[ pg_test]
153
153
fn test_trusted_model ( ) {
154
154
let model = "Salesforce/xgen-7b-8k-inst" ;
155
- set_config ( PGML_HF_WHITELIST . 0 , model) . unwrap ( ) ;
156
- set_config ( PGML_HF_TRUST_WHITELIST . 0 , model) . unwrap ( ) ;
155
+ set_config ( "pgml.huggingface_whitelist" , model) . unwrap ( ) ;
156
+ set_config ( "pgml.huggingface_trust_remote_code_whitelist" , model) . unwrap ( ) ;
157
157
158
158
let task_json = format ! ( json_template!( ) , model, false ) ;
159
159
let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
@@ -163,7 +163,7 @@ mod tests {
163
163
let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
164
164
assert ! ( verify_task( & task) . is_err( ) ) ;
165
165
166
- set_config ( PGML_HF_TRUST_REMOTE_CODE . 0 , "true" ) . unwrap ( ) ;
166
+ set_config ( "pgml.huggingface_trust_remote_code" , "true" ) . unwrap ( ) ;
167
167
let task_json = format ! ( json_template!( ) , model, false ) ;
168
168
let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
169
169
assert ! ( verify_task( & task) . is_ok( ) ) ;
@@ -176,8 +176,8 @@ mod tests {
176
176
#[ pg_test]
177
177
fn test_untrusted_model ( ) {
178
178
let model = "Salesforce/xgen-7b-8k-inst" ;
179
- set_config ( PGML_HF_WHITELIST . 0 , model) . unwrap ( ) ;
180
- set_config ( PGML_HF_TRUST_WHITELIST . 0 , "other_model" ) . unwrap ( ) ;
179
+ set_config ( "pgml.huggingface_whitelist" , model) . unwrap ( ) ;
180
+ set_config ( "pgml.huggingface_trust_remote_code_whitelist" , "other_model" ) . unwrap ( ) ;
181
181
182
182
let task_json = format ! ( json_template!( ) , model, false ) ;
183
183
let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
@@ -187,7 +187,7 @@ mod tests {
187
187
let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
188
188
assert ! ( verify_task( & task) . is_err( ) ) ;
189
189
190
- set_config ( PGML_HF_TRUST_REMOTE_CODE . 0 , "true" ) . unwrap ( ) ;
190
+ set_config ( "pgml.huggingface_trust_remote_code" , "true" ) . unwrap ( ) ;
191
191
let task_json = format ! ( json_template!( ) , model, false ) ;
192
192
let task: Value = serde_json:: from_str ( & task_json) . unwrap ( ) ;
193
193
assert ! ( verify_task( & task) . is_ok( ) ) ;
0 commit comments