5
5
import concurrent .futures
6
6
7
7
logging .basicConfig (level = logging .INFO )
8
- logger = logging . getLogger ( __name__ )
9
-
8
+ # Instead of using full name, only use the module name
9
+ logger = logging . getLogger ( __name__ . split ( "." )[ 0 ])
10
10
11
11
class RowGenerateResult :
12
12
"""
@@ -78,6 +78,7 @@ def __init__(
78
78
self ._prompt_formatter = prompt_formatter
79
79
self ._batch_size = batch_size
80
80
self ._concurrency = concurrency
81
+ self .input_variables = prompt_formatter .variables
81
82
82
83
def _generate (
83
84
self , prompts : list , temperature : float , max_tokens = 256 , system_prompt = None
@@ -95,7 +96,7 @@ def run_tasks(
95
96
Returns:
96
97
EvalResult: the evaluation result
97
98
"""
98
- prompt_batches = []
99
+ task_batches = []
99
100
# First, traverse the input dataframe using batch size
100
101
for i in range (0 , len (input_df ), self ._batch_size ):
101
102
# Get the current batch
@@ -107,9 +108,13 @@ def run_tasks(
107
108
# Format the input dataframe into prompts
108
109
prompt = self ._prompt_formatter .format (** row )
109
110
prompts .append (prompt )
110
- prompt_batches .append (prompts )
111
+ task = {
112
+ "prompts" : prompts ,
113
+ "df" : batch_df ,
114
+ }
115
+ task_batches .append (task )
111
116
logger .info (
112
- f"Generated total number of batches for prompts: { len (prompt_batches )} "
117
+ f"Generated total number of batches for prompts: { len (task_batches )} "
113
118
)
114
119
115
120
# Call the _generate in parallel using multiple threads, each call with a batch of prompts
@@ -118,15 +123,28 @@ def run_tasks(
118
123
) as executor :
119
124
future_to_batch = {
120
125
executor .submit (
121
- self ._generate , prompts , temperature , max_tokens , system_prompt
122
- ): prompts
123
- for prompts in prompt_batches
126
+ self ._generate ,
127
+ task ["prompts" ],
128
+ temperature ,
129
+ max_tokens ,
130
+ system_prompt ,
131
+ ): task
132
+ for task in task_batches
124
133
}
125
134
batch_generate_results = []
126
135
for future in concurrent .futures .as_completed (future_to_batch ):
127
- prompts = future_to_batch [future ]
136
+ task = future_to_batch [future ]
128
137
try :
129
138
result = future .result ()
139
+ batch_df = task ["df" ]
140
+ # Add the columns from batch_df where the column name is in the input_variables, add as attribute and value to the RowEvalResult
141
+ for index , row in enumerate (result .rows ):
142
+ for input_variable in self .input_variables :
143
+ setattr (
144
+ row ,
145
+ input_variable ,
146
+ batch_df [input_variable ].iloc [index ],
147
+ )
130
148
batch_generate_results .append (result )
131
149
except Exception as exc :
132
150
logger .error (f"Exception occurred when running the task: { exc } " )
@@ -268,7 +286,7 @@ def __init__(
268
286
self ._tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
269
287
270
288
def _format_prompt (self , message : str , system_prompt_opt : str ) -> str :
271
- if system_prompt_opt is None :
289
+ if system_prompt_opt is not None :
272
290
texts = [f"[INST] <<SYS>>\n { system_prompt_opt } \n <</SYS>>\n \n " ]
273
291
texts .append (f"{ message .strip ()} [/INST]" )
274
292
return "" .join (texts )
@@ -323,3 +341,187 @@ def _generate(
323
341
is_successful = True ,
324
342
error_msg = None ,
325
343
)
344
+
345
+
346
+ class VicunaModelGenerator (BaseModelGenerator ):
347
+ def __init__ (
348
+ self ,
349
+ prompt_formatter : PromptTemplate ,
350
+ model_name_or_path : str ,
351
+ batch_size : int = 1 ,
352
+ concurrency : int = 1 ,
353
+ ) -> None :
354
+ """
355
+ Args:
356
+ prompt_formatter (PromptTemplate): the prompt format to format the input dataframe into prompts row by row according to the column names
357
+ model_name (str): the model name
358
+ batch_size (int, optional): Batch size that will be used to run tasks. Defaults to 1, which means it's sequential.
359
+
360
+ Recommendations:
361
+ - for A100 80GB, use batch_size 1 for vicuna-33b
362
+ - for A100 80GB x 2, use batch_size 64 for vicuna-33b
363
+ """
364
+ super ().__init__ (prompt_formatter , batch_size , concurrency )
365
+ # require the concurrency to be 1 to avoid race condition during inference
366
+ if concurrency != 1 :
367
+ raise ValueError (
368
+ "VicunaModelGenerator currently only supports concurrency 1"
369
+ )
370
+ self ._model_name_or_path = model_name_or_path
371
+ import torch
372
+ from transformers import (
373
+ AutoModelForCausalLM ,
374
+ AutoTokenizer ,
375
+ TextIteratorStreamer ,
376
+ )
377
+
378
+ if torch .cuda .is_available ():
379
+ self ._model = AutoModelForCausalLM .from_pretrained (
380
+ model_name_or_path , torch_dtype = torch .float16 , device_map = "auto"
381
+ )
382
+ else :
383
+ raise ValueError ("VicunaModelGenerator currently only supports GPU" )
384
+ self ._tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
385
+
386
+ def _format_prompt (self , message : str , system_prompt_opt : str ) -> str :
387
+ if system_prompt_opt is not None :
388
+ return f"""{ system_prompt_opt }
389
+
390
+ USER: { message }
391
+ ASSISTANT:
392
+ """
393
+ else :
394
+ return f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
395
+
396
+ USER: { message }
397
+ ASSISTANT:
398
+ """
399
+
400
+ def _generate (
401
+ self , prompts : list , temperature : float , max_tokens = 256 , system_prompt = None
402
+ ) -> BatchGenerateResult :
403
+ from transformers import pipeline
404
+
405
+ all_formatted_prompts = [
406
+ self ._format_prompt (message = message , system_prompt_opt = system_prompt )
407
+ for message in prompts
408
+ ]
409
+
410
+ top_p = 0.95
411
+ repetition_penalty = 1.15
412
+ pipe = pipeline (
413
+ "text-generation" ,
414
+ model = self ._model ,
415
+ tokenizer = self ._tokenizer ,
416
+ max_new_tokens = max_tokens ,
417
+ temperature = temperature ,
418
+ top_p = top_p ,
419
+ repetition_penalty = repetition_penalty ,
420
+ return_full_text = False ,
421
+ )
422
+ responses = pipe (all_formatted_prompts )
423
+ rows = []
424
+ for index , response in enumerate (responses ):
425
+ response_content = response [0 ]["generated_text" ]
426
+ row_generate_result = RowGenerateResult (
427
+ is_successful = True ,
428
+ error_msg = None ,
429
+ answer = response_content ,
430
+ temperature = temperature ,
431
+ max_tokens = max_tokens ,
432
+ model_name = self ._model_name_or_path ,
433
+ top_p = top_p ,
434
+ repetition_penalty = repetition_penalty ,
435
+ prompts = all_formatted_prompts [index ],
436
+ )
437
+ rows .append (row_generate_result )
438
+
439
+ return BatchGenerateResult (
440
+ num_rows = len (rows ),
441
+ num_successful_rows = len (rows ),
442
+ rows = rows ,
443
+ is_successful = True ,
444
+ error_msg = None ,
445
+ )
446
+
447
+
448
+ class DriverProxyModelGenerator (BaseModelGenerator ):
449
+ def __init__ (
450
+ self ,
451
+ url : str ,
452
+ pat_token : str ,
453
+ prompt_formatter : PromptTemplate ,
454
+ batch_size : int = 32 ,
455
+ concurrency : int = 1 ,
456
+ ) -> None :
457
+ """
458
+ Args:
459
+ prompt_formatter (PromptTemplate): the prompt format to format the input dataframe into prompts row by row according to the column names
460
+ model_name (str): the model name
461
+ batch_size (int, optional): Batch size that will be used to run tasks. Defaults to 1, which means it's sequential.
462
+
463
+ Recommendations:
464
+ - for A100 80GB, use batch_size 16 for llama-2-13b-chat
465
+ """
466
+ super ().__init__ (prompt_formatter , batch_size , concurrency )
467
+ self ._url = url
468
+ self ._pat_token = pat_token
469
+
470
+ def _format_prompt (self , message : str , system_prompt_opt : str ) -> str :
471
+ if system_prompt_opt is not None :
472
+ texts = [f"[INST] <<SYS>>\n { system_prompt_opt } \n <</SYS>>\n \n " ]
473
+ texts .append (f"{ message .strip ()} [/INST]" )
474
+ return "" .join (texts )
475
+ else :
476
+ texts = [f"[INST] \n \n " ]
477
+ texts .append (f"{ message .strip ()} [/INST]" )
478
+ return "" .join (texts )
479
+
480
+ def _generate (
481
+ self , prompts : list , temperature : float , max_tokens = 256 , system_prompt = None
482
+ ) -> BatchGenerateResult :
483
+ top_p = 0.95
484
+
485
+ all_formatted_prompts = [
486
+ self ._format_prompt (message = message , system_prompt_opt = system_prompt )
487
+ for message in prompts
488
+ ]
489
+
490
+ import requests
491
+ import json
492
+
493
+ headers = {
494
+ "Authentication" : f"Bearer { self ._pat_token } " ,
495
+ "Content-Type" : "application/json" ,
496
+ }
497
+
498
+ data = {
499
+ "prompts" : all_formatted_prompts ,
500
+ "temperature" : temperature ,
501
+ "max_tokens" : max_tokens ,
502
+ }
503
+
504
+ response = requests .post (self ._url , headers = headers , data = json .dumps (data ))
505
+
506
+ # Extract the "outputs" as a JSON array from the response
507
+ outputs = response .json ()["outputs" ]
508
+ rows = []
509
+ for index , response_content in enumerate (outputs ):
510
+ row_generate_result = RowGenerateResult (
511
+ is_successful = True ,
512
+ error_msg = None ,
513
+ answer = response_content ,
514
+ temperature = temperature ,
515
+ max_tokens = max_tokens ,
516
+ top_p = top_p ,
517
+ prompts = all_formatted_prompts [index ],
518
+ )
519
+ rows .append (row_generate_result )
520
+
521
+ return BatchGenerateResult (
522
+ num_rows = len (rows ),
523
+ num_successful_rows = len (rows ),
524
+ rows = rows ,
525
+ is_successful = True ,
526
+ error_msg = None ,
527
+ )
0 commit comments