Skip to content

Commit 3780ee3

Browse files
authored
Add driver proxy generator (databrickslabs#10)
1 parent f6db63d commit 3780ee3

File tree

10 files changed

+552
-149
lines changed

10 files changed

+552
-149
lines changed

databricks/labs/doc_qa/chatbot/chatbot.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22
from databricks.labs.doc_qa.llm_utils import PromptTemplate
33
import openai
44
from databricks.labs.doc_qa.chatbot.retriever import Document, BaseRetriever
5-
import logging
5+
from databricks.labs.doc_qa.logging_utils import logger
66
import tiktoken
7-
import dataclasses
8-
9-
logger = logging.getLogger(__name__)
107

118
class LlmProvider:
129

@@ -16,9 +13,9 @@ def prompt(self, prompt: str, **kwargs) -> str:
1613

1714
class OpenAILlmProvider(LlmProvider):
1815

19-
def __init__(self, api_key: str, model, temperature, **kwargs):
16+
def __init__(self, api_key: str, model_name: str, temperature, **kwargs):
2017
self._api_key = api_key
21-
self._model = model
18+
self.model_name = model_name
2219
self._temperature = temperature
2320
openai.api_key = api_key
2421

@@ -51,7 +48,7 @@ def __init__(self, llm_provider: str, retriever: BaseRetriever,
5148
self._whole_prompt_template = whole_prompt_template
5249
self._document_prompt_tempate = document_prompt_tempate
5350
self._max_num_tokens_for_context = max_num_tokens_for_context
54-
self._enc = tiktoken.encoding_for_model(self._llm_provider.model)
51+
self._enc = tiktoken.encoding_for_model(self._llm_provider.model_name)
5552

5653
def chat(self, prompt: str, top_k=1, **kwargs) -> ChatResponse:
5754
"""

databricks/labs/doc_qa/chatbot/retriever.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from datetime import datetime
22
import pandas as pd
33
from databricks.labs.doc_qa.llm_utils import PromptTemplate
4+
from databricks.labs.doc_qa.logging_utils import logger
45
import openai
5-
import logging
66
import faiss
77
import numpy as np
88
import json
99

1010

11-
logger = logging.getLogger(__name__)
12-
1311

1412
class EmbeddingProvider:
1513

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from databricks.labs.doc_qa.llm_utils import PromptTemplate
2+
import pandas as pd
3+
import os
4+
from databricks.labs.doc_qa.evaluators.templated_evaluator import (
5+
OpenAIEvaluator,
6+
RetryPolicy,
7+
)
8+
from databricks.labs.doc_qa.variables.doc_qa_template_variables import (
9+
get_openai_grading_template_and_function,
10+
)
11+
from databricks.labs.doc_qa.logging_utils import logger
12+
13+
14+
def gpt_4_evaluator():
15+
retry_policy = RetryPolicy(max_retry_on_invalid_result=3, max_retry_on_exception=3)
16+
(
17+
openai_grading_prompt,
18+
openai_grading_function,
19+
) = get_openai_grading_template_and_function(scale=10, level_of_details=2)
20+
return OpenAIEvaluator(
21+
model="gpt-4",
22+
temperature=0.0,
23+
grading_prompt_tempate=openai_grading_prompt,
24+
input_columns=["question", "answer", "context"],
25+
openai_function=openai_grading_function,
26+
retry_policy=retry_policy,
27+
)
28+
29+
30+
def vllm_vicuna_model_generator(url, pat_token, model_name):
31+
from databricks.labs.doc_qa.model_generators.model_generator import (
32+
vLllmOpenAICompletionFormatModelGenerator,
33+
)
34+
from databricks.labs.doc_qa.variables.doc_qa_template_variables import (
35+
vicuna_prompt_format_func,
36+
doc_qa_task_prompt_template,
37+
)
38+
39+
return vLllmOpenAICompletionFormatModelGenerator(
40+
url=url,
41+
pat_token=pat_token,
42+
prompt_formatter=doc_qa_task_prompt_template,
43+
batch_size=1,
44+
model_name=model_name,
45+
format_prompt_func=vicuna_prompt_format_func,
46+
concurrency=100,
47+
)
48+
49+
50+
def vllm_llama2_model_generator(url, pat_token, model_name):
51+
from databricks.labs.doc_qa.model_generators.model_generator import (
52+
vLllmOpenAICompletionFormatModelGenerator,
53+
)
54+
from databricks.labs.doc_qa.variables.doc_qa_template_variables import (
55+
llama2_prompt_format_func,
56+
doc_qa_task_prompt_template,
57+
)
58+
59+
return vLllmOpenAICompletionFormatModelGenerator(
60+
url=url,
61+
pat_token=pat_token,
62+
prompt_formatter=doc_qa_task_prompt_template,
63+
batch_size=1,
64+
model_name=model_name,
65+
format_prompt_func=llama2_prompt_format_func,
66+
concurrency=100,
67+
)
68+
69+
70+
def generate_and_evaluate(
71+
input_df, model_generator, evaluator, temperature=0, max_tokens=200
72+
):
73+
generate_result = model_generator.run_tasks(
74+
input_df=input_df, temperature=temperature, max_tokens=max_tokens
75+
)
76+
77+
result_df = generate_result.to_dataframe()
78+
79+
logger.info(f"Finished generating {len(result_df)} rows, starting evaluation")
80+
return evaluator.run_eval(dataset_df=result_df, concurrency=20, catch_error=True)
81+
82+
83+
def evaluate_using_vllm_locally(
84+
input_df,
85+
hf_model_name,
86+
prompt_tempate_format_func,
87+
temperature=0,
88+
max_tokens=200,
89+
max_num_batched_tokens=None,
90+
):
91+
from databricks.labs.doc_qa.model_generators.model_generator import (
92+
vLllmLocalModelGenerator,
93+
)
94+
from databricks.labs.doc_qa.variables.doc_qa_template_variables import (
95+
doc_qa_task_prompt_template,
96+
)
97+
98+
model_generator = vLllmLocalModelGenerator(
99+
hf_model_name=hf_model_name,
100+
format_prompt_func=prompt_tempate_format_func,
101+
prompt_formatter=doc_qa_task_prompt_template,
102+
max_num_batched_tokens=max_num_batched_tokens,
103+
)
104+
105+
evaluator = gpt_4_evaluator()
106+
generate_result = model_generator.run_tasks(
107+
input_df=input_df, temperature=temperature, max_tokens=max_tokens
108+
)
109+
110+
result_df = generate_result.to_dataframe()
111+
112+
logger.info(f"Finished generating {len(result_df)} rows, starting evaluation")
113+
return evaluator.run_eval(dataset_df=result_df, concurrency=20, catch_error=True)

databricks/labs/doc_qa/evaluators/templated_evaluator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,13 @@
77
from databricks.labs.doc_qa.llm_providers import openai_provider
88
from databricks.labs.doc_qa.llm_providers import anthropic_provider
99
import json
10-
import logging
10+
from databricks.labs.doc_qa.logging_utils import logger
1111
from tenacity import retry, stop_after_attempt, retry_if_result, retry_if_exception
1212
import re
1313
from enum import Enum
1414
import json
1515

1616

17-
logging.basicConfig(level=logging.INFO)
18-
logger = logging.getLogger(__name__.split(".")[0])
19-
20-
2117
class ParameterType(Enum):
2218
STRING = 'string'
2319
NUMBER = 'number'

databricks/labs/doc_qa/llm_providers/anthropic_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import os
44
import requests
55
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, retry_if_exception_type, retry_if_exception
6-
import logging
6+
from databricks.labs.doc_qa.logging_utils import logger
77
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
8+
import logging
89

910
anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')
1011

@@ -16,7 +17,6 @@ def supress_httpx_logs():
1617
logger.setLevel(logging.WARNING)
1718

1819
supress_httpx_logs()
19-
logger = logging.getLogger(__name__)
2020

2121

2222
def request_anthropic(prompt, temperature=0.0, model="claude-2", max_tokens_to_sample=300):

databricks/labs/doc_qa/llm_providers/openai_provider.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
import os
44
import requests
55
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, retry_if_exception_type, retry_if_exception
6-
import logging
7-
8-
logging.basicConfig(level=logging.INFO)
9-
logger = logging.getLogger(__name__)
6+
from databricks.labs.doc_qa.logging_utils import logger
107

118

129
openai_token = os.getenv('OPENAI_API_KEY')
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import logging
2+
3+
logger = logging.getLogger("doc-qa")
4+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
5+
6+
handler = logging.StreamHandler()
7+
handler.setFormatter(formatter)
8+
9+
logger.addHandler(handler)
10+
logger.setLevel(logging.INFO)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy