Content-Length: 367722 | pFad | http://github.com/postgresml/postgresml/pull/1350.patch
thub.com
From e3bea27fd6410faed493f1afbd386b1beaffdfc5 Mon Sep 17 00:00:00 2001
From: Santi Adavani
Date: Wed, 31 Jan 2024 07:39:06 -0800
Subject: [PATCH 01/36] fine-tuning text classification in progress
---
.../src/bindings/transformers/mod.rs | 2 +-
.../src/bindings/transformers/transformers.py | 97 +++++++++++++++++++
2 files changed, 98 insertions(+), 1 deletion(-)
diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs
index 6a4a2133e..aa38687b9 100644
--- a/pgml-extension/src/bindings/transformers/mod.rs
+++ b/pgml-extension/src/bindings/transformers/mod.rs
@@ -60,7 +60,7 @@ pub fn tune(task: &Task, dataset: TextDataset, hyperparams: &JsonB, path: &Path)
let hyperparams = serde_json::to_string(&hyperparams.0)?;
Python::with_gil(|py| -> Result> {
- let tune = get_module!(PY_MODULE).getattr(py, "tune").format_traceback(py)?;
+ let tune = get_module!(PY_MODULE).getattr(py, "finetune").format_traceback(py)?;
let path = path.to_string_lossy();
let output = tune
.call1(
diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py
index 9390cac44..ab02f58ac 100644
--- a/pgml-extension/src/bindings/transformers/transformers.py
+++ b/pgml-extension/src/bindings/transformers/transformers.py
@@ -46,6 +46,20 @@
PegasusTokenizer,
)
import threading
+import logging
+from rich.logging import RichHandler
+
+transformers.logging.set_verbosity_info()
+
+
+FORMAT = "%(message)s"
+logging.basicConfig(
+ level=os.environ.get("LOG_LEVEL", "INFO"),
+ format="%(asctime)s - %(message)s",
+ datefmt="[%X]",
+ handlers=[RichHandler()],
+)
+log = logging.getLogger("rich")
__cache_transformer_by_model_id = {}
__cache_sentence_transformer_by_name = {}
@@ -983,3 +997,86 @@ def generate(model_id, data, config):
)
all_preds.extend(decoded_preds)
return all_preds
+
+
+#######################
+# LLM Fine-Tuning
+#######################
+def finetune(task, hyperparams, path, x_train, x_test, y_train, y_test):
+ # Get model and tokenizer
+ hyperparams = orjson.loads(hyperparams)
+ model_name = hyperparams.pop("model_name")
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ classes = list(set(y_train))
+ num_classes = len(classes)
+ model = AutoModelForSequenceClassification.from_pretrained(
+ model_name, num_labels=num_classes
+ )
+ id2label = {}
+ label2id = {}
+ for id, label in enumerate(classes):
+ label2id[label] = float(id)
+ id2label[id] = label
+
+ model.config.id2label = id2label
+ model.config.label2id = label2id
+
+ y_train_label = [label2id[_class] for _class in y_train]
+ y_test_label = [label2id[_class] for _class in y_test]
+
+ # Prepare dataset
+ train_dataset = datasets.Dataset.from_dict(
+ {
+ "text": x_train,
+ "label": y_train_label,
+ }
+ )
+ test_dataset = datasets.Dataset.from_dict(
+ {
+ "text": x_test,
+ "label": y_test_label,
+ }
+ )
+ # tokenization function
+ def tokenize_function(example):
+ tokenized_example = tokenizer(
+ example["text"],
+ padding=True,
+ truncation=True,
+ return_tensors="pt"
+ )
+ return tokenized_example
+
+ # Generate tokens
+ train_tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
+ test_tokenized_datasets = test_dataset.map(tokenize_function, batched=True)
+ log.info("Tokenization done")
+ log.info("Train dataset")
+ log.info(train_tokenized_datasets[0:2])
+ log.info("Test dataset")
+ log.info(test_tokenized_datasets[0:2])
+ # Data collator
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
+
+ # Training Args
+ log.info("Training args setup started path=%s"%path)
+ training_args=TrainingArguments(output_dir="/tmp/postgresml/models/", **hyperparams)
+ log.info("Trainer setup done")
+ # Trainer
+ try:
+ trainer = Trainer(
+ model=model.to("cpu"),
+ args=training_args,
+ train_dataset=train_tokenized_datasets,
+ eval_dataset=test_tokenized_datasets,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ )
+ except Exception as e:
+ log.error(e)
+ log.info("Training started")
+
+ # Train
+ trainer.train()
+ metrics = {"loss" : 0.0}
+ return metrics
\ No newline at end of file
From c4cf332115a1f240bf0308a23d9a3735f90a0eba Mon Sep 17 00:00:00 2001
From: Santi Adavani
Date: Wed, 31 Jan 2024 16:51:14 -0800
Subject: [PATCH 02/36] More commit messages
---
pgml-extension/src/bindings/transformers/transformers.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py
index ab02f58ac..9d8aa4f63 100644
--- a/pgml-extension/src/bindings/transformers/transformers.py
+++ b/pgml-extension/src/bindings/transformers/transformers.py
@@ -1060,9 +1060,15 @@ def tokenize_function(example):
# Training Args
log.info("Training args setup started path=%s"%path)
- training_args=TrainingArguments(output_dir="/tmp/postgresml/models/", **hyperparams)
+ training_args=TrainingArguments(output_dir="/tmp/postgresml/models/", logging_dir="/tmp/postgresml/runs", **hyperparams)
log.info("Trainer setup done")
# Trainer
+ log.info(model)
+ log.info(training_args)
+ log.info(train_tokenized_datasets)
+ log.info(test_tokenized_datasets)
+ log.info(tokenizer)
+ log.info(data_collator)
try:
trainer = Trainer(
model=model.to("cpu"),
From fb7cc2ae6633efdf10b20194e469c38751e4574a Mon Sep 17 00:00:00 2001
From: Santi Adavani
Date: Mon, 5 Feb 2024 20:52:18 -0800
Subject: [PATCH 03/36] Working text classification with dataset args and
training args
---
pgml-extension/src/api.rs | 9 ++--
.../src/bindings/transformers/mod.rs | 39 +++++++++++---
.../src/bindings/transformers/transformers.py | 14 ++---
pgml-extension/src/orm/dataset.rs | 16 +++---
pgml-extension/src/orm/mod.rs | 2 +-
pgml-extension/src/orm/model.rs | 10 ++--
pgml-extension/src/orm/snapshot.rs | 53 +++++++++++++------
7 files changed, 96 insertions(+), 47 deletions(-)
diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs
index 7fd5012c8..bd2136be6 100644
--- a/pgml-extension/src/api.rs
+++ b/pgml-extension/src/api.rs
@@ -803,7 +803,7 @@ fn tune(
project_name: &str,
task: default!(Option<&str>, "NULL"),
relation_name: default!(Option<&str>, "NULL"),
- y_column_name: default!(Option<&str>, "NULL"),
+ _y_column_name: default!(Option<&str>, "NULL"),
model_name: default!(Option<&str>, "NULL"),
hyperparams: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
@@ -861,9 +861,7 @@ fn tune(
let snapshot = Snapshot::create(
relation_name,
- Some(vec![y_column_name
- .expect("You must pass a `y_column_name` when you pass a `relation_name`")
- .to_string()]),
+ None,
test_size,
test_sampling,
materialize_snapshot,
@@ -891,7 +889,7 @@ fn tune(
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
// hyperparams["random_state"] = 0
- let model = Model::tune(&project, &mut snapshot, &hyperparams);
+ let model = Model::finetune(&project, &mut snapshot, &hyperparams);
let new_metrics: &serde_json::Value = &model.metrics.unwrap().0;
let new_metrics = new_metrics.as_object().unwrap();
@@ -947,6 +945,7 @@ fn tune(
)])
}
+
#[cfg(feature = "python")]
#[pg_extern(name = "sklearn_f1_score")]
pub fn sklearn_f1_score(ground_truth: Vec, y_hat: Vec) -> f32 {
diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs
index aa38687b9..1b650cbbd 100644
--- a/pgml-extension/src/bindings/transformers/mod.rs
+++ b/pgml-extension/src/bindings/transformers/mod.rs
@@ -10,7 +10,7 @@ use pyo3::types::PyTuple;
use serde_json::Value;
use crate::create_pymodule;
-use crate::orm::{Task, TextDataset};
+use crate::orm::{Task, TextClassificationDataset};
use super::TracebackError;
@@ -55,7 +55,33 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -
})
}
-pub fn tune(task: &Task, dataset: TextDataset, hyperparams: &JsonB, path: &Path) -> Result> {
+// pub fn tune(task: &Task, dataset: TextDatasetType, hyperparams: &JsonB, path: &Path) -> Result> {
+// let task = task.to_string();
+// let hyperparams = serde_json::to_string(&hyperparams.0)?;
+
+// Python::with_gil(|py| -> Result> {
+// let tune = get_module!(PY_MODULE).getattr(py, "finetune").format_traceback(py)?;
+// let path = path.to_string_lossy();
+// let output = tune
+// .call1(
+// py,
+// (
+// &task,
+// &hyperparams,
+// path.as_ref(),
+// dataset.x_train,
+// dataset.x_test,
+// dataset.y_train,
+// dataset.y_test,
+// ),
+// )
+// .format_traceback(py)?;
+
+// output.extract(py).format_traceback(py)
+// })
+// }
+
+pub fn finetune(task: &Task, dataset: TextClassificationDataset, hyperparams: &JsonB, path: &Path) -> Result> {
let task = task.to_string();
let hyperparams = serde_json::to_string(&hyperparams.0)?;
@@ -69,10 +95,10 @@ pub fn tune(task: &Task, dataset: TextDataset, hyperparams: &JsonB, path: &Path)
&task,
&hyperparams,
path.as_ref(),
- dataset.x_train,
- dataset.x_test,
- dataset.y_train,
- dataset.y_test,
+ dataset.text_train,
+ dataset.text_test,
+ dataset.class_train,
+ dataset.class_test,
),
)
.format_traceback(py)?;
@@ -80,7 +106,6 @@ pub fn tune(task: &Task, dataset: TextDataset, hyperparams: &JsonB, path: &Path)
output.extract(py).format_traceback(py)
})
}
-
pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result> {
Python::with_gil(|py| -> Result> {
let generate = get_module!(PY_MODULE).getattr(py, "generate").format_traceback(py)?;
diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py
index 9d8aa4f63..e1ec38715 100644
--- a/pgml-extension/src/bindings/transformers/transformers.py
+++ b/pgml-extension/src/bindings/transformers/transformers.py
@@ -1009,15 +1009,17 @@ def finetune(task, hyperparams, path, x_train, x_test, y_train, y_test):
tokenizer = AutoTokenizer.from_pretrained(model_name)
classes = list(set(y_train))
num_classes = len(classes)
- model = AutoModelForSequenceClassification.from_pretrained(
- model_name, num_labels=num_classes
- )
+
id2label = {}
label2id = {}
for id, label in enumerate(classes):
- label2id[label] = float(id)
+ label2id[label] = id
id2label[id] = label
-
+
+ model = AutoModelForSequenceClassification.from_pretrained(
+ model_name, num_labels=num_classes, id2label=id2label, label2id=label2id
+ )
+
model.config.id2label = id2label
model.config.label2id = label2id
@@ -1060,7 +1062,7 @@ def tokenize_function(example):
# Training Args
log.info("Training args setup started path=%s"%path)
- training_args=TrainingArguments(output_dir="/tmp/postgresml/models/", logging_dir="/tmp/postgresml/runs", **hyperparams)
+ training_args=TrainingArguments(output_dir="/tmp/postgresml/models/", logging_dir="/tmp/postgresml/runs", **hyperparams["training_args"])
log.info("Trainer setup done")
# Trainer
log.info(model)
diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs
index 062886a5c..ea56ea19c 100644
--- a/pgml-extension/src/orm/dataset.rs
+++ b/pgml-extension/src/orm/dataset.rs
@@ -68,12 +68,12 @@ impl Dataset {
}
}
-#[derive(Debug)]
-pub struct TextDataset {
- pub x_train: Vec,
- pub y_train: Vec,
- pub x_test: Vec,
- pub y_test: Vec,
+// TextClassificationDataset
+pub struct TextClassificationDataset {
+ pub text_train: Vec,
+ pub class_train: Vec,
+ pub text_test: Vec,
+ pub class_test: Vec,
pub num_features: usize,
pub num_labels: usize,
pub num_rows: usize,
@@ -82,11 +82,11 @@ pub struct TextDataset {
pub num_distinct_labels: usize,
}
-impl Display for TextDataset {
+impl Display for TextClassificationDataset {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(
f,
- "TextDataset {{ num_features: {}, num_labels: {}, num_distinct_labels: {}, num_rows: {}, num_train_rows: {}, num_test_rows: {} }}",
+ "TextClassificationDataset {{ num_features: {}, num_labels: {}, num_distinct_labels: {}, num_rows: {}, num_train_rows: {}, num_test_rows: {} }}",
self.num_features, self.num_labels, self.num_distinct_labels, self.num_rows, self.num_train_rows, self.num_test_rows,
)
}
diff --git a/pgml-extension/src/orm/mod.rs b/pgml-extension/src/orm/mod.rs
index abe00f1c1..b67cd748c 100644
--- a/pgml-extension/src/orm/mod.rs
+++ b/pgml-extension/src/orm/mod.rs
@@ -13,7 +13,7 @@ pub mod task;
pub use algorithm::Algorithm;
pub use dataset::Dataset;
-pub use dataset::TextDataset;
+pub use dataset::TextClassificationDataset;
pub use model::Model;
pub use project::Project;
pub use runtime::Runtime;
diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs
index 5d1aadbde..6799aecd2 100644
--- a/pgml-extension/src/orm/model.rs
+++ b/pgml-extension/src/orm/model.rs
@@ -157,10 +157,14 @@ impl Model {
model
}
+
#[allow(clippy::too_many_arguments)]
- pub fn tune(project: &Project, snapshot: &mut Snapshot, hyperparams: &JsonB) -> Model {
+ pub fn finetune(project: &Project, snapshot: &mut Snapshot, hyperparams: &JsonB) -> Model {
let mut model: Option = None;
- let dataset = snapshot.text_dataset();
+
+ let dataset_args = JsonB(json!(hyperparams.0.get("dataset_args").unwrap()));
+
+ let dataset = snapshot.text_classification_dataset(dataset_args);
// Create the model record.
Spi::connect(|mut client| {
@@ -211,7 +215,7 @@ impl Model {
let path = std::path::PathBuf::from(format!("/tmp/postgresml/models/{id}"));
info!("Tuning {}", model);
- let metrics = match transformers::tune(&project.task, dataset, &model.hyperparams, &path) {
+ let metrics = match transformers::finetune(&project.task, dataset, &model.hyperparams, &path) {
Ok(metrics) => metrics,
Err(e) => error!("{e}"),
};
diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs
index 402dff976..c21a3342f 100644
--- a/pgml-extension/src/orm/snapshot.rs
+++ b/pgml-extension/src/orm/snapshot.rs
@@ -11,7 +11,8 @@ use serde_json::json;
use crate::orm::Sampling;
use crate::orm::Status;
-use crate::orm::{Dataset, TextDataset};
+use crate::orm::{Dataset, TextClassificationDataset};
+
// Categories use a designated string to represent NULL categorical values,
// rather than Option = None, because the JSONB serialization schema
@@ -773,7 +774,7 @@ impl Snapshot {
(num_train_rows, num_test_rows)
}
- pub fn text_dataset(&mut self) -> TextDataset {
+ pub fn text_classification_dataset(&mut self, dataset_args: default!(JsonB, "'{}'")) -> TextClassificationDataset {
let mut data = None;
Spi::connect(|client| {
@@ -783,23 +784,41 @@ impl Snapshot {
let num_features = self.num_features();
let num_labels = self.num_labels();
- let mut x_train: Vec = Vec::with_capacity(num_train_rows * num_features);
- let mut y_train: Vec = Vec::with_capacity(num_train_rows * num_labels);
- let mut x_test: Vec = Vec::with_capacity(num_test_rows * num_features);
- let mut y_test: Vec = Vec::with_capacity(num_test_rows * num_labels);
+ let mut text_train: Vec = Vec::with_capacity(num_train_rows);
+ let mut class_train: Vec = Vec::with_capacity(num_train_rows);
+ let mut text_test: Vec = Vec::with_capacity(num_test_rows);
+ let mut class_test: Vec = Vec::with_capacity(num_test_rows);
+
+ let class_column_value = dataset_args.0
+ .get("class_column")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string())
+ .unwrap_or_else(|| "class".to_string());
+
+ let text_column_value = dataset_args.0
+ .get("text_column")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string())
+ .unwrap_or_else(|| "text".to_string());
result.enumerate().for_each(|(i, row)| {
for column in &mut self.columns {
- let vector = if column.label {
+ let vector = if column.name == class_column_value {
if i < num_train_rows {
- &mut y_train
+ &mut class_train
} else {
- &mut y_test
+ &mut class_test
+ }
+ } else if column.name == text_column_value {
+ if i < num_train_rows {
+ &mut text_train
+ } else {
+ &mut text_test
}
- } else if i < num_train_rows {
- &mut x_train
} else {
- &mut x_test
+ // Handle the case when neither "class_column" nor "text_column" is present
+ // You might want to provide a default value or raise an error.
+ panic!("Neither 'class_column' nor 'text_column' found in dataset_args");
};
match column.pg_type.as_str() {
@@ -812,11 +831,11 @@ impl Snapshot {
}
});
- data = Some(TextDataset {
- x_train,
- y_train,
- x_test,
- y_test,
+ data = Some(TextClassificationDataset {
+ text_train,
+ class_train,
+ text_test,
+ class_test,
num_features,
num_labels,
num_rows,
From 55844878dd1ac54f1575678095c3cf6c706d3442 Mon Sep 17 00:00:00 2001
From: Santi Adavani
Date: Tue, 6 Feb 2024 17:27:35 -0800
Subject: [PATCH 04/36] finetuing with text dataset enum to handle different
tasks
---
.../src/bindings/transformers/mod.rs | 4 +--
.../src/bindings/transformers/transformers.py | 2 +-
pgml-extension/src/orm/dataset.rs | 12 ++++++++
pgml-extension/src/orm/mod.rs | 1 +
pgml-extension/src/orm/model.rs | 28 +++++++++++++++----
5 files changed, 39 insertions(+), 8 deletions(-)
diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs
index 1b650cbbd..853a3d436 100644
--- a/pgml-extension/src/bindings/transformers/mod.rs
+++ b/pgml-extension/src/bindings/transformers/mod.rs
@@ -81,12 +81,12 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -
// })
// }
-pub fn finetune(task: &Task, dataset: TextClassificationDataset, hyperparams: &JsonB, path: &Path) -> Result> {
+pub fn finetune_text_classification(task: &Task, dataset: TextClassificationDataset, hyperparams: &JsonB, path: &Path) -> Result> {
let task = task.to_string();
let hyperparams = serde_json::to_string(&hyperparams.0)?;
Python::with_gil(|py| -> Result> {
- let tune = get_module!(PY_MODULE).getattr(py, "finetune").format_traceback(py)?;
+ let tune = get_module!(PY_MODULE).getattr(py, "finetune_text_classification").format_traceback(py)?;
let path = path.to_string_lossy();
let output = tune
.call1(
diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py
index e1ec38715..9dce8a9ed 100644
--- a/pgml-extension/src/bindings/transformers/transformers.py
+++ b/pgml-extension/src/bindings/transformers/transformers.py
@@ -1002,7 +1002,7 @@ def generate(model_id, data, config):
#######################
# LLM Fine-Tuning
#######################
-def finetune(task, hyperparams, path, x_train, x_test, y_train, y_test):
+def finetune_text_classification(task, hyperparams, path, x_train, x_test, y_train, y_test):
# Get model and tokenizer
hyperparams = orjson.loads(hyperparams)
model_name = hyperparams.pop("model_name")
diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs
index ea56ea19c..ce165acba 100644
--- a/pgml-extension/src/orm/dataset.rs
+++ b/pgml-extension/src/orm/dataset.rs
@@ -92,6 +92,18 @@ impl Display for TextClassificationDataset {
}
}
+pub enum TextDatasetType {
+ TextClassification(TextClassificationDataset),
+}
+
+impl TextDatasetType {
+ pub fn num_features(&self) -> usize {
+ match self {
+ TextDatasetType::TextClassification(dataset) => dataset.num_features,
+ }
+ }
+}
+
fn drop_table_if_exists(table_name: &str) {
// Avoid the existence for DROP TABLE IF EXISTS warning by checking the schema for the table first
let table_count = Spi::get_one_with_args::(
diff --git a/pgml-extension/src/orm/mod.rs b/pgml-extension/src/orm/mod.rs
index b67cd748c..c41306afe 100644
--- a/pgml-extension/src/orm/mod.rs
+++ b/pgml-extension/src/orm/mod.rs
@@ -13,6 +13,7 @@ pub mod task;
pub use algorithm::Algorithm;
pub use dataset::Dataset;
+pub use dataset::TextDatasetType;
pub use dataset::TextClassificationDataset;
pub use model::Model;
pub use project::Project;
diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs
index 6799aecd2..ff8a3c1e8 100644
--- a/pgml-extension/src/orm/model.rs
+++ b/pgml-extension/src/orm/model.rs
@@ -164,7 +164,12 @@ impl Model {
let dataset_args = JsonB(json!(hyperparams.0.get("dataset_args").unwrap()));
- let dataset = snapshot.text_classification_dataset(dataset_args);
+ // let dataset = snapshot.text_classification_dataset(dataset_args);
+ let dataset = if project.task == Task::text_classification {
+ TextDatasetType::TextClassification(snapshot.text_classification_dataset(dataset_args))
+ } else {
+ TextDatasetType::TextClassification(snapshot.text_classification_dataset(dataset_args))
+ };
// Create the model record.
Spi::connect(|mut client| {
@@ -183,7 +188,7 @@ impl Model {
(PgBuiltInOids::TEXTOID.oid(), None::