Skip to content

Commit f6db63d

Browse files
authored
Add ModelGenerator for LLama2, driver proxy and Vicuna (databrickslabs#9)
1 parent 96ec0c4 commit f6db63d

File tree

5 files changed

+646
-317
lines changed

5 files changed

+646
-317
lines changed

databricks/labs/doc_qa/evaluators/templated_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
logging.basicConfig(level=logging.INFO)
18-
logger = logging.getLogger(__name__)
18+
logger = logging.getLogger(__name__.split(".")[0])
1919

2020

2121
class ParameterType(Enum):

databricks/labs/doc_qa/llm_providers/openai_provider.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111

1212
openai_token = os.getenv('OPENAI_API_KEY')
13-
openai_org = os.getenv('OPENAI_ORGANIZATION')
1413

1514
class StatusCode429Error(Exception):
1615
pass
@@ -24,7 +23,6 @@ def request_openai(messages, functions=[], temperature=0.0, model="gpt-4"):
2423
headers = {
2524
"Content-Type": "application/json",
2625
"Authorization": f"Bearer {openai_token}",
27-
"OpenAI-Organization": openai_org,
2826
}
2927
data = {
3028
"model": model,

databricks/labs/doc_qa/model_generators/model_generator.py

Lines changed: 212 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import concurrent.futures
66

77
logging.basicConfig(level=logging.INFO)
8-
logger = logging.getLogger(__name__)
9-
8+
# Instead of using full name, only use the module name
9+
logger = logging.getLogger(__name__.split(".")[0])
1010

1111
class RowGenerateResult:
1212
"""
@@ -78,6 +78,7 @@ def __init__(
7878
self._prompt_formatter = prompt_formatter
7979
self._batch_size = batch_size
8080
self._concurrency = concurrency
81+
self.input_variables = prompt_formatter.variables
8182

8283
def _generate(
8384
self, prompts: list, temperature: float, max_tokens=256, system_prompt=None
@@ -95,7 +96,7 @@ def run_tasks(
9596
Returns:
9697
EvalResult: the evaluation result
9798
"""
98-
prompt_batches = []
99+
task_batches = []
99100
# First, traverse the input dataframe using batch size
100101
for i in range(0, len(input_df), self._batch_size):
101102
# Get the current batch
@@ -107,9 +108,13 @@ def run_tasks(
107108
# Format the input dataframe into prompts
108109
prompt = self._prompt_formatter.format(**row)
109110
prompts.append(prompt)
110-
prompt_batches.append(prompts)
111+
task = {
112+
"prompts": prompts,
113+
"df": batch_df,
114+
}
115+
task_batches.append(task)
111116
logger.info(
112-
f"Generated total number of batches for prompts: {len(prompt_batches)}"
117+
f"Generated total number of batches for prompts: {len(task_batches)}"
113118
)
114119

115120
# Call the _generate in parallel using multiple threads, each call with a batch of prompts
@@ -118,15 +123,28 @@ def run_tasks(
118123
) as executor:
119124
future_to_batch = {
120125
executor.submit(
121-
self._generate, prompts, temperature, max_tokens, system_prompt
122-
): prompts
123-
for prompts in prompt_batches
126+
self._generate,
127+
task["prompts"],
128+
temperature,
129+
max_tokens,
130+
system_prompt,
131+
): task
132+
for task in task_batches
124133
}
125134
batch_generate_results = []
126135
for future in concurrent.futures.as_completed(future_to_batch):
127-
prompts = future_to_batch[future]
136+
task = future_to_batch[future]
128137
try:
129138
result = future.result()
139+
batch_df = task["df"]
140+
# Add the columns from batch_df where the column name is in the input_variables, add as attribute and value to the RowEvalResult
141+
for index, row in enumerate(result.rows):
142+
for input_variable in self.input_variables:
143+
setattr(
144+
row,
145+
input_variable,
146+
batch_df[input_variable].iloc[index],
147+
)
130148
batch_generate_results.append(result)
131149
except Exception as exc:
132150
logger.error(f"Exception occurred when running the task: {exc}")
@@ -268,7 +286,7 @@ def __init__(
268286
self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
269287

270288
def _format_prompt(self, message: str, system_prompt_opt: str) -> str:
271-
if system_prompt_opt is None:
289+
if system_prompt_opt is not None:
272290
texts = [f"[INST] <<SYS>>\n{system_prompt_opt}\n<</SYS>>\n\n"]
273291
texts.append(f"{message.strip()} [/INST]")
274292
return "".join(texts)
@@ -323,3 +341,187 @@ def _generate(
323341
is_successful=True,
324342
error_msg=None,
325343
)
344+
345+
346+
class VicunaModelGenerator(BaseModelGenerator):
347+
def __init__(
348+
self,
349+
prompt_formatter: PromptTemplate,
350+
model_name_or_path: str,
351+
batch_size: int = 1,
352+
concurrency: int = 1,
353+
) -> None:
354+
"""
355+
Args:
356+
prompt_formatter (PromptTemplate): the prompt format to format the input dataframe into prompts row by row according to the column names
357+
model_name (str): the model name
358+
batch_size (int, optional): Batch size that will be used to run tasks. Defaults to 1, which means it's sequential.
359+
360+
Recommendations:
361+
- for A100 80GB, use batch_size 1 for vicuna-33b
362+
- for A100 80GB x 2, use batch_size 64 for vicuna-33b
363+
"""
364+
super().__init__(prompt_formatter, batch_size, concurrency)
365+
# require the concurrency to be 1 to avoid race condition during inference
366+
if concurrency != 1:
367+
raise ValueError(
368+
"VicunaModelGenerator currently only supports concurrency 1"
369+
)
370+
self._model_name_or_path = model_name_or_path
371+
import torch
372+
from transformers import (
373+
AutoModelForCausalLM,
374+
AutoTokenizer,
375+
TextIteratorStreamer,
376+
)
377+
378+
if torch.cuda.is_available():
379+
self._model = AutoModelForCausalLM.from_pretrained(
380+
model_name_or_path, torch_dtype=torch.float16, device_map="auto"
381+
)
382+
else:
383+
raise ValueError("VicunaModelGenerator currently only supports GPU")
384+
self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
385+
386+
def _format_prompt(self, message: str, system_prompt_opt: str) -> str:
387+
if system_prompt_opt is not None:
388+
return f"""{system_prompt_opt}
389+
390+
USER: {message}
391+
ASSISTANT:
392+
"""
393+
else:
394+
return f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
395+
396+
USER: {message}
397+
ASSISTANT:
398+
"""
399+
400+
def _generate(
401+
self, prompts: list, temperature: float, max_tokens=256, system_prompt=None
402+
) -> BatchGenerateResult:
403+
from transformers import pipeline
404+
405+
all_formatted_prompts = [
406+
self._format_prompt(message=message, system_prompt_opt=system_prompt)
407+
for message in prompts
408+
]
409+
410+
top_p = 0.95
411+
repetition_penalty = 1.15
412+
pipe = pipeline(
413+
"text-generation",
414+
model=self._model,
415+
tokenizer=self._tokenizer,
416+
max_new_tokens=max_tokens,
417+
temperature=temperature,
418+
top_p=top_p,
419+
repetition_penalty=repetition_penalty,
420+
return_full_text=False,
421+
)
422+
responses = pipe(all_formatted_prompts)
423+
rows = []
424+
for index, response in enumerate(responses):
425+
response_content = response[0]["generated_text"]
426+
row_generate_result = RowGenerateResult(
427+
is_successful=True,
428+
error_msg=None,
429+
answer=response_content,
430+
temperature=temperature,
431+
max_tokens=max_tokens,
432+
model_name=self._model_name_or_path,
433+
top_p=top_p,
434+
repetition_penalty=repetition_penalty,
435+
prompts=all_formatted_prompts[index],
436+
)
437+
rows.append(row_generate_result)
438+
439+
return BatchGenerateResult(
440+
num_rows=len(rows),
441+
num_successful_rows=len(rows),
442+
rows=rows,
443+
is_successful=True,
444+
error_msg=None,
445+
)
446+
447+
448+
class DriverProxyModelGenerator(BaseModelGenerator):
449+
def __init__(
450+
self,
451+
url: str,
452+
pat_token: str,
453+
prompt_formatter: PromptTemplate,
454+
batch_size: int = 32,
455+
concurrency: int = 1,
456+
) -> None:
457+
"""
458+
Args:
459+
prompt_formatter (PromptTemplate): the prompt format to format the input dataframe into prompts row by row according to the column names
460+
model_name (str): the model name
461+
batch_size (int, optional): Batch size that will be used to run tasks. Defaults to 1, which means it's sequential.
462+
463+
Recommendations:
464+
- for A100 80GB, use batch_size 16 for llama-2-13b-chat
465+
"""
466+
super().__init__(prompt_formatter, batch_size, concurrency)
467+
self._url = url
468+
self._pat_token = pat_token
469+
470+
def _format_prompt(self, message: str, system_prompt_opt: str) -> str:
471+
if system_prompt_opt is not None:
472+
texts = [f"[INST] <<SYS>>\n{system_prompt_opt}\n<</SYS>>\n\n"]
473+
texts.append(f"{message.strip()} [/INST]")
474+
return "".join(texts)
475+
else:
476+
texts = [f"[INST] \n\n"]
477+
texts.append(f"{message.strip()} [/INST]")
478+
return "".join(texts)
479+
480+
def _generate(
481+
self, prompts: list, temperature: float, max_tokens=256, system_prompt=None
482+
) -> BatchGenerateResult:
483+
top_p = 0.95
484+
485+
all_formatted_prompts = [
486+
self._format_prompt(message=message, system_prompt_opt=system_prompt)
487+
for message in prompts
488+
]
489+
490+
import requests
491+
import json
492+
493+
headers = {
494+
"Authentication": f"Bearer {self._pat_token}",
495+
"Content-Type": "application/json",
496+
}
497+
498+
data = {
499+
"prompts": all_formatted_prompts,
500+
"temperature": temperature,
501+
"max_tokens": max_tokens,
502+
}
503+
504+
response = requests.post(self._url, headers=headers, data=json.dumps(data))
505+
506+
# Extract the "outputs" as a JSON array from the response
507+
outputs = response.json()["outputs"]
508+
rows = []
509+
for index, response_content in enumerate(outputs):
510+
row_generate_result = RowGenerateResult(
511+
is_successful=True,
512+
error_msg=None,
513+
answer=response_content,
514+
temperature=temperature,
515+
max_tokens=max_tokens,
516+
top_p=top_p,
517+
prompts=all_formatted_prompts[index],
518+
)
519+
rows.append(row_generate_result)
520+
521+
return BatchGenerateResult(
522+
num_rows=len(rows),
523+
num_successful_rows=len(rows),
524+
rows=rows,
525+
is_successful=True,
526+
error_msg=None,
527+
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy