Skip to content

Commit 78ba2ca

Browse files
authored
Update chatbot code. (databrickslabs#11)
1 parent 3780ee3 commit 78ba2ca

File tree

6 files changed

+385
-127
lines changed

6 files changed

+385
-127
lines changed

databricks/labs/doc_qa/chatbot/chatbot.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,74 +4,102 @@
44
from databricks.labs.doc_qa.chatbot.retriever import Document, BaseRetriever
55
from databricks.labs.doc_qa.logging_utils import logger
66
import tiktoken
7+
from concurrent.futures import ThreadPoolExecutor
78

8-
class LlmProvider:
99

10+
class LlmProvider:
1011
def prompt(self, prompt: str, **kwargs) -> str:
1112
raise NotImplementedError()
1213

1314

1415
class OpenAILlmProvider(LlmProvider):
15-
1616
def __init__(self, api_key: str, model_name: str, temperature, **kwargs):
1717
self._api_key = api_key
1818
self.model_name = model_name
1919
self._temperature = temperature
2020
openai.api_key = api_key
21-
21+
2222
def prompt(self, prompt: str, **kwargs) -> str:
23-
messages = [
24-
{
25-
"role": "user",
26-
"content": prompt
27-
}
28-
]
29-
response_message = openai_provider.request_openai(messages=messages, functions=[], model=self._model, temperature=self._temperature)
30-
return response_message['content']
23+
messages = [{"role": "user", "content": prompt}]
24+
response_message = openai_provider.request_openai(
25+
messages=messages,
26+
functions=[],
27+
model=self.model_name,
28+
temperature=self._temperature,
29+
)
30+
return response_message["content"]
31+
3132

3233
class ChatResponse:
34+
query: str
3335
content: str
3436
relevant_documents: list[Document]
3537

36-
def __init__(self, content: str, relevant_documents: list[Document]):
38+
def __init__(self, query: str, content: str, relevant_documents: list[Document]):
39+
self.query = query
3740
self.content = content
3841
self.relevant_documents = relevant_documents
3942

4043

4144
class BaseChatBot:
42-
def __init__(self, llm_provider: str, retriever: BaseRetriever,
43-
whole_prompt_template: PromptTemplate,
44-
document_prompt_tempate: PromptTemplate,
45-
max_num_tokens_for_context: int = 3500, **kwargs):
45+
def __init__(
46+
self,
47+
llm_provider: str,
48+
retriever: BaseRetriever,
49+
whole_prompt_template: PromptTemplate,
50+
document_prompt_tempate: PromptTemplate,
51+
max_num_tokens_for_context: int = 3500,
52+
**kwargs,
53+
):
4654
self._llm_provider = llm_provider
4755
self._retriever = retriever
4856
self._whole_prompt_template = whole_prompt_template
4957
self._document_prompt_tempate = document_prompt_tempate
5058
self._max_num_tokens_for_context = max_num_tokens_for_context
5159
self._enc = tiktoken.encoding_for_model(self._llm_provider.model_name)
5260

53-
def chat(self, prompt: str, top_k=1, **kwargs) -> ChatResponse:
61+
def chat_batch(
62+
self, queries: list[str], top_k=1, concurrency: int = 20, **kwargs
63+
) -> list[ChatResponse]:
64+
with ThreadPoolExecutor(max_workers=concurrency) as executor:
65+
results = executor.map(
66+
lambda query: self.chat(query=query, top_k=top_k, **kwargs),
67+
queries,
68+
)
69+
return list(results)
70+
71+
def chat(self, query: str, top_k=1, **kwargs) -> ChatResponse:
5472
"""
5573
Chat with the chatbot.
5674
"""
57-
relevant_documents = self._retriever.find_similar_docs(query=prompt, top_k=top_k)
75+
relevant_documents = self._retriever.find_similar_docs(query=query, top_k=top_k)
5876
# First, format the prompt for each document
5977
document_str = ""
6078
total_num_tokens = 0
6179
for index, document in enumerate(relevant_documents):
6280
# use all attributes of document, except for created_at or vector, as the format parameter
63-
doc_formated_prompt = self._document_prompt_tempate.format(**{k: v for k, v in document.__dict__.items() if k not in ['created_at', 'vector']})
81+
doc_formated_prompt = self._document_prompt_tempate.format(
82+
**{
83+
k: v
84+
for k, v in document.__dict__.items()
85+
if k not in ["created_at", "vector"]
86+
}
87+
)
6488
num_tokens = len(self._enc.encode(doc_formated_prompt))
6589
if total_num_tokens + num_tokens > self._max_num_tokens_for_context:
66-
logger.warning(f"Exceeding max number of tokens for context: {self._max_num_tokens_for_context}, existing on {index}th document out of {len(relevant_documents)} documents")
90+
logger.warning(
91+
f"Exceeding max number of tokens for context: {self._max_num_tokens_for_context}, existing on {index}th document out of {len(relevant_documents)} documents"
92+
)
6793
break
6894
total_num_tokens += num_tokens
6995
document_str += doc_formated_prompt + "\n"
7096
logger.debug(f"Document string: {document_str}")
7197
# Then, format the whole prompt
72-
whole_prompt = self._whole_prompt_template.format(context=document_str, prompt=prompt)
98+
whole_prompt = self._whole_prompt_template.format(
99+
context=document_str, query=query
100+
)
73101
logger.debug(f"Whole prompt: {whole_prompt}")
74102
response = self._llm_provider.prompt(whole_prompt)
75-
return ChatResponse(content=response, relevant_documents=relevant_documents)
76-
77-
103+
return ChatResponse(
104+
query=query, content=response, relevant_documents=relevant_documents
105+
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy