|
4 | 4 | from databricks.labs.doc_qa.chatbot.retriever import Document, BaseRetriever
|
5 | 5 | from databricks.labs.doc_qa.logging_utils import logger
|
6 | 6 | import tiktoken
|
| 7 | +from concurrent.futures import ThreadPoolExecutor |
7 | 8 |
|
8 |
| -class LlmProvider: |
9 | 9 |
|
| 10 | +class LlmProvider: |
10 | 11 | def prompt(self, prompt: str, **kwargs) -> str:
|
11 | 12 | raise NotImplementedError()
|
12 | 13 |
|
13 | 14 |
|
14 | 15 | class OpenAILlmProvider(LlmProvider):
|
15 |
| - |
16 | 16 | def __init__(self, api_key: str, model_name: str, temperature, **kwargs):
|
17 | 17 | self._api_key = api_key
|
18 | 18 | self.model_name = model_name
|
19 | 19 | self._temperature = temperature
|
20 | 20 | openai.api_key = api_key
|
21 |
| - |
| 21 | + |
22 | 22 | def prompt(self, prompt: str, **kwargs) -> str:
|
23 |
| - messages = [ |
24 |
| - { |
25 |
| - "role": "user", |
26 |
| - "content": prompt |
27 |
| - } |
28 |
| - ] |
29 |
| - response_message = openai_provider.request_openai(messages=messages, functions=[], model=self._model, temperature=self._temperature) |
30 |
| - return response_message['content'] |
| 23 | + messages = [{"role": "user", "content": prompt}] |
| 24 | + response_message = openai_provider.request_openai( |
| 25 | + messages=messages, |
| 26 | + functions=[], |
| 27 | + model=self.model_name, |
| 28 | + temperature=self._temperature, |
| 29 | + ) |
| 30 | + return response_message["content"] |
| 31 | + |
31 | 32 |
|
32 | 33 | class ChatResponse:
|
| 34 | + query: str |
33 | 35 | content: str
|
34 | 36 | relevant_documents: list[Document]
|
35 | 37 |
|
36 |
| - def __init__(self, content: str, relevant_documents: list[Document]): |
| 38 | + def __init__(self, query: str, content: str, relevant_documents: list[Document]): |
| 39 | + self.query = query |
37 | 40 | self.content = content
|
38 | 41 | self.relevant_documents = relevant_documents
|
39 | 42 |
|
40 | 43 |
|
41 | 44 | class BaseChatBot:
|
42 |
| - def __init__(self, llm_provider: str, retriever: BaseRetriever, |
43 |
| - whole_prompt_template: PromptTemplate, |
44 |
| - document_prompt_tempate: PromptTemplate, |
45 |
| - max_num_tokens_for_context: int = 3500, **kwargs): |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + llm_provider: str, |
| 48 | + retriever: BaseRetriever, |
| 49 | + whole_prompt_template: PromptTemplate, |
| 50 | + document_prompt_tempate: PromptTemplate, |
| 51 | + max_num_tokens_for_context: int = 3500, |
| 52 | + **kwargs, |
| 53 | + ): |
46 | 54 | self._llm_provider = llm_provider
|
47 | 55 | self._retriever = retriever
|
48 | 56 | self._whole_prompt_template = whole_prompt_template
|
49 | 57 | self._document_prompt_tempate = document_prompt_tempate
|
50 | 58 | self._max_num_tokens_for_context = max_num_tokens_for_context
|
51 | 59 | self._enc = tiktoken.encoding_for_model(self._llm_provider.model_name)
|
52 | 60 |
|
53 |
| - def chat(self, prompt: str, top_k=1, **kwargs) -> ChatResponse: |
| 61 | + def chat_batch( |
| 62 | + self, queries: list[str], top_k=1, concurrency: int = 20, **kwargs |
| 63 | + ) -> list[ChatResponse]: |
| 64 | + with ThreadPoolExecutor(max_workers=concurrency) as executor: |
| 65 | + results = executor.map( |
| 66 | + lambda query: self.chat(query=query, top_k=top_k, **kwargs), |
| 67 | + queries, |
| 68 | + ) |
| 69 | + return list(results) |
| 70 | + |
| 71 | + def chat(self, query: str, top_k=1, **kwargs) -> ChatResponse: |
54 | 72 | """
|
55 | 73 | Chat with the chatbot.
|
56 | 74 | """
|
57 |
| - relevant_documents = self._retriever.find_similar_docs(query=prompt, top_k=top_k) |
| 75 | + relevant_documents = self._retriever.find_similar_docs(query=query, top_k=top_k) |
58 | 76 | # First, format the prompt for each document
|
59 | 77 | document_str = ""
|
60 | 78 | total_num_tokens = 0
|
61 | 79 | for index, document in enumerate(relevant_documents):
|
62 | 80 | # use all attributes of document, except for created_at or vector, as the format parameter
|
63 |
| - doc_formated_prompt = self._document_prompt_tempate.format(**{k: v for k, v in document.__dict__.items() if k not in ['created_at', 'vector']}) |
| 81 | + doc_formated_prompt = self._document_prompt_tempate.format( |
| 82 | + **{ |
| 83 | + k: v |
| 84 | + for k, v in document.__dict__.items() |
| 85 | + if k not in ["created_at", "vector"] |
| 86 | + } |
| 87 | + ) |
64 | 88 | num_tokens = len(self._enc.encode(doc_formated_prompt))
|
65 | 89 | if total_num_tokens + num_tokens > self._max_num_tokens_for_context:
|
66 |
| - logger.warning(f"Exceeding max number of tokens for context: {self._max_num_tokens_for_context}, existing on {index}th document out of {len(relevant_documents)} documents") |
| 90 | + logger.warning( |
| 91 | + f"Exceeding max number of tokens for context: {self._max_num_tokens_for_context}, existing on {index}th document out of {len(relevant_documents)} documents" |
| 92 | + ) |
67 | 93 | break
|
68 | 94 | total_num_tokens += num_tokens
|
69 | 95 | document_str += doc_formated_prompt + "\n"
|
70 | 96 | logger.debug(f"Document string: {document_str}")
|
71 | 97 | # Then, format the whole prompt
|
72 |
| - whole_prompt = self._whole_prompt_template.format(context=document_str, prompt=prompt) |
| 98 | + whole_prompt = self._whole_prompt_template.format( |
| 99 | + context=document_str, query=query |
| 100 | + ) |
73 | 101 | logger.debug(f"Whole prompt: {whole_prompt}")
|
74 | 102 | response = self._llm_provider.prompt(whole_prompt)
|
75 |
| - return ChatResponse(content=response, relevant_documents=relevant_documents) |
76 |
| - |
77 |
| - |
| 103 | + return ChatResponse( |
| 104 | + query=query, content=response, relevant_documents=relevant_documents |
| 105 | + ) |
0 commit comments