Content-Length: 844333 | pFad | http://github.com/postgresml/postgresml/commit/f674f70b542b164b98a9a63e009b2b3b7a73afdb

9A Allow user to limit the number of threads that OpenMP spawns. (#1362) · postgresml/postgresml@f674f70 · GitHub
Skip to content

Commit f674f70

Browse files
higuoxingXuebin Suxuebinsu
authored
Allow user to limit the number of threads that OpenMP spawns. (#1362)
Co-authored-by: Xuebin Su <sxuebin@vmware.com> Co-authored-by: Xuebin Su (苏学斌) <12034000+xuebinsu@users.noreply.github.com>
1 parent 1042f85 commit f674f70

File tree

4 files changed

+98
-46
lines changed

4 files changed

+98
-46
lines changed

pgml-extension/src/bindings/python/mod.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ use pgrx::*;
66
use pyo3::prelude::*;
77
use pyo3::types::PyTuple;
88

9-
use crate::config::get_config;
9+
use crate::config::PGML_VENV;
1010
use crate::create_pymodule;
1111

12-
static CONFIG_NAME: &str = "pgml.venv";
13-
1412
create_pymodule!("/src/bindings/python/python.py");
1513

1614
pub fn activate_venv(venv: &str) -> Result<bool> {
@@ -23,8 +21,8 @@ pub fn activate_venv(venv: &str) -> Result<bool> {
2321
}
2422

2523
pub fn activate() -> Result<bool> {
26-
match get_config(CONFIG_NAME) {
27-
Some(venv) => activate_venv(&venv),
24+
match PGML_VENV.get() {
25+
Some(venv) => activate_venv(&venv.to_string_lossy()),
2826
None => Ok(false),
2927
}
3028
}

pgml-extension/src/bindings/transformers/whitelist.rs

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,44 @@
11
use anyhow::{bail, Error};
2+
use pgrx::GucSetting;
23
#[cfg(any(test, feature = "pg_test"))]
34
use pgrx::{pg_schema, pg_test};
45
use serde_json::Value;
6+
use std::ffi::CStr;
57

6-
use crate::config::get_config;
7-
8-
static CONFIG_HF_WHITELIST: &str = "pgml.huggingface_whitelist";
9-
static CONFIG_HF_TRUST_REMOTE_CODE_BOOL: &str = "pgml.huggingface_trust_remote_code";
10-
static CONFIG_HF_TRUST_WHITELIST: &str = "pgml.huggingface_trust_remote_code_whitelist";
8+
use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_REMOTE_CODE_WHITELIST, PGML_HF_WHITELIST};
119

1210
//github.com/ Verify that the model in the task JSON is allowed based on the huggingface whitelists.
1311
pub fn verify_task(task: &Value) -> Result<(), Error> {
1412
let task_model = match get_model_name(task) {
1513
Some(model) => model.to_string(),
1614
None => return Ok(()),
1715
};
18-
let whitelisted_models = config_csv_list(CONFIG_HF_WHITELIST);
16+
let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST);
1917

2018
let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model);
2119
if !model_is_allowed {
22-
bail!("model {task_model} is not whitelisted. Consider adding to {CONFIG_HF_WHITELIST} in postgresql.conf");
20+
bail!("model {task_model} is not whitelisted. Consider adding to `pgml.huggingface_whitelist` in postgresql.conf");
2321
}
2422

2523
let task_trust = get_trust_remote_code(task);
26-
let trust_remote_code = get_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL)
27-
.map(|v| v == "true")
28-
.unwrap_or(true);
24+
let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.get();
2925

30-
let trusted_models = config_csv_list(CONFIG_HF_TRUST_WHITELIST);
26+
let trusted_models = config_csv_list(&PGML_HF_TRUST_REMOTE_CODE_WHITELIST);
3127

3228
let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model);
3329

3430
let remote_code_allowed = trust_remote_code && model_is_trusted;
3531
if !remote_code_allowed && task_trust == Some(true) {
36-
bail!("model {task_model} is not trusted to run remote code. Consider setting {CONFIG_HF_TRUST_REMOTE_CODE_BOOL} = 'true' or adding {task_model} to {CONFIG_HF_TRUST_WHITELIST}");
32+
bail!("model {task_model} is not trusted to run remote code. Consider setting pgml.huggingface_trust_remote_code = 'true' or adding {task_model} to pgml.huggingface_trust_remote_code_whitelist");
3733
}
3834

3935
Ok(())
4036
}
4137

42-
fn config_csv_list(name: &str) -> Vec<String> {
43-
match get_config(name) {
38+
fn config_csv_list(csv_list: &GucSetting<Option<&'static CStr>>) -> Vec<String> {
39+
match csv_list.get() {
4440
Some(value) => value
41+
.to_string_lossy()
4542
.trim_matches('"')
4643
.split(',')
4744
.filter_map(|s| if s.is_empty() { None } else { Some(s.to_string()) })
@@ -122,7 +119,7 @@ mod tests {
122119
#[pg_test]
123120
fn test_empty_whitelist() {
124121
let model = "Salesforce/xgen-7b-8k-inst";
125-
set_config(CONFIG_HF_WHITELIST, "").unwrap();
122+
set_config("pgml.huggingface_whitelist", "").unwrap();
126123
let task_json = format!(json_template!(), model, false);
127124
let task: Value = serde_json::from_str(&task_json).unwrap();
128125
assert!(verify_task(&task).is_ok());
@@ -131,12 +128,12 @@ mod tests {
131128
#[pg_test]
132129
fn test_nonempty_whitelist() {
133130
let model = "Salesforce/xgen-7b-8k-inst";
134-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
131+
set_config("pgml.huggingface_whitelist", model).unwrap();
135132
let task_json = format!(json_template!(), model, false);
136133
let task: Value = serde_json::from_str(&task_json).unwrap();
137134
assert!(verify_task(&task).is_ok());
138135

139-
set_config(CONFIG_HF_WHITELIST, "other_model").unwrap();
136+
set_config("pgml.huggingface_whitelist", "other_model").unwrap();
140137
let task_json = format!(json_template!(), model, false);
141138
let task: Value = serde_json::from_str(&task_json).unwrap();
142139
assert!(verify_task(&task).is_err());
@@ -145,18 +142,18 @@ mod tests {
145142
#[pg_test]
146143
fn test_trusted_model() {
147144
let model = "Salesforce/xgen-7b-8k-inst";
148-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
149-
set_config(CONFIG_HF_TRUST_WHITELIST, model).unwrap();
145+
set_config("pgml.huggingface_whitelist", model).unwrap();
146+
set_config("pgml.huggingface_trust_remote_code_whitelist", model).unwrap();
150147

151148
let task_json = format!(json_template!(), model, false);
152149
let task: Value = serde_json::from_str(&task_json).unwrap();
153150
assert!(verify_task(&task).is_ok());
154151

155152
let task_json = format!(json_template!(), model, true);
156153
let task: Value = serde_json::from_str(&task_json).unwrap();
157-
assert!(verify_task(&task).is_ok());
154+
assert!(verify_task(&task).is_err());
158155

159-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap();
156+
set_config("pgml.huggingface_trust_remote_code", "true").unwrap();
160157
let task_json = format!(json_template!(), model, false);
161158
let task: Value = serde_json::from_str(&task_json).unwrap();
162159
assert!(verify_task(&task).is_ok());
@@ -169,8 +166,8 @@ mod tests {
169166
#[pg_test]
170167
fn test_untrusted_model() {
171168
let model = "Salesforce/xgen-7b-8k-inst";
172-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
173-
set_config(CONFIG_HF_TRUST_WHITELIST, "other_model").unwrap();
169+
set_config("pgml.huggingface_whitelist", model).unwrap();
170+
set_config("pgml.huggingface_trust_remote_code_whitelist", "other_model").unwrap();
174171

175172
let task_json = format!(json_template!(), model, false);
176173
let task: Value = serde_json::from_str(&task_json).unwrap();
@@ -180,7 +177,7 @@ mod tests {
180177
let task: Value = serde_json::from_str(&task_json).unwrap();
181178
assert!(verify_task(&task).is_err());
182179

183-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap();
180+
set_config("pgml.huggingface_trust_remote_code", "true").unwrap();
184181
let task_json = format!(json_template!(), model, false);
185182
let task: Value = serde_json::from_str(&task_json).unwrap();
186183
assert!(verify_task(&task).is_ok());

pgml-extension/src/config.rs

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,72 @@
1+
use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting};
12
use std::ffi::CStr;
23

34
#[cfg(any(test, feature = "pg_test"))]
45
use pgrx::{pg_schema, pg_test};
5-
use pgrx_pg_sys::AsPgCStr;
6-
7-
pub fn get_config(name: &str) -> Option<String> {
8-
// SAFETY: name is not null because it is a Rust reference.
9-
let ptr = unsafe { pgrx_pg_sys::GetConfigOption(name.as_pg_cstr(), true, false) };
10-
(!ptr.is_null()).then(move || {
11-
// SAFETY: assuming pgrx_pg_sys is providing a valid, null terminated pointer.
12-
unsafe { CStr::from_ptr(ptr) }.to_string_lossy().to_string()
13-
})
6+
7+
pub static PGML_VENV: GucSetting<Option<&'static CStr>> = GucSetting::<Option<&'static CStr>>::new(None);
8+
pub static PGML_HF_WHITELIST: GucSetting<Option<&'static CStr>> = GucSetting::<Option<&'static CStr>>::new(None);
9+
pub static PGML_HF_TRUST_REMOTE_CODE: GucSetting<bool> = GucSetting::<bool>::new(false);
10+
pub static PGML_HF_TRUST_REMOTE_CODE_WHITELIST: GucSetting<Option<&'static CStr>> =
11+
GucSetting::<Option<&'static CStr>>::new(None);
12+
pub static PGML_OMP_NUM_THREADS: GucSetting<i32> = GucSetting::<i32>::new(1);
13+
14+
extern "C" {
15+
fn omp_set_num_threads(num_threads: i32);
16+
}
17+
18+
pub fn initialize_server_params() {
19+
GucRegistry::define_string_guc(
20+
"pgml.venv",
21+
"Python's virtual environment path",
22+
"",
23+
&PGML_VENV,
24+
GucContext::Userset,
25+
GucFlags::default(),
26+
);
27+
28+
GucRegistry::define_string_guc(
29+
"pgml.huggingface_whitelist",
30+
"Models allowed to be downloaded from huggingface",
31+
"",
32+
&PGML_HF_WHITELIST,
33+
GucContext::Userset,
34+
GucFlags::default(),
35+
);
36+
37+
GucRegistry::define_bool_guc(
38+
"pgml.huggingface_trust_remote_code",
39+
"Whether model can execute remote codes",
40+
"",
41+
&PGML_HF_TRUST_REMOTE_CODE,
42+
GucContext::Userset,
43+
GucFlags::default(),
44+
);
45+
46+
GucRegistry::define_string_guc(
47+
"pgml.huggingface_trust_remote_code_whitelist",
48+
"Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'",
49+
"",
50+
&PGML_HF_TRUST_REMOTE_CODE_WHITELIST,
51+
GucContext::Userset,
52+
GucFlags::default(),
53+
);
54+
55+
GucRegistry::define_int_guc(
56+
"pgml.omp_num_threads",
57+
"Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid",
58+
"",
59+
&PGML_OMP_NUM_THREADS,
60+
1,
61+
i32::max_value(),
62+
GucContext::Backend,
63+
GucFlags::default(),
64+
);
65+
66+
let omp_num_threads = PGML_OMP_NUM_THREADS.get();
67+
unsafe {
68+
omp_set_num_threads(omp_num_threads);
69+
}
1470
}
1571

1672
#[cfg(any(test, feature = "pg_test"))]
@@ -26,17 +82,17 @@ pub fn set_config(name: &str, value: &str) -> Result<(), pgrx::spi::Error> {
2682
mod tests {
2783
use super::*;
2884

29-
#[pg_test]
30-
fn read_config_max_connections() {
31-
let name = "max_connections";
32-
assert_eq!(get_config(name), Some("100".into()));
33-
}
34-
3585
#[pg_test]
3686
fn read_pgml_huggingface_whitelist() {
3787
let name = "pgml.huggingface_whitelist";
3888
let value = "meta-llama/Llama-2-7b";
3989
set_config(name, value).unwrap();
40-
assert_eq!(get_config(name), Some(value.into()));
90+
assert_eq!(PGML_HF_WHITELIST.get().unwrap().to_str().unwrap(), value);
91+
}
92+
93+
#[pg_test]
94+
fn omp_num_threads_cannot_be_set_after_startup() {
95+
let result = std::panic::catch_unwind(|| set_config("pgml.omp_num_threads", "1"));
96+
assert!(result.is_err());
4197
}
4298
}

pgml-extension/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ extension_sql_file!("../sql/schema.sql", name = "schema");
2424
#[cfg(not(feature = "use_as_lib"))]
2525
#[pg_guard]
2626
pub extern "C" fn _PG_init() {
27+
config::initialize_server_params();
2728
bindings::python::activate().expect("Error setting python venv");
2829
orm::project::init();
2930
}
@@ -53,7 +54,7 @@ pub mod pg_test {
5354

5455
pub fn postgresql_conf_options() -> Vec<&'static str> {
5556
// return any postgresql.conf settings that are required for your tests
56-
let mut options = vec!["shared_preload_libraries = 'pgml'"];
57+
let mut options = vec!["shared_preload_libraries = 'pgml'", "pgml.omp_num_threads = '1'"];
5758
if let Some(venv) = option_env!("PGML_VENV") {
5859
let option = format!("pgml.venv = '{venv}'");
5960
options.push(Box::leak(option.into_boxed_str()));

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/commit/f674f70b542b164b98a9a63e009b2b3b7a73afdb

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy