Content-Length: 716787 | pFad | http://github.com/postgresml/postgresml/commit/f75114baa668944a3d31a1fef7fd24c31cfc71ff

AB LLM fine-tuning (#1350) · postgresml/postgresml@f75114b · GitHub
Skip to content

Commit f75114b

Browse files
authored
LLM fine-tuning (#1350)
1 parent 790e4f9 commit f75114b

File tree

15 files changed

+1993
-105
lines changed

15 files changed

+1993
-105
lines changed

README.md

Lines changed: 737 additions & 0 deletions
Large diffs are not rendered by default.

pgml-extension/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
.DS_Store
1515

1616

17+
# venv
18+
pgml-venv

pgml-extension/Cargo.lock

Lines changed: 38 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "2.8.2"
3+
version = "2.8.3"
44
edition = "2021"
55

66
[lib]
@@ -39,8 +39,8 @@ openblas-src = { version = "0.10", features = ["cblas", "system"] }
3939
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
4040
ndarray-stats = "0.5.1"
4141
parking_lot = "0.12"
42-
pgrx = "=0.11.2"
43-
pgrx-pg-sys = "=0.11.2"
42+
pgrx = "=0.11.3"
43+
pgrx-pg-sys = "=0.11.3"
4444
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
4545
rand = "0.8"
4646
rmp-serde = { version = "1.1" }
@@ -51,7 +51,7 @@ typetag = "0.2"
5151
xgboost = { git = "https://github.com/postgresml/rust-xgboost", branch = "master" }
5252

5353
[dev-dependencies]
54-
pgrx-tests = "=0.11.2"
54+
pgrx-tests = "=0.11.3"
5555

5656
[build-dependencies]
5757
vergen = { version = "8", features = ["build", "git", "gitcl"] }

pgml-extension/requirements.linux.txt

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
accelerate==0.25.0
1+
accelerate==0.27.2
22
aiohttp==3.9.1
33
aiosignal==1.3.1
44
annotated-types==0.6.0
55
anyio==4.2.0
6+
appdirs==1.4.4
67
async-timeout==4.0.3
78
attrs==23.1.0
89
auto-gptq==0.6.0
910
bitsandbytes==0.41.3.post2
11+
black==24.1.1
1012
catboost==1.2.2
1113
certifi==2023.11.17
1214
charset-normalizer==3.3.2
@@ -20,13 +22,18 @@ dataclasses-json==0.6.3
2022
datasets==2.15.0
2123
deepspeed==0.12.5
2224
dill==0.3.7
25+
docker-pycreds==0.4.0
26+
docstring-parser==0.15
2327
einops==0.7.0
28+
evaluate==0.4.1
2429
exceptiongroup==1.2.0
2530
filelock==3.13.1
2631
fonttools==4.47.0
2732
frozenlist==1.4.1
2833
fsspec==2023.10.0
2934
gekko==1.0.6
35+
gitdb==4.0.11
36+
GitPython==3.1.41
3037
graphviz==0.20.1
3138
greenlet==3.0.2
3239
hjson==3.1.0
@@ -45,9 +52,11 @@ langchain-core==0.1.1
4552
langsmith==0.0.72
4653
lightgbm==4.1.0
4754
lxml==4.9.3
55+
markdown-it-py==3.0.0
4856
MarkupSafe==2.1.3
4957
marshmallow==3.20.1
5058
matplotlib==3.8.2
59+
mdurl==0.1.2
5160
mpmath==1.3.0
5261
multidict==6.0.4
5362
multiprocess==0.70.15
@@ -72,8 +81,10 @@ optimum==1.16.1
7281
orjson==3.9.10
7382
packaging==23.2
7483
pandas==2.1.4
84+
pathspec==0.12.1
7585
peft==0.7.1
7686
Pillow==10.1.0
87+
platformdirs==4.2.0
7788
plotly==5.18.0
7889
portalocker==2.8.2
7990
protobuf==4.25.1
@@ -83,13 +94,16 @@ pyarrow==11.0.0
8394
pyarrow-hotfix==0.6
8495
pydantic==2.5.2
8596
pydantic_core==2.14.5
97+
Pygments==2.17.2
8698
pynvml==11.5.0
8799
pyparsing==3.1.1
88100
python-dateutil==2.8.2
89101
pytz==2023.3.post1
90102
PyYAML==6.0.1
91103
regex==2023.10.3
92104
requests==2.31.0
105+
responses==0.18.0
106+
rich==13.7.1
93107
rouge==1.0.1
94108
sacrebleu==2.4.0
95109
sacremoses==0.1.1
@@ -98,23 +112,30 @@ scikit-learn==1.3.2
98112
scipy==1.11.4
99113
sentence-transformers==2.5.1
100114
sentencepiece==0.1.99
115+
sentry-sdk==1.40.2
116+
setproctitle==1.3.3
117+
shtab==1.6.5
101118
six==1.16.0
119+
smmap==5.0.1
102120
sniffio==1.3.0
103121
SQLAlchemy==2.0.23
104122
sympy==1.12
105123
tabulate==0.9.0
106124
tenacity==8.2.3
107125
threadpoolctl==3.2.0
108126
tokenizers==0.15.0
127+
tomli==2.0.1
109128
torch==2.1.2
110129
torchaudio==2.1.2
111130
torchvision==0.16.2
112131
tqdm==4.66.1
113132
transformers==4.38.2
114133
transformers-stream-generator==0.0.4
115134
triton==2.1.0
135+
trl==0.7.10
116136
typing-inspect==0.9.0
117137
typing_extensions==4.9.0
138+
tyro==0.7.2
118139
tzdata==2023.3
119140
urllib3==2.1.0
120141
xformers==0.0.23.post1
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
-- Add conversation, text-pair-classification task type
2+
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
3+
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text-pair-classification';
4+
5+
-- Crate pgml.logs table
6+
CREATE TABLE IF NOT EXISTS pgml.logs (
7+
id SERIAL PRIMARY KEY,
8+
model_id BIGINT,
9+
project_id BIGINT,
10+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
11+
logs JSONB
12+
);

pgml-extension/src/api.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ fn tune(
816816
project_name: &str,
817817
task: default!(Option<&str>, "NULL"),
818818
relation_name: default!(Option<&str>, "NULL"),
819-
y_column_name: default!(Option<&str>, "NULL"),
819+
_y_column_name: default!(Option<&str>, "NULL"),
820820
model_name: default!(Option<&str>, "NULL"),
821821
hyperparams: default!(JsonB, "'{}'"),
822822
test_size: default!(f32, 0.25),
@@ -874,9 +874,7 @@ fn tune(
874874

875875
let snapshot = Snapshot::create(
876876
relation_name,
877-
Some(vec![y_column_name
878-
.expect("You must pass a `y_column_name` when you pass a `relation_name`")
879-
.to_string()]),
877+
None,
880878
test_size,
881879
test_sampling,
882880
materialize_snapshot,
@@ -898,13 +896,14 @@ fn tune(
898896
// algorithm will be transformers, stash the model_name in a hyperparam for v1 compatibility.
899897
let mut hyperparams = hyperparams.0.as_object().unwrap().clone();
900898
hyperparams.insert(String::from("model_name"), json!(model_name));
899+
hyperparams.insert(String::from("project_name"), json!(project_name));
901900
let hyperparams = JsonB(json!(hyperparams));
902901

903902
// # Default repeatable random state when possible
904903
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
905904
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
906905
// hyperparams["random_state"] = 0
907-
let model = Model::tune(&project, &mut snapshot, &hyperparams);
906+
let model = Model::finetune(&project, &mut snapshot, &hyperparams);
908907
let new_metrics: &serde_json::Value = &model.metrics.unwrap().0;
909908
let new_metrics = new_metrics.as_object().unwrap();
910909

@@ -928,18 +927,19 @@ fn tune(
928927
Some(true) | None => {
929928
if let Ok(Some(deployed_metrics)) = deployed_metrics {
930929
let deployed_metrics = deployed_metrics.0.as_object().unwrap();
931-
if project.task.value_is_better(
932-
deployed_metrics
933-
.get(&project.task.default_target_metric())
934-
.unwrap()
935-
.as_f64()
936-
.unwrap(),
937-
new_metrics
938-
.get(&project.task.default_target_metric())
939-
.unwrap()
940-
.as_f64()
941-
.unwrap(),
942-
) {
930+
931+
let deployed_value = deployed_metrics
932+
.get(&project.task.default_target_metric())
933+
.and_then(|value| value.as_f64())
934+
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails
935+
936+
// Get the value for the default target metric from new_metrics or provide a default value
937+
let new_value = new_metrics
938+
.get(&project.task.default_target_metric())
939+
.and_then(|value| value.as_f64())
940+
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails
941+
942+
if project.task.value_is_better(deployed_value, new_value) {
943943
deploy = false;
944944
}
945945
}

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/commit/f75114baa668944a3d31a1fef7fd24c31cfc71ff

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy