Content-Length: 619880 | pFad | http://github.com/postgresml/postgresml/pull/1350/files

01 LLM fine-tuning by santiatpml · Pull Request #1350 · postgresml/postgresml · GitHub
Skip to content

LLM fine-tuning #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e3bea27
fine-tuning text classification in progress
santiatpml Jan 31, 2024
c4cf332
More commit messages
santiatpml Feb 1, 2024
fb7cc2a
Working text classification with dataset args and training args
santiatpml Feb 6, 2024
5584487
finetuing with text dataset enum to handle different tasks
santiatpml Feb 7, 2024
82cb4f7
text pair classification task support
santiatpml Feb 7, 2024
c10de47
saving model after training
santiatpml Feb 7, 2024
63ee09b
removed device to cpu
santiatpml Feb 7, 2024
865ae28
updated transforemrs
Feb 8, 2024
097a8cf
Working e2e finetunig for two tasks
Feb 8, 2024
2dd50e6
Integration with huggingface hub and wandb
Feb 9, 2024
6ac8722
Conversation dataset + training placeholder
Feb 13, 2024
1e40cd8
Updated rust to fix failing tests
Feb 13, 2024
312d893
working version of conversation with lora + load 8bit + hf hub
Feb 13, 2024
afc2e93
Tested llama2-7b finetuning
Feb 22, 2024
22ee5c7
pypgrx first working version
Feb 27, 2024
97d455d
refactoring finetuning code to add callbacks
santiatpml Feb 27, 2024
b700944
fixed merge conflicts
santiatpml Mar 5, 2024
65d2f8b
Refactored finetuning + conversation + pgml callbacks
Mar 2, 2024
5f1b5f4
removed wandb dependency
Mar 4, 2024
08084bf
removed local pypgrx from requirements
Mar 4, 2024
dc0c6ee
removed maturin from requirements
Mar 4, 2024
421af8f
removed flash attn
Mar 4, 2024
4bbca96
Added indent for info display
Mar 5, 2024
3db857c
Updated readme with LLM fine-tuning for text classification
santiatpml Mar 7, 2024
7cbee43
README updates
santiatpml Mar 7, 2024
9284cf1
Added a tutorial for 9 classes - draft 1
santiatpml Mar 8, 2024
66c65c8
README updates
santiatpml Mar 8, 2024
5759ee3
Moved python functions (#1374)
SilasMarvin Mar 18, 2024
b539168
README updates
santiatpml Mar 19, 2024
31215b8
migrations and removed pypgrx
santiatpml Mar 20, 2024
dae6b74
Added r_log to take log level and message
santiatpml Mar 20, 2024
dae5ffc
Updated version and requirements
Mar 22, 2024
435f5bd
Changed version 2.8.3
Mar 22, 2024
aeb2683
README updates for conversation task fine-tuning using lora
santiatpml Mar 22, 2024
e5221cc
minor readme updates
santiatpml Mar 26, 2024
6db147e
added new line
santiatpml Mar 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
737 changes: 737 additions & 0 deletions README.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pgml-extension/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@
.DS_Store


# venv
pgml-venv
56 changes: 38 additions & 18 deletions pgml-extension/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions pgml-extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pgml"
version = "2.8.2"
version = "2.8.3"
edition = "2021"

[lib]
Expand Down Expand Up @@ -39,8 +39,8 @@ openblas-src = { version = "0.10", features = ["cblas", "system"] }
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
ndarray-stats = "0.5.1"
parking_lot = "0.12"
pgrx = "=0.11.2"
pgrx-pg-sys = "=0.11.2"
pgrx = "=0.11.3"
pgrx-pg-sys = "=0.11.3"
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rand = "0.8"
rmp-serde = { version = "1.1" }
Expand All @@ -51,7 +51,7 @@ typetag = "0.2"
xgboost = { git = "https://github.com/postgresml/rust-xgboost", branch = "master" }

[dev-dependencies]
pgrx-tests = "=0.11.2"
pgrx-tests = "=0.11.3"

[build-dependencies]
vergen = { version = "8", features = ["build", "git", "gitcl"] }
Expand Down
25 changes: 23 additions & 2 deletions pgml-extension/requirements.linux.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
accelerate==0.25.0
accelerate==0.27.2
aiohttp==3.9.1
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.2.0
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.1.0
auto-gptq==0.6.0
bitsandbytes==0.41.3.post2
black==24.1.1
catboost==1.2.2
certifi==2023.11.17
charset-normalizer==3.3.2
Expand All @@ -20,13 +22,18 @@ dataclasses-json==0.6.3
datasets==2.15.0
deepspeed==0.12.5
dill==0.3.7
docker-pycreds==0.4.0
docstring-parser==0.15
einops==0.7.0
evaluate==0.4.1
exceptiongroup==1.2.0
filelock==3.13.1
fonttools==4.47.0
frozenlist==1.4.1
fsspec==2023.10.0
gekko==1.0.6
gitdb==4.0.11
GitPython==3.1.41
graphviz==0.20.1
greenlet==3.0.2
hjson==3.1.0
Expand All @@ -45,9 +52,11 @@ langchain-core==0.1.1
langsmith==0.0.72
lightgbm==4.1.0
lxml==4.9.3
markdown-it-py==3.0.0
MarkupSafe==2.1.3
marshmallow==3.20.1
matplotlib==3.8.2
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
Expand All @@ -72,8 +81,10 @@ optimum==1.16.1
orjson==3.9.10
packaging==23.2
pandas==2.1.4
pathspec==0.12.1
peft==0.7.1
Pillow==10.1.0
platformdirs==4.2.0
plotly==5.18.0
portalocker==2.8.2
protobuf==4.25.1
Expand All @@ -83,13 +94,16 @@ pyarrow==11.0.0
pyarrow-hotfix==0.6
pydantic==2.5.2
pydantic_core==2.14.5
Pygments==2.17.2
pynvml==11.5.0
pyparsing==3.1.1
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
regex==2023.10.3
requests==2.31.0
responses==0.18.0
rich==13.7.1
rouge==1.0.1
sacrebleu==2.4.0
sacremoses==0.1.1
Expand All @@ -98,23 +112,30 @@ scikit-learn==1.3.2
scipy==1.11.4
sentence-transformers==2.2.2
sentencepiece==0.1.99
sentry-sdk==1.40.2
setproctitle==1.3.3
shtab==1.6.5
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
SQLAlchemy==2.0.23
sympy==1.12
tabulate==0.9.0
tenacity==8.2.3
threadpoolctl==3.2.0
tokenizers==0.15.0
tomli==2.0.1
torch==2.1.2
torchaudio==2.1.2
torchvision==0.16.2
tqdm==4.66.1
transformers==4.38.0
transformers==4.38.2
transformers-stream-generator==0.0.4
triton==2.1.0
trl==0.7.10
typing-inspect==0.9.0
typing_extensions==4.9.0
tyro==0.7.2
tzdata==2023.3
urllib3==2.1.0
xformers==0.0.23.post1
Expand Down
12 changes: 12 additions & 0 deletions pgml-extension/sql/pgml--2.8.2--2.8.3.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- Add conversation, text-pair-classification task type
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text-pair-classification';

-- Crate pgml.logs table
CREATE TABLE IF NOT EXISTS pgml.logs (
id SERIAL PRIMARY KEY,
model_id BIGINT,
project_id BIGINT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
logs JSONB
);
34 changes: 17 additions & 17 deletions pgml-extension/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ fn tune(
project_name: &str,
task: default!(Option<&str>, "NULL"),
relation_name: default!(Option<&str>, "NULL"),
y_column_name: default!(Option<&str>, "NULL"),
_y_column_name: default!(Option<&str>, "NULL"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the underscore? Is it because it's not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct.

model_name: default!(Option<&str>, "NULL"),
hyperparams: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
Expand Down Expand Up @@ -861,9 +861,7 @@ fn tune(

let snapshot = Snapshot::create(
relation_name,
Some(vec![y_column_name
.expect("You must pass a `y_column_name` when you pass a `relation_name`")
.to_string()]),
None,
test_size,
test_sampling,
materialize_snapshot,
Expand All @@ -885,13 +883,14 @@ fn tune(
// algorithm will be transformers, stash the model_name in a hyperparam for v1 compatibility.
let mut hyperparams = hyperparams.0.as_object().unwrap().clone();
hyperparams.insert(String::from("model_name"), json!(model_name));
hyperparams.insert(String::from("project_name"), json!(project_name));
let hyperparams = JsonB(json!(hyperparams));

// # Default repeatable random state when possible
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
// hyperparams["random_state"] = 0
let model = Model::tune(&project, &mut snapshot, &hyperparams);
let model = Model::finetune(&project, &mut snapshot, &hyperparams);
let new_metrics: &serde_json::Value = &model.metrics.unwrap().0;
let new_metrics = new_metrics.as_object().unwrap();

Expand All @@ -915,18 +914,19 @@ fn tune(
Some(true) | None => {
if let Ok(Some(deployed_metrics)) = deployed_metrics {
let deployed_metrics = deployed_metrics.0.as_object().unwrap();
if project.task.value_is_better(
deployed_metrics
.get(&project.task.default_target_metric())
.unwrap()
.as_f64()
.unwrap(),
new_metrics
.get(&project.task.default_target_metric())
.unwrap()
.as_f64()
.unwrap(),
) {

let deployed_value = deployed_metrics
.get(&project.task.default_target_metric())
.and_then(|value| value.as_f64())
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails

// Get the value for the default target metric from new_metrics or provide a default value
let new_value = new_metrics
.get(&project.task.default_target_metric())
.and_then(|value| value.as_f64())
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails

if project.task.value_is_better(deployed_value, new_value) {
deploy = false;
}
}
Expand Down
Loading








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/pull/1350/files

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy