Content-Length: 1129415 | pFad | http://github.com/postgresml/postgresml/commit/82fcbd4dca8f3ff094d3d5422b426d4b7a37c72e

B0 Chatbot page is almost ready to go (#1054) · postgresml/postgresml@82fcbd4 · GitHub
Skip to content

Commit 82fcbd4

Browse files
authored
Chatbot page is almost ready to go (#1054)
1 parent a2908fa commit 82fcbd4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2992
-614
lines changed

pgml-dashboard/Cargo.lock

Lines changed: 455 additions & 59 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-dashboard/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,5 @@ pgvector = { version = "0.2.2", features = [ "sqlx", "postgres" ] }
4545
console-subscriber = "*"
4646
glob = "*"
4747
pgml-components = { path = "../packages/pgml-components" }
48+
reqwest = { version = "0.11.20", features = ["json"] }
49+
pgml = { version = "0.9.2", path = "../pgml-sdks/pgml/" }

pgml-dashboard/package-lock.json

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-dashboard/package.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"dependencies": {
3+
"autosize": "^6.0.1",
4+
"dompurify": "^3.0.6",
5+
"marked": "^9.1.0"
6+
}
7+
}

pgml-dashboard/src/api/chatbot.rs

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
use anyhow::Context;
2+
use pgml::{Collection, Pipeline};
3+
use rand::{distributions::Alphanumeric, Rng};
4+
use reqwest::Client;
5+
use rocket::{
6+
http::Status,
7+
outcome::IntoOutcome,
8+
request::{self, FromRequest},
9+
route::Route,
10+
serde::json::Json,
11+
Request,
12+
};
13+
use serde::{Deserialize, Serialize};
14+
use serde_json::json;
15+
use std::time::{SystemTime, UNIX_EPOCH};
16+
17+
use crate::{
18+
forms,
19+
responses::{Error, ResponseOk},
20+
};
21+
22+
pub struct User {
23+
chatbot_session_id: String,
24+
}
25+
26+
#[rocket::async_trait]
27+
impl<'r> FromRequest<'r> for User {
28+
type Error = ();
29+
30+
async fn from_request(request: &'r Request<'_>) -> request::Outcome<User, ()> {
31+
request
32+
.cookies()
33+
.get_private("chatbot_session_id")
34+
.map(|c| User {
35+
chatbot_session_id: c.value().to_string(),
36+
})
37+
.or_forward(Status::Unauthorized)
38+
}
39+
}
40+
41+
#[derive(Serialize, Deserialize, PartialEq, Eq)]
42+
enum ChatRole {
43+
User,
44+
Bot,
45+
}
46+
47+
#[derive(Clone, Copy, Serialize, Deserialize)]
48+
enum ChatbotBrain {
49+
OpenAIGPT4,
50+
PostgresMLFalcon180b,
51+
AnthropicClaude,
52+
MetaLlama2,
53+
}
54+
55+
impl TryFrom<u8> for ChatbotBrain {
56+
type Error = anyhow::Error;
57+
58+
fn try_from(value: u8) -> anyhow::Result<Self> {
59+
match value {
60+
0 => Ok(ChatbotBrain::OpenAIGPT4),
61+
1 => Ok(ChatbotBrain::PostgresMLFalcon180b),
62+
2 => Ok(ChatbotBrain::AnthropicClaude),
63+
3 => Ok(ChatbotBrain::MetaLlama2),
64+
_ => Err(anyhow::anyhow!("Invalid brain id")),
65+
}
66+
}
67+
}
68+
69+
#[derive(Clone, Copy, Serialize, Deserialize)]
70+
enum KnowledgeBase {
71+
PostgresML,
72+
PyTorch,
73+
Rust,
74+
PostgreSQL,
75+
}
76+
77+
impl KnowledgeBase {
78+
// The topic and knowledge base are the same for now but may be different later
79+
fn topic(&self) -> &'static str {
80+
match self {
81+
Self::PostgresML => "PostgresML",
82+
Self::PyTorch => "PyTorch",
83+
Self::Rust => "Rust",
84+
Self::PostgreSQL => "PostgreSQL",
85+
}
86+
}
87+
88+
fn collection(&self) -> &'static str {
89+
match self {
90+
Self::PostgresML => "PostgresML",
91+
Self::PyTorch => "PyTorch",
92+
Self::Rust => "Rust",
93+
Self::PostgreSQL => "PostgreSQL",
94+
}
95+
}
96+
}
97+
98+
impl TryFrom<u8> for KnowledgeBase {
99+
type Error = anyhow::Error;
100+
101+
fn try_from(value: u8) -> anyhow::Result<Self> {
102+
match value {
103+
0 => Ok(KnowledgeBase::PostgresML),
104+
1 => Ok(KnowledgeBase::PyTorch),
105+
2 => Ok(KnowledgeBase::Rust),
106+
3 => Ok(KnowledgeBase::PostgreSQL),
107+
_ => Err(anyhow::anyhow!("Invalid knowledge base id")),
108+
}
109+
}
110+
}
111+
112+
#[derive(Serialize, Deserialize)]
113+
struct Document {
114+
id: String,
115+
text: String,
116+
role: ChatRole,
117+
user_id: String,
118+
model: ChatbotBrain,
119+
knowledge_base: KnowledgeBase,
120+
timestamp: u128,
121+
}
122+
123+
impl Document {
124+
fn new(text: String, role: ChatRole, user_id: String, model: ChatbotBrain, knowledge_base: KnowledgeBase) -> Document {
125+
let id = rand::thread_rng()
126+
.sample_iter(&Alphanumeric)
127+
.take(32)
128+
.map(char::from)
129+
.collect();
130+
let timestamp = SystemTime::now()
131+
.duration_since(UNIX_EPOCH)
132+
.unwrap()
133+
.as_millis();
134+
Document {
135+
id,
136+
text,
137+
role,
138+
user_id,
139+
model,
140+
knowledge_base,
141+
timestamp,
142+
}
143+
}
144+
}
145+
146+
async fn get_openai_chatgpt_answer(
147+
knowledge_base: KnowledgeBase,
148+
history: &str,
149+
context: &str,
150+
question: &str,
151+
) -> Result<String, Error> {
152+
let openai_api_key = std::env::var("OPENAI_API_KEY")?;
153+
let base_prompt = std::env::var("CHATBOT_CHATGPT_BASE_PROMPT")?;
154+
let system_prompt = std::env::var("CHATBOT_CHATGPT_SYSTEM_PROMPT")?;
155+
156+
let system_prompt = system_prompt
157+
.replace("{topic}", knowledge_base.topic())
158+
.replace("{persona}", "Engineer")
159+
.replace("{language}", "English");
160+
161+
let content = base_prompt
162+
.replace("{history}", history)
163+
.replace("{context}", context)
164+
.replace("{question}", question);
165+
166+
let body = json!({
167+
"model": "gpt-4",
168+
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": content}],
169+
"temperature": 0.7
170+
});
171+
172+
let response = Client::new()
173+
.post("https://api.openai.com/v1/chat/completions")
174+
.bearer_auth(openai_api_key)
175+
.json(&body)
176+
.send()
177+
.await?
178+
.json::<serde_json::Value>()
179+
.await?;
180+
181+
let response = response["choices"]
182+
.as_array()
183+
.context("No data returned from OpenAI")?[0]["message"]["content"]
184+
.as_str()
185+
.context("The reponse content from OpenAI was not a string")?
186+
.to_string();
187+
188+
Ok(response)
189+
}
190+
191+
#[post("/chatbot/get-answer", format = "json", data = "<data>")]
192+
pub async fn chatbot_get_answer(
193+
user: User,
194+
data: Json<forms::ChatbotPostData>,
195+
) -> Result<ResponseOk, Error> {
196+
match wrapped_chatbot_get_answer(user, data).await {
197+
Ok(response) => Ok(ResponseOk(
198+
json!({
199+
"answer": response,
200+
})
201+
.to_string(),
202+
)),
203+
Err(error) => {
204+
eprintln!("Error: {:?}", error);
205+
Ok(ResponseOk(
206+
json!({
207+
"error": error.to_string(),
208+
})
209+
.to_string(),
210+
))
211+
}
212+
}
213+
}
214+
215+
pub async fn wrapped_chatbot_get_answer(
216+
user: User,
217+
data: Json<forms::ChatbotPostData>,
218+
) -> Result<String, Error> {
219+
let brain = ChatbotBrain::try_from(data.model)?;
220+
let knowledge_base = KnowledgeBase::try_from(data.knowledge_base)?;
221+
222+
// Create it up here so the timestamps that order the conversation are accurate
223+
let user_document = Document::new(
224+
data.question.clone(),
225+
ChatRole::User,
226+
user.chatbot_session_id.clone(),
227+
brain,
228+
knowledge_base
229+
);
230+
231+
let collection = knowledge_base.collection();
232+
let collection = Collection::new(
233+
collection,
234+
Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")),
235+
);
236+
237+
let mut history_collection = Collection::new(
238+
"ChatHistory",
239+
Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")),
240+
);
241+
let messages = history_collection
242+
.get_documents(Some(
243+
json!({
244+
"limit": 5,
245+
"order_by": {"timestamp": "desc"},
246+
"filter": {
247+
"metadata": {
248+
"$and" : [
249+
{
250+
"$or":
251+
[
252+
{"role": {"$eq": ChatRole::Bot}},
253+
{"role": {"$eq": ChatRole::User}}
254+
]
255+
},
256+
{
257+
"user_id": {
258+
"$eq": user.chatbot_session_id
259+
}
260+
},
261+
{
262+
"knowledge_base": {
263+
"$eq": knowledge_base
264+
}
265+
},
266+
{
267+
"model": {
268+
"$eq": brain
269+
}
270+
}
271+
]
272+
}
273+
}
274+
275+
})
276+
.into(),
277+
))
278+
.await?;
279+
280+
let mut history = messages
281+
.into_iter()
282+
.map(|m| {
283+
// Can probably remove this clone
284+
let chat_role: ChatRole = serde_json::from_value(m["document"]["role"].to_owned())?;
285+
if chat_role == ChatRole::Bot {
286+
Ok(format!("Assistant: {}", m["document"]["text"]))
287+
} else {
288+
Ok(format!("User: {}", m["document"]["text"]))
289+
}
290+
})
291+
.collect::<anyhow::Result<Vec<String>>>()?;
292+
history.reverse();
293+
let history = history.join("\n");
294+
295+
let mut pipeline = Pipeline::new("v1", None, None, None);
296+
let context = collection
297+
.query()
298+
.vector_recall(&data.question, &mut pipeline, Some(json!({
299+
"instruction": "Represent the Wikipedia question for retrieving supporting documents: "
300+
}).into()))
301+
.limit(5)
302+
.fetch_all()
303+
.await?
304+
.into_iter()
305+
.map(|(_, context, metadata)| format!("#### Document {}: {}", metadata["id"], context))
306+
.collect::<Vec<String>>()
307+
.join("\n");
308+
309+
let answer = match brain {
310+
_ => get_openai_chatgpt_answer(knowledge_base, &history, &context, &data.question).await,
311+
}?;
312+
313+
let new_history_messages: Vec<pgml::types::Json> = vec![
314+
serde_json::to_value(user_document).unwrap().into(),
315+
serde_json::to_value(Document::new(
316+
answer.clone(),
317+
ChatRole::Bot,
318+
user.chatbot_session_id.clone(),
319+
brain,
320+
knowledge_base
321+
))
322+
.unwrap()
323+
.into(),
324+
];
325+
326+
// We do not want to block our return waiting for this to happen
327+
tokio::spawn(async move {
328+
history_collection
329+
.upsert_documents(new_history_messages, None)
330+
.await.expect("Failed to upsert user history");
331+
});
332+
333+
Ok(answer)
334+
}
335+
336+
pub fn routes() -> Vec<Route> {
337+
routes![chatbot_get_answer]
338+
}

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/commit/82fcbd4dca8f3ff094d3d5422b426d4b7a37c72e

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy