Skip to content

Commit 3768179

Browse files
committed
Simplify global variables.
1 parent 2a1853c commit 3768179

File tree

4 files changed

+40
-48
lines changed

4 files changed

+40
-48
lines changed

pgml-extension/src/bindings/python/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub fn activate_venv(venv: &str) -> Result<bool> {
2121
}
2222

2323
pub fn activate() -> Result<bool> {
24-
match PGML_VENV.1.get() {
24+
match PGML_VENV.get() {
2525
Some(venv) => activate_venv(&venv.to_string_lossy()),
2626
None => Ok(false),
2727
}

pgml-extension/src/bindings/transformers/whitelist.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@ use pgrx::{pg_schema, pg_test};
55
use serde_json::Value;
66
use std::ffi::CStr;
77

8-
use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_WHITELIST, PGML_HF_WHITELIST};
8+
use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_REMOTE_CODE_WHITELIST, PGML_HF_WHITELIST};
99

1010
/// Verify that the model in the task JSON is allowed based on the huggingface whitelists.
1111
pub fn verify_task(task: &Value) -> Result<(), Error> {
1212
let task_model = match get_model_name(task) {
1313
Some(model) => model.to_string(),
1414
None => return Ok(()),
1515
};
16-
let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST.1);
16+
let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST);
1717

1818
let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model);
1919
if !model_is_allowed {
2020
bail!(
2121
"model {} is not whitelisted. Consider adding to {} in postgresql.conf",
2222
task_model,
23-
PGML_HF_WHITELIST.0
23+
"pgml.huggingface_whitelist"
2424
);
2525
}
2626

2727
let task_trust = get_trust_remote_code(task);
28-
let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.1.get();
28+
let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.get();
2929

30-
let trusted_models = config_csv_list(&PGML_HF_TRUST_WHITELIST.1);
30+
let trusted_models = config_csv_list(&PGML_HF_TRUST_REMOTE_CODE_WHITELIST);
3131

3232
let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model);
3333

@@ -36,9 +36,9 @@ pub fn verify_task(task: &Value) -> Result<(), Error> {
3636
bail!(
3737
"model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}",
3838
task_model,
39-
PGML_HF_TRUST_REMOTE_CODE.0,
39+
"pgml.huggingface_trust_remote_code",
4040
task_model,
41-
PGML_HF_TRUST_WHITELIST.0
41+
"pgml.huggingface_trust_remote_code_whitelist",
4242
);
4343
}
4444

@@ -129,7 +129,7 @@ mod tests {
129129
#[pg_test]
130130
fn test_empty_whitelist() {
131131
let model = "Salesforce/xgen-7b-8k-inst";
132-
set_config(PGML_HF_WHITELIST.0, "").unwrap();
132+
set_config("pgml.huggingface_whitelist", "").unwrap();
133133
let task_json = format!(json_template!(), model, false);
134134
let task: Value = serde_json::from_str(&task_json).unwrap();
135135
assert!(verify_task(&task).is_ok());
@@ -138,12 +138,12 @@ mod tests {
138138
#[pg_test]
139139
fn test_nonempty_whitelist() {
140140
let model = "Salesforce/xgen-7b-8k-inst";
141-
set_config(PGML_HF_WHITELIST.0, model).unwrap();
141+
set_config("pgml.huggingface_whitelist", model).unwrap();
142142
let task_json = format!(json_template!(), model, false);
143143
let task: Value = serde_json::from_str(&task_json).unwrap();
144144
assert!(verify_task(&task).is_ok());
145145

146-
set_config(PGML_HF_WHITELIST.0, "other_model").unwrap();
146+
set_config("pgml.huggingface_whitelist", "other_model").unwrap();
147147
let task_json = format!(json_template!(), model, false);
148148
let task: Value = serde_json::from_str(&task_json).unwrap();
149149
assert!(verify_task(&task).is_err());
@@ -152,8 +152,8 @@ mod tests {
152152
#[pg_test]
153153
fn test_trusted_model() {
154154
let model = "Salesforce/xgen-7b-8k-inst";
155-
set_config(PGML_HF_WHITELIST.0, model).unwrap();
156-
set_config(PGML_HF_TRUST_WHITELIST.0, model).unwrap();
155+
set_config("pgml.huggingface_whitelist", model).unwrap();
156+
set_config("pgml.huggingface_trust_remote_code_whitelist", model).unwrap();
157157

158158
let task_json = format!(json_template!(), model, false);
159159
let task: Value = serde_json::from_str(&task_json).unwrap();
@@ -163,7 +163,7 @@ mod tests {
163163
let task: Value = serde_json::from_str(&task_json).unwrap();
164164
assert!(verify_task(&task).is_err());
165165

166-
set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap();
166+
set_config("pgml.huggingface_trust_remote_code", "true").unwrap();
167167
let task_json = format!(json_template!(), model, false);
168168
let task: Value = serde_json::from_str(&task_json).unwrap();
169169
assert!(verify_task(&task).is_ok());
@@ -176,8 +176,8 @@ mod tests {
176176
#[pg_test]
177177
fn test_untrusted_model() {
178178
let model = "Salesforce/xgen-7b-8k-inst";
179-
set_config(PGML_HF_WHITELIST.0, model).unwrap();
180-
set_config(PGML_HF_TRUST_WHITELIST.0, "other_model").unwrap();
179+
set_config("pgml.huggingface_whitelist", model).unwrap();
180+
set_config("pgml.huggingface_trust_remote_code_whitelist", "other_model").unwrap();
181181

182182
let task_json = format!(json_template!(), model, false);
183183
let task: Value = serde_json::from_str(&task_json).unwrap();
@@ -187,7 +187,7 @@ mod tests {
187187
let task: Value = serde_json::from_str(&task_json).unwrap();
188188
assert!(verify_task(&task).is_err());
189189

190-
set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap();
190+
set_config("pgml.huggingface_trust_remote_code", "true").unwrap();
191191
let task_json = format!(json_template!(), model, false);
192192
let task: Value = serde_json::from_str(&task_json).unwrap();
193193
assert!(verify_task(&task).is_ok());

pgml-extension/src/config.rs

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,58 @@
1-
use once_cell::sync::Lazy;
21
use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting};
32
use std::ffi::CStr;
43

54
#[cfg(any(test, feature = "pg_test"))]
65
use pgrx::{pg_schema, pg_test};
76

8-
pub static PGML_VENV: Lazy<(&'static str, GucSetting<Option<&'static CStr>>)> =
9-
Lazy::new(|| ("pgml.venv", GucSetting::<Option<&'static CStr>>::new(None)));
10-
pub static PGML_HF_WHITELIST: Lazy<(&'static str, GucSetting<Option<&'static CStr>>)> = Lazy::new(|| {
11-
(
12-
"pgml.huggingface_whitelist",
13-
GucSetting::<Option<&'static CStr>>::new(None),
14-
)
15-
});
16-
pub static PGML_HF_TRUST_REMOTE_CODE: Lazy<(&'static str, GucSetting<bool>)> =
17-
Lazy::new(|| ("pgml.huggingface_trust_remote_code", GucSetting::<bool>::new(false)));
18-
pub static PGML_HF_TRUST_WHITELIST: Lazy<(&'static str, GucSetting<Option<&'static CStr>>)> = Lazy::new(|| {
19-
(
20-
"pgml.huggingface_trust_remote_code_whitelist",
21-
GucSetting::<Option<&'static CStr>>::new(None),
22-
)
23-
});
24-
pub static PGML_OMP_NUM_THREADS: Lazy<(&'static str, GucSetting<i32>)> =
25-
Lazy::new(|| ("pgml.omp_num_threads", GucSetting::<i32>::new(0)));
7+
pub static PGML_VENV: GucSetting<Option<&'static CStr>> = GucSetting::<Option<&'static CStr>>::new(None);
8+
pub static PGML_HF_WHITELIST: GucSetting<Option<&'static CStr>> = GucSetting::<Option<&'static CStr>>::new(None);
9+
pub static PGML_HF_TRUST_REMOTE_CODE: GucSetting<bool> = GucSetting::<bool>::new(false);
10+
pub static PGML_HF_TRUST_REMOTE_CODE_WHITELIST: GucSetting<Option<&'static CStr>> =
11+
GucSetting::<Option<&'static CStr>>::new(None);
12+
pub static PGML_OMP_NUM_THREADS: GucSetting<i32> = GucSetting::<i32>::new(0);
2613

2714
pub fn initialize_server_params() {
2815
GucRegistry::define_string_guc(
29-
PGML_VENV.0,
16+
"pgml.venv",
3017
"Python's virtual environment path",
3118
"",
32-
&PGML_VENV.1,
19+
&PGML_VENV,
3320
GucContext::Userset,
3421
GucFlags::default(),
3522
);
23+
3624
GucRegistry::define_string_guc(
37-
PGML_HF_WHITELIST.0,
25+
"pgml.huggingface_whitelist",
3826
"Models allowed to be downloaded from huggingface",
3927
"",
40-
&PGML_HF_WHITELIST.1,
28+
&PGML_HF_WHITELIST,
4129
GucContext::Userset,
4230
GucFlags::default(),
4331
);
32+
4433
GucRegistry::define_bool_guc(
45-
PGML_HF_TRUST_REMOTE_CODE.0,
34+
"pgml.huggingface_trust_remote_code",
4635
"Whether model can execute remote codes",
4736
"",
48-
&PGML_HF_TRUST_REMOTE_CODE.1,
37+
&PGML_HF_TRUST_REMOTE_CODE,
4938
GucContext::Userset,
5039
GucFlags::default(),
5140
);
41+
5242
GucRegistry::define_string_guc(
53-
PGML_HF_TRUST_WHITELIST.0,
43+
"pgml.huggingface_trust_remote_code_whitelist",
5444
"Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'",
5545
"",
56-
&PGML_HF_TRUST_WHITELIST.1,
46+
&PGML_HF_TRUST_REMOTE_CODE_WHITELIST,
5747
GucContext::Userset,
5848
GucFlags::default(),
5949
);
50+
6051
GucRegistry::define_int_guc(
61-
PGML_OMP_NUM_THREADS.0,
52+
"pgml.omp_num_threads",
6253
"Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid",
6354
"",
64-
&PGML_OMP_NUM_THREADS.1,
55+
&PGML_OMP_NUM_THREADS,
6556
0,
6657
i32::max_value(),
6758
GucContext::Backend,
@@ -87,7 +78,8 @@ mod tests {
8778
let name = "pgml.huggingface_whitelist";
8879
let value = "meta-llama/Llama-2-7b";
8980
set_config(name, value).unwrap();
90-
assert_eq!(PGML_HF_WHITELIST.1.get().unwrap().to_string_lossy(), value);
81+
assert_eq!(PGML_HF_WHITELIST.get().unwrap().to_str().unwrap(), value);
82+
//assert_eq!((&*PGML_HF_WHITELIST).get().unwrap().to_str().unwrap(), value);
9183
}
9284

9385
#[pg_test]

pgml-extension/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ extern "C" {
2929
#[pg_guard]
3030
pub extern "C" fn _PG_init() {
3131
config::initialize_server_params();
32-
let omp_num_threads = config::PGML_OMP_NUM_THREADS.1.get();
32+
let omp_num_threads = config::PGML_OMP_NUM_THREADS.get();
3333
if omp_num_threads > 0 {
3434
unsafe {
3535
omp_set_num_threads(omp_num_threads);

0 commit comments

Comments
 (0)