Skip to content

Commit 5759ee3

Browse files
authored
Moved python functions (#1374)
1 parent 66c65c8 commit 5759ee3

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

pgml-extension/src/bindings/mod.rs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,31 @@ use std::fmt::Debug;
33
use anyhow::{anyhow, Result};
44
#[allow(unused_imports)] // used for test macros
55
use pgrx::*;
6-
use pyo3::{PyResult, Python};
6+
use pyo3::{pyfunction, PyResult, Python};
77

88
use crate::orm::*;
99

10+
#[pyfunction]
11+
fn r_insert_logs(project_id: i64, model_id: i64, logs: String) -> PyResult<String> {
12+
let id_value = Spi::get_one_with_args::<i64>(
13+
"INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;",
14+
vec![
15+
(PgBuiltInOids::INT8OID.oid(), project_id.into_datum()),
16+
(PgBuiltInOids::INT8OID.oid(), model_id.into_datum()),
17+
(PgBuiltInOids::TEXTOID.oid(), logs.into_datum()),
18+
],
19+
)
20+
.unwrap()
21+
.unwrap();
22+
Ok(format!("Inserted logs with id: {}", id_value))
23+
}
24+
25+
#[pyfunction]
26+
fn r_print_info(info: String) -> PyResult<String> {
27+
info!("{}", info);
28+
Ok(info)
29+
}
30+
1031
#[cfg(feature = "python")]
1132
#[macro_export]
1233
macro_rules! create_pymodule {
@@ -16,11 +37,11 @@ macro_rules! create_pymodule {
1637
pyo3::Python::with_gil(|py| -> anyhow::Result<pyo3::Py<pyo3::types::PyModule>> {
1738
use $crate::bindings::TracebackError;
1839
let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile));
19-
Ok(
20-
pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__")
21-
.format_traceback(py)?
22-
.into(),
23-
)
40+
let module = pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__")
41+
.format_traceback(py)?;
42+
module.add_function(wrap_pyfunction!($crate::bindings::r_insert_logs, module)?)?;
43+
module.add_function(wrap_pyfunction!($crate::bindings::r_print_info, module)?)?;
44+
Ok(module.into())
2445
})
2546
});
2647
};

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
5656
from trl.trainer import ConstantLengthDataset
5757
from peft import LoraConfig, get_peft_model
58-
from pypgrx import print_info, insert_logs
5958
from abc import abstractmethod
6059

6160
transformers.logging.set_verbosity_info()
@@ -1017,8 +1016,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
10171016
logs["step"] = state.global_step
10181017
logs["max_steps"] = state.max_steps
10191018
logs["timestamp"] = str(datetime.now())
1020-
print_info(json.dumps(logs, indent=4))
1021-
insert_logs(self.project_id, self.model_id, json.dumps(logs))
1019+
r_print_info(json.dumps(logs, indent=4))
10221020

10231021

10241022
class FineTuningBase:
@@ -1100,9 +1098,9 @@ def print_number_of_trainable_model_parameters(self, model):
11001098
trainable_model_params += param.numel()
11011099

11021100
# Calculate and print the number and percentage of trainable parameters
1103-
print_info(f"Trainable model parameters: {trainable_model_params}")
1104-
print_info(f"All model parameters: {all_model_params}")
1105-
print_info(
1101+
r_print_info(f"Trainable model parameters: {trainable_model_params}")
1102+
r_print_info(f"All model parameters: {all_model_params}")
1103+
r_print_info(
11061104
f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"
11071105
)
11081106

@@ -1398,7 +1396,7 @@ def __init__(
13981396
"bias": "none",
13991397
"task_type": "CAUSAL_LM",
14001398
}
1401-
print_info(
1399+
r_print_info(
14021400
"LoRA configuration are not set. Using default parameters"
14031401
+ json.dumps(self.lora_config_params)
14041402
)
@@ -1465,7 +1463,7 @@ def formatting_prompts_func(example):
14651463
peft_config=LoraConfig(**self.lora_config_params),
14661464
callbacks=[PGMLCallback(self.project_id, self.model_id)],
14671465
)
1468-
print_info("Creating Supervised Fine Tuning trainer done. Training ... ")
1466+
r_print_info("Creating Supervised Fine Tuning trainer done. Training ... ")
14691467

14701468
# Train
14711469
self.trainer.train()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy