Skip to content

Commit 79f9833

Browse files
authored
organize python related modules (#962)
1 parent 6bdcf00 commit 79f9833

File tree

14 files changed

+127
-120
lines changed

14 files changed

+127
-120
lines changed

pgml-extension/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ sacremoses==0.0.53
1717
scikit-learn==1.3.0
1818
sentencepiece==0.1.99
1919
sentence-transformers==2.2.2
20+
tokenizers==0.13.3
2021
torch==2.0.1
2122
torchaudio==2.0.2
2223
torchvision==0.15.2

pgml-extension/src/api.rs

Lines changed: 14 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ use pgrx::iter::{SetOfIterator, TableIterator};
66
use pgrx::*;
77

88
#[cfg(feature = "python")]
9-
use pyo3::prelude::*;
109
use serde_json::json;
1110

1211
#[cfg(feature = "python")]
13-
use crate::bindings::sklearn::package_version;
1412
use crate::orm::*;
1513

1614
macro_rules! unwrap_or_error {
@@ -25,38 +23,13 @@ macro_rules! unwrap_or_error {
2523
#[cfg(feature = "python")]
2624
#[pg_extern]
2725
pub fn activate_venv(venv: &str) -> bool {
28-
unwrap_or_error!(crate::bindings::venv::activate_venv(venv))
26+
unwrap_or_error!(crate::bindings::python::activate_venv(venv))
2927
}
3028

3129
#[cfg(feature = "python")]
3230
#[pg_extern(immutable, parallel_safe)]
3331
pub fn validate_python_dependencies() -> bool {
34-
unwrap_or_error!(crate::bindings::venv::activate());
35-
36-
Python::with_gil(|py| {
37-
let sys = PyModule::import(py, "sys").unwrap();
38-
let version: String = sys.getattr("version").unwrap().extract().unwrap();
39-
info!("Python version: {version}");
40-
for module in ["xgboost", "lightgbm", "numpy", "sklearn"] {
41-
match py.import(module) {
42-
Ok(_) => (),
43-
Err(e) => {
44-
panic!(
45-
"The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"
46-
);
47-
}
48-
}
49-
}
50-
});
51-
52-
let sklearn = unwrap_or_error!(package_version("sklearn"));
53-
let xgboost = unwrap_or_error!(package_version("xgboost"));
54-
let lightgbm = unwrap_or_error!(package_version("lightgbm"));
55-
let numpy = unwrap_or_error!(package_version("numpy"));
56-
57-
info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);
58-
59-
true
32+
unwrap_or_error!(crate::bindings::python::validate_dependencies())
6033
}
6134

6235
#[cfg(not(feature = "python"))]
@@ -66,8 +39,7 @@ pub fn validate_python_dependencies() {}
6639
#[cfg(feature = "python")]
6740
#[pg_extern]
6841
pub fn python_package_version(name: &str) -> String {
69-
unwrap_or_error!(crate::bindings::venv::activate());
70-
unwrap_or_error!(package_version(name))
42+
unwrap_or_error!(crate::bindings::python::package_version(name))
7143
}
7244

7345
#[cfg(not(feature = "python"))]
@@ -79,13 +51,19 @@ pub fn python_package_version(name: &str) {
7951
#[cfg(feature = "python")]
8052
#[pg_extern]
8153
pub fn python_pip_freeze() -> TableIterator<'static, (name!(package, String),)> {
82-
unwrap_or_error!(crate::bindings::venv::activate());
54+
unwrap_or_error!(crate::bindings::python::pip_freeze())
55+
}
8356

84-
let packages = unwrap_or_error!(crate::bindings::venv::freeze())
85-
.into_iter()
86-
.map(|package| (package,));
57+
#[cfg(feature = "python")]
58+
#[pg_extern]
59+
fn python_version() -> String {
60+
unwrap_or_error!(crate::bindings::python::version())
61+
}
8762

88-
TableIterator::new(packages)
63+
#[cfg(not(feature = "python"))]
64+
#[pg_extern]
65+
pub fn python_version() -> String {
66+
String::from("Python is not installed, recompile with `--features python`")
8967
}
9068

9169
#[pg_extern]
@@ -104,26 +82,6 @@ pub fn validate_shared_library() {
10482
}
10583
}
10684

107-
#[cfg(feature = "python")]
108-
#[pg_extern]
109-
fn python_version() -> String {
110-
unwrap_or_error!(crate::bindings::venv::activate());
111-
let mut version = String::new();
112-
113-
Python::with_gil(|py| {
114-
let sys = PyModule::import(py, "sys").unwrap();
115-
version = sys.getattr("version").unwrap().extract().unwrap();
116-
});
117-
118-
version
119-
}
120-
121-
#[cfg(not(feature = "python"))]
122-
#[pg_extern]
123-
pub fn python_version() -> String {
124-
String::from("Python is not installed, recompile with `--features python`")
125-
}
126-
12785
#[pg_extern(immutable, parallel_safe)]
12886
fn version() -> String {
12987
crate::VERSION.to_string()

pgml-extension/src/bindings/langchain.rs renamed to pgml-extension/src/bindings/langchain/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ use pyo3::types::PyTuple;
66

77
use crate::{bindings::TracebackError, create_pymodule};
88

9-
create_pymodule!("/src/bindings/langchain.py");
9+
create_pymodule!("/src/bindings/langchain/langchain.py");
1010

1111
pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Result<Vec<String>> {
12-
crate::bindings::venv::activate()?;
12+
crate::bindings::python::activate()?;
1313

1414
let kwargs = serde_json::to_string(kwargs).unwrap();
1515

pgml-extension/src/bindings/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ pub mod langchain;
3838
pub mod lightgbm;
3939
pub mod linfa;
4040
#[cfg(feature = "python")]
41+
pub mod python;
42+
#[cfg(feature = "python")]
4143
pub mod sklearn;
4244
#[cfg(feature = "python")]
4345
pub mod transformers;
44-
#[cfg(feature = "python")]
45-
pub mod venv;
4646
pub mod xgboost;
4747

4848
pub type Fit = fn(dataset: &Dataset, hyperparams: &Hyperparams) -> Result<Box<dyn Bindings>>;
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//! Use virtualenv.
2+
3+
use anyhow::Result;
4+
use once_cell::sync::Lazy;
5+
use pgrx::iter::TableIterator;
6+
use pgrx::*;
7+
use pyo3::prelude::*;
8+
use pyo3::types::PyTuple;
9+
10+
use crate::config::get_config;
11+
use crate::{bindings::TracebackError, create_pymodule};
12+
13+
static CONFIG_NAME: &str = "pgml.venv";
14+
15+
create_pymodule!("/src/bindings/python/python.py");
16+
17+
pub fn activate_venv(venv: &str) -> Result<bool> {
18+
Python::with_gil(|py| {
19+
let activate_venv: Py<PyAny> = get_module!(PY_MODULE).getattr(py, "activate_venv")?;
20+
let result: Py<PyAny> =
21+
activate_venv.call1(py, PyTuple::new(py, &[venv.to_string().into_py(py)]))?;
22+
23+
Ok(result.extract(py)?)
24+
})
25+
}
26+
27+
pub fn activate() -> Result<bool> {
28+
match get_config(CONFIG_NAME) {
29+
Some(venv) => activate_venv(&venv),
30+
None => Ok(false),
31+
}
32+
}
33+
34+
pub fn pip_freeze() -> Result<TableIterator<'static, (name!(package, String),)>> {
35+
activate()?;
36+
let packages = Python::with_gil(|py| -> Result<Vec<String>> {
37+
let freeze = get_module!(PY_MODULE).getattr(py, "freeze")?;
38+
let result = freeze.call0(py)?;
39+
40+
Ok(result.extract(py)?)
41+
})?;
42+
43+
Ok(TableIterator::new(
44+
packages.into_iter().map(|package| (package,)),
45+
))
46+
}
47+
48+
pub fn validate_dependencies() -> Result<bool> {
49+
activate()?;
50+
Python::with_gil(|py| {
51+
let sys = PyModule::import(py, "sys").unwrap();
52+
let version: String = sys.getattr("version").unwrap().extract().unwrap();
53+
info!("Python version: {version}");
54+
for module in ["xgboost", "lightgbm", "numpy", "sklearn"] {
55+
match py.import(module) {
56+
Ok(_) => (),
57+
Err(e) => {
58+
panic!(
59+
"The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"
60+
);
61+
}
62+
}
63+
}
64+
});
65+
66+
let sklearn = package_version("sklearn")?;
67+
let xgboost = package_version("xgboost")?;
68+
let lightgbm = package_version("lightgbm")?;
69+
let numpy = package_version("numpy")?;
70+
71+
info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);
72+
73+
Ok(true)
74+
}
75+
76+
pub fn version() -> Result<String> {
77+
activate()?;
78+
Python::with_gil(|py| {
79+
let sys = PyModule::import(py, "sys").unwrap();
80+
let version: String = sys.getattr("version").unwrap().extract().unwrap();
81+
Ok(version)
82+
})
83+
}
84+
85+
pub fn package_version(name: &str) -> Result<String> {
86+
activate()?;
87+
Python::with_gil(|py| {
88+
let package = py.import(name)?;
89+
Ok(package.getattr("__version__")?.extract()?)
90+
})
91+
}

pgml-extension/src/bindings/sklearn.rs renamed to pgml-extension/src/bindings/sklearn/mod.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ use once_cell::sync::Lazy;
1515
use pyo3::prelude::*;
1616
use pyo3::types::PyTuple;
1717

18-
use crate::bindings::Bindings;
18+
use crate::{
19+
bindings::{Bindings, TracebackError},
20+
create_pymodule,
21+
orm::*,
22+
};
1923

20-
use crate::{bindings::TracebackError, create_pymodule, orm::*};
21-
22-
create_pymodule!("/src/bindings/sklearn.py");
24+
create_pymodule!("/src/bindings/sklearn/sklearn.py");
2325

2426
macro_rules! wrap_fit {
2527
($fn_name:tt, $task:literal) => {
@@ -355,10 +357,3 @@ pub fn cluster_metrics(
355357
Ok(scores)
356358
})
357359
}
358-
359-
pub fn package_version(name: &str) -> Result<String> {
360-
Python::with_gil(|py| {
361-
let package = py.import(name)?;
362-
Ok(package.getattr("__version__")?.extract()?)
363-
})
364-
}

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub fn transform(
2424
args: &serde_json::Value,
2525
inputs: Vec<&str>,
2626
) -> Result<serde_json::Value> {
27-
crate::bindings::venv::activate()?;
27+
crate::bindings::python::activate()?;
2828

2929
whitelist::verify_task(task)?;
3030

@@ -70,7 +70,7 @@ pub fn embed(
7070
inputs: Vec<&str>,
7171
kwargs: &serde_json::Value,
7272
) -> Result<Vec<Vec<f32>>> {
73-
crate::bindings::venv::activate()?;
73+
crate::bindings::python::activate()?;
7474

7575
let kwargs = serde_json::to_string(kwargs)?;
7676
Python::with_gil(|py| -> Result<Vec<Vec<f32>>> {
@@ -101,7 +101,7 @@ pub fn tune(
101101
hyperparams: &JsonB,
102102
path: &Path,
103103
) -> Result<HashMap<String, f64>> {
104-
crate::bindings::venv::activate()?;
104+
crate::bindings::python::activate()?;
105105

106106
let task = task.to_string();
107107
let hyperparams = serde_json::to_string(&hyperparams.0)?;
@@ -131,7 +131,7 @@ pub fn tune(
131131
}
132132

133133
pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<String>> {
134-
crate::bindings::venv::activate()?;
134+
crate::bindings::python::activate()?;
135135

136136
Python::with_gil(|py| -> Result<Vec<String>> {
137137
let generate = get_module!(PY_MODULE)
@@ -219,7 +219,7 @@ pub fn load_dataset(
219219
limit: Option<usize>,
220220
kwargs: &serde_json::Value,
221221
) -> Result<usize> {
222-
crate::bindings::venv::activate()?;
222+
crate::bindings::python::activate()?;
223223

224224
let kwargs = serde_json::to_string(kwargs)?;
225225

@@ -376,7 +376,7 @@ pub fn load_dataset(
376376
}
377377

378378
pub fn clear_gpu_cache(memory_usage: Option<f32>) -> Result<bool> {
379-
crate::bindings::venv::activate().unwrap();
379+
crate::bindings::python::activate().unwrap();
380380

381381
Python::with_gil(|py| -> Result<bool> {
382382
let clear_gpu_cache: Py<PyAny> = get_module!(PY_MODULE)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy