Skip to content

Commit 8cd4298

Browse files
authored
Fix runtime.pai submitter (#3073)
* fix runtime.pai submitter * wip * fix runtime pai submitter * fix pre-commit * test ci * update * fix ci * fix ci * update * fix precommit * fix precommit * test ci * fix ci * update * fix ci * update python tests * update * fix comments
1 parent ce99d2c commit 8cd4298

20 files changed

+282
-258
lines changed

go/codegen/pai/template_tf.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ oss.load_dir("{{.OSSModelDir}}/model_save")
9393
// install sklearn-pandas==1.8.0 to fix deps for sklearn2pmml with Python2 on PAI.
9494
const paiRequirementsTmplText = `
9595
adanet==0.8.0
96+
dill==0.3.0
9697
numpy==1.16.2
9798
pandas==0.24.2
9899
plotille==3.7

python/runtime/alisa/submitter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from runtime.pai.submitter import (ENTRY_FILE, JOB_ARCHIVE_FILE, PARAMS_FILE,
2323
clean_oss_model_path,
2424
create_evaluate_result_table,
25-
create_explain_result_table,
2625
create_predict_result_table,
2726
create_tmp_table_from_select,
2827
create_train_and_eval_tmp_table,
@@ -303,8 +302,9 @@ def submit_alisa_explain(datasource, select, result_table, model_name,
303302

304303
label_column = model_params.get("label_col")
305304
params["label_column"] = label_column
306-
create_explain_result_table(datasource, data_table, result_table,
307-
model_type, estimator, label_column)
305+
# FIXME(typhoonzero): Add this back using runtime.step.create_result_table
306+
# create_explain_result_table(datasource, data_table, result_table,
307+
# model_type, estimator, label_column)
308308

309309
setup_explain_entry(params, model_type)
310310
prepare_archive(cwd, estimator, oss_model_path, params)

python/runtime/local/submitter.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313

1414
from runtime import db
1515
from runtime.dbapi import table_writer
16-
from runtime.feature.derivation import infer_feature_columns
16+
from runtime.feature.derivation import (get_ordered_field_descs,
17+
infer_feature_columns)
1718
from runtime.model.db import read_metadata_from_db
1819
from runtime.model.model import EstimatorType, Model
20+
from runtime.step.create_result_table import (create_evaluate_table,
21+
create_explain_table,
22+
create_predict_table)
1923
from runtime.step.tensorflow.evaluate import evaluate_step as tf_evaluate
2024
from runtime.step.tensorflow.explain import explain_step as tf_explain
25+
from runtime.step.tensorflow.explain import print_image_as_base64_html
2126
from runtime.step.tensorflow.predict import predict_step as tf_pred
2227
from runtime.step.tensorflow.train import train_step as tf_train
2328
from runtime.step.xgboost.evaluate import evaluate as xgboost_evaluate
@@ -114,10 +119,20 @@ def submit_local_pred(datasource,
114119
else:
115120
pred_func = tf_pred
116121

122+
conn = db.connect_with_data_source(datasource)
123+
if model.get_meta("label") is None:
124+
train_label_desc = None
125+
else:
126+
train_label_desc = model.get_meta("label").get_field_desc()[0]
127+
result_column_names, train_label_idx = create_predict_table(
128+
conn, select, result_table, train_label_desc, label_name)
129+
conn.close()
130+
117131
pred_func(datasource=datasource,
118132
select=select,
119133
result_table=result_table,
120-
label_name=label_name,
134+
result_column_names=result_column_names,
135+
train_label_idx=train_label_idx,
121136
model=model)
122137

123138

@@ -132,15 +147,25 @@ def submit_local_evaluate(datasource,
132147
model = Model.load_from_db(datasource, model)
133148
if model.get_type() == EstimatorType.XGBOOST:
134149
evaluate_func = xgboost_evaluate
150+
validation_metrics = model_params.get("validation.metrics",
151+
"accuracy_score")
135152
else:
136153
evaluate_func = tf_evaluate
154+
validation_metrics = model_params.get("validation.metrics", "Accuracy")
155+
156+
conn = db.connect_with_data_source(datasource)
157+
validation_metrics = [m.strip() for m in validation_metrics.split(",")]
158+
result_column_names = create_evaluate_table(conn, result_table,
159+
validation_metrics)
160+
conn.close()
137161

138162
evaluate_func(datasource=datasource,
139163
select=select,
140164
result_table=result_table,
141165
model=model,
142166
label_name=label_name,
143-
model_params=model_params)
167+
model_params=model_params,
168+
result_column_names=result_column_names)
144169

145170

146171
def submit_local_explain(datasource,
@@ -157,12 +182,24 @@ def submit_local_explain(datasource,
157182
else:
158183
explain_func = tf_explain
159184

185+
if result_table:
186+
feature_columns = model.get_meta("features")
187+
estimator_string = model.get_meta("class_name")
188+
field_descs = get_ordered_field_descs(feature_columns)
189+
feature_column_names = [fd.name for fd in field_descs]
190+
with db.connect_with_data_source(datasource) as conn:
191+
create_explain_table(conn, model.get_type(), explainer,
192+
estimator_string, result_table,
193+
feature_column_names)
194+
160195
explain_func(datasource=datasource,
161196
select=select,
162197
explainer=explainer,
163198
model_params=model_params,
164199
result_table=result_table,
165200
model=model)
201+
if not result_table:
202+
print_image_as_base64_html("summary.png")
166203

167204

168205
def submit_local_run(datasource, select, image_name, params, into):

python/runtime/local/submitter_test.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313

1414
import unittest
1515

16+
import runtime.temp_file as temp_file
1617
import runtime.testing as testing
1718
from runtime.feature.column import NumericColumn
1819
from runtime.feature.field_desc import FieldDesc
19-
from runtime.local import train
20+
from runtime.local import evaluate, explain, pred, train
2021

2122

2223
class TestXGBoostTrain(unittest.TestCase):
@@ -39,13 +40,39 @@ def test_train(self):
3940
"num_boost_round": 20,
4041
}
4142
model_params = {"num_class": 3, "objective": "multi:softmax"}
42-
eval_result = train(ds, original_sql, select, val_select,
43-
"xgboost.gbtree", "", None,
44-
NumericColumn(FieldDesc(name="class")),
45-
model_params, train_params, None,
46-
"iris.xgboost_train_model_test", None)
47-
self.assertLess(eval_result['train']['merror'][-1], 0.01)
48-
self.assertLess(eval_result['validate']['merror'][-1], 0.01)
43+
with temp_file.TemporaryDirectory(as_cwd=True):
44+
eval_result = train(ds, original_sql, select, val_select,
45+
"xgboost.gbtree", "", None,
46+
NumericColumn(FieldDesc(name="class")),
47+
model_params, train_params, None,
48+
"iris.xgboost_train_model_test", None)
49+
self.assertLess(eval_result['train']['merror'][-1], 0.01)
50+
self.assertLess(eval_result['validate']['merror'][-1], 0.01)
51+
52+
with temp_file.TemporaryDirectory(as_cwd=True):
53+
pred_original_sql = """SELECT * FROM iris.test
54+
TO PREDICT iris.xgboost_pred_result.pred_val
55+
USING iris.xgboost_train_model_test;"""
56+
pred(ds, pred_original_sql, "SELECT * FROM iris.test",
57+
"iris.xgboost_train_model_test", "pred_val", model_params,
58+
"iris.xgboost_pred_result")
59+
60+
with temp_file.TemporaryDirectory(as_cwd=True):
61+
explain_original_sql = """SELECT * FROM iris.test
62+
TO EXPLAIN iris.xgboost_train_model_test
63+
INTO iris.xgboost_explain_result;"""
64+
explain(ds, explain_original_sql, "SELECT * FROM iris.test",
65+
"iris.xgboost_train_model_test", model_params,
66+
"iris.xgboost_explain_result")
67+
68+
with temp_file.TemporaryDirectory(as_cwd=True):
69+
evaluate_original_sql = """SELECT * FROM iris.test
70+
TO EVALUATE iris.xgboost_train_model_test
71+
WITH label_col=class
72+
INTO iris.xgboost_evaluate_result;"""
73+
evaluate(ds, evaluate_original_sql, "SELECT * FROM iris.test",
74+
"class", "iris.xgboost_train_model_test", model_params,
75+
"iris.xgboost_evaluate_result")
4976

5077

5178
if __name__ == '__main__':

python/runtime/pai/create_result_table.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# limitations under the License.
1313

1414
from runtime import db
15-
from runtime.diagnostics import SQLFlowDiagnostic
1615
from runtime.model import EstimatorType
1716
from runtime.pai import table_ops
1817

@@ -58,61 +57,6 @@ def create_predict_result_table(datasource, select, result_table, label_column,
5857
(result_table, train_label_column))
5958

6059

61-
# (TODO: lhw) This function is a common tool for prediction
62-
# on all platforms, we need to move it to a new file
63-
def create_explain_result_table(datasource, data_table, result_table,
64-
model_type, estimator, label_column):
65-
"""Create explain result table from given datasource
66-
67-
Args:
68-
datasource: current datasource
69-
data_table: input data table name
70-
result_table: table name to store the result
71-
model_type: type of the model to use
72-
estimator: estimator class if the model is TensorFlow estimator
73-
label_column: column name of the predict label
74-
"""
75-
conn = db.connect_with_data_source(datasource)
76-
drop_stmt = "DROP TABLE IF EXISTS %s" % result_table
77-
conn.execute(drop_stmt)
78-
79-
create_stmt = ""
80-
if model_type == EstimatorType.PAIML:
81-
return
82-
elif model_type == EstimatorType.TENSORFLOW:
83-
if estimator.startswith("BoostedTrees"):
84-
column_def = ""
85-
if conn.driver == "mysql":
86-
column_def = "(feature VARCHAR(255), dfc FLOAT, gain FLOAT)"
87-
else:
88-
# Hive & MaxCompute
89-
column_def = "(feature STRING, dfc STRING, gain STRING)"
90-
create_stmt = "CREATE TABLE IF NOT EXISTS %s %s;" % (result_table,
91-
column_def)
92-
else:
93-
if not label_column:
94-
raise SQLFlowDiagnostic(
95-
"need to specify WITH label_col=lable_col_name "
96-
"when explaining deep models")
97-
create_stmt = get_create_shap_result_sql(conn, data_table,
98-
result_table,
99-
label_column)
100-
elif model_type == EstimatorType.XGBOOST:
101-
if not label_column:
102-
raise SQLFlowDiagnostic(
103-
"need to specify WITH label_col=lable_col_name "
104-
"when explaining xgboost models")
105-
create_stmt = get_create_shap_result_sql(conn, data_table,
106-
result_table, label_column)
107-
else:
108-
raise SQLFlowDiagnostic(
109-
"not supported modelType %d for creating Explain result table" %
110-
model_type)
111-
112-
if not conn.execute(create_stmt):
113-
raise SQLFlowDiagnostic("Can't create explain result table")
114-
115-
11660
def get_create_shap_result_sql(conn, data_table, result_table, label_column):
11761
"""Get a sql statement which create a result table for SHAP
11862

python/runtime/pai/submitter_evaluate.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515

1616
import runtime.temp_file as temp_file
17+
from runtime import db
1718
from runtime.diagnostics import SQLFlowDiagnostic
1819
from runtime.model import EstimatorType
1920
from runtime.pai import cluster_conf, pai_model, table_ops
@@ -22,6 +23,7 @@
2223
from runtime.pai.prepare_archive import prepare_archive
2324
from runtime.pai.submit_pai_task import submit_pai_task
2425
from runtime.pai_local.try_run import try_pai_local_run
26+
from runtime.step.create_result_table import create_evaluate_table
2527

2628

2729
def submit_pai_evaluate(datasource,
@@ -72,12 +74,22 @@ def submit_pai_evaluate(datasource,
7274

7375
if model_type == EstimatorType.XGBOOST:
7476
params["entry_type"] = "evaluate_xgb"
77+
validation_metrics = model_params.get("validation.metrics",
78+
"accuracy_score")
7579
else:
7680
params["entry_type"] = "evaluate_tf"
81+
validation_metrics = model_params.get("validation.metrics", "Accuracy")
82+
83+
conn = db.connect_with_data_source(datasource)
84+
validation_metrics = [m.strip() for m in validation_metrics.split(",")]
85+
result_column_names = create_evaluate_table(conn, result_table,
86+
validation_metrics)
87+
conn.close()
7788

78-
# create_evaluate_result_table(datasource, result_table, metrics)
7989
with table_ops.create_tmp_tables_guard(select, datasource) as data_table:
8090
params["pai_table"] = data_table
91+
params["oss_model_path"] = oss_model_path
92+
params["result_column_names"] = result_column_names
8193

8294
if try_pai_local_run(params, oss_model_path):
8395
return

python/runtime/pai/submitter_explain.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import runtime.temp_file as temp_file
1919
from runtime import db
2020
from runtime.diagnostics import SQLFlowDiagnostic
21+
from runtime.feature.derivation import get_ordered_field_descs
2122
from runtime.model import EstimatorType
2223
from runtime.model.model import Model
2324
from runtime.pai import cluster_conf, pai_model, table_ops
@@ -26,6 +27,7 @@
2627
from runtime.pai.prepare_archive import prepare_archive
2728
from runtime.pai.submit_pai_task import submit_pai_task
2829
from runtime.pai_local.try_run import try_pai_local_run
30+
from runtime.step.create_result_table import create_explain_table
2931
from runtime.step.tensorflow.explain import print_image_as_base64_html
3032

3133

@@ -192,19 +194,30 @@ def submit_pai_explain(datasource,
192194
# is like: "SELECT fields,... FROM table"
193195
with table_ops.create_tmp_tables_guard(select, datasource) as data_table:
194196
params["pai_table"] = data_table
197+
params["oss_model_path"] = oss_model_path
198+
199+
# Create explain result table
200+
if result_table:
201+
conn = db.connect_with_data_source(datasource)
202+
feature_columns = meta.get_meta("features")
203+
estimator_string = meta.get_meta("class_name")
204+
field_descs = get_ordered_field_descs(feature_columns)
205+
feature_column_names = [fd.name for fd in field_descs]
206+
create_explain_table(conn, meta.get_type(), explainer,
207+
estimator_string, result_table,
208+
feature_column_names)
209+
conn.close()
195210

196211
if not try_pai_local_run(params, oss_model_path):
197212
with temp_file.TemporaryDirectory(prefix="sqlflow",
198213
dir="/tmp") as cwd:
199214
prepare_archive(cwd, estimator, oss_model_path, params)
200-
201215
cmd = get_pai_explain_cmd(
202216
datasource, project, oss_model_path, model, data_table,
203217
result_table, model_type, model_params,
204218
"file://" + os.path.join(cwd, JOB_ARCHIVE_FILE),
205219
"file://" + os.path.join(cwd, PARAMS_FILE), label_name)
206-
207-
submit_pai_task(cmd, datasource)
220+
submit_pai_task(cmd, datasource)
208221

209222
print_oss_image(params["oss_dest"], params["oss_ak"], params["oss_sk"],
210223
params["oss_endpoint"], params["oss_bucket_name"])

python/runtime/pai/submitter_predict.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
import runtime.temp_file as temp_file
1717
from runtime import db
1818
from runtime.diagnostics import SQLFlowDiagnostic
19-
from runtime.model import EstimatorType
19+
from runtime.model import EstimatorType, oss
2020
from runtime.pai import cluster_conf, pai_model, table_ops
2121
from runtime.pai.get_pai_tf_cmd import (ENTRY_FILE, JOB_ARCHIVE_FILE,
2222
PARAMS_FILE, get_pai_tf_cmd)
2323
from runtime.pai.prepare_archive import prepare_archive
2424
from runtime.pai.submit_pai_task import submit_pai_task
2525
from runtime.pai_local.try_run import try_pai_local_run
26+
from runtime.step.create_result_table import create_predict_table
2627

2728

2829
def get_pai_predict_cmd(datasource, project, oss_model_path, model_name,
@@ -122,16 +123,26 @@ def submit_pai_predict(datasource,
122123
datasource, model)
123124
setup_predict_entry(params, model_type)
124125

126+
# TODO(typhoonzero): load model meta from database.
125127
oss_model_path = pai_model.get_oss_model_save_path(datasource,
126128
model,
127129
user=user)
130+
model_metas = oss.load_metas(oss_model_path, "xgboost_model_desc")
131+
train_label_desc = model_metas[5].get_field_desc()[0]
132+
conn = db.connect_with_data_source(datasource)
133+
result_column_names, train_label_idx = create_predict_table(
134+
conn, select, result_table, train_label_desc, label_name)
135+
conn.close()
128136

129137
# TODO(typhoonzero): Do **NOT** create tmp table when the select statement
130138
# is like: "SELECT fields,... FROM table"
131139
with table_ops.create_tmp_tables_guard(select, datasource) as data_table:
140+
del params["label_name"]
132141
params["pai_table"] = data_table
133142
params["oss_model_path"] = oss_model_path
134143
params["model"] = ""
144+
params["result_column_names"] = result_column_names
145+
params["train_label_idx"] = train_label_idx
135146

136147
if try_pai_local_run(params, oss_model_path):
137148
return

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy