diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index d877f490a..9c9449103 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -3,10 +3,31 @@ use std::fmt::Debug; use anyhow::{anyhow, Result}; #[allow(unused_imports)] // used for test macros use pgrx::*; -use pyo3::{PyResult, Python}; +use pyo3::{pyfunction, PyResult, Python}; use crate::orm::*; +#[pyfunction] +fn r_insert_logs(project_id: i64, model_id: i64, logs: String) -> PyResult { + let id_value = Spi::get_one_with_args::( + "INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;", + vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), logs.into_datum()), + ], + ) + .unwrap() + .unwrap(); + Ok(format!("Inserted logs with id: {}", id_value)) +} + +#[pyfunction] +fn r_print_info(info: String) -> PyResult { + info!("{}", info); + Ok(info) +} + #[cfg(feature = "python")] #[macro_export] macro_rules! create_pymodule { @@ -16,11 +37,11 @@ macro_rules! create_pymodule { pyo3::Python::with_gil(|py| -> anyhow::Result> { use $crate::bindings::TracebackError; let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile)); - Ok( - pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__") - .format_traceback(py)? - .into(), - ) + let module = pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__") + .format_traceback(py)?; + module.add_function(wrap_pyfunction!($crate::bindings::r_insert_logs, module)?)?; + module.add_function(wrap_pyfunction!($crate::bindings::r_print_info, module)?)?; + Ok(module.into()) }) }); }; diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 42ac43fe0..f3a6d63d4 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -55,7 +55,6 @@ from trl import SFTTrainer, DataCollatorForCompletionOnlyLM from trl.trainer import ConstantLengthDataset from peft import LoraConfig, get_peft_model -from pypgrx import print_info, insert_logs from abc import abstractmethod transformers.logging.set_verbosity_info() @@ -1017,8 +1016,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): logs["step"] = state.global_step logs["max_steps"] = state.max_steps logs["timestamp"] = str(datetime.now()) - print_info(json.dumps(logs, indent=4)) - insert_logs(self.project_id, self.model_id, json.dumps(logs)) + r_print_info(json.dumps(logs, indent=4)) class FineTuningBase: @@ -1100,9 +1098,9 @@ def print_number_of_trainable_model_parameters(self, model): trainable_model_params += param.numel() # Calculate and print the number and percentage of trainable parameters - print_info(f"Trainable model parameters: {trainable_model_params}") - print_info(f"All model parameters: {all_model_params}") - print_info( + r_print_info(f"Trainable model parameters: {trainable_model_params}") + r_print_info(f"All model parameters: {all_model_params}") + r_print_info( f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%" ) @@ -1398,7 +1396,7 @@ def __init__( "bias": "none", "task_type": "CAUSAL_LM", } - print_info( + r_print_info( "LoRA configuration are not set. Using default parameters" + json.dumps(self.lora_config_params) ) @@ -1465,7 +1463,7 @@ def formatting_prompts_func(example): peft_config=LoraConfig(**self.lora_config_params), callbacks=[PGMLCallback(self.project_id, self.model_id)], ) - print_info("Creating Supervised Fine Tuning trainer done. Training ... ") + r_print_info("Creating Supervised Fine Tuning trainer done. Training ... ") # Train self.trainer.train() pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy