Skip to content

Commit f6eda7a

Browse files
authored
pgml sdk rag example (#1249)
1 parent 793ad2f commit f6eda7a

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from pgml import Collection, Model, Splitter, Pipeline, Builtins, OpenSourceAI
2+
import json
3+
from datasets import load_dataset
4+
from time import time
5+
from dotenv import load_dotenv
6+
from rich.console import Console
7+
import asyncio
8+
9+
10+
async def main():
11+
load_dotenv()
12+
console = Console()
13+
14+
# Initialize collection
15+
collection = Collection("squad_collection")
16+
17+
# Create a pipeline using the default model and splitter
18+
model = Model()
19+
splitter = Splitter()
20+
pipeline = Pipeline("squadv1", model, splitter)
21+
await collection.add_pipeline(pipeline)
22+
23+
# Prep documents for upserting
24+
data = load_dataset("squad", split="train")
25+
data = data.to_pandas()
26+
data = data.drop_duplicates(subset=["context"])
27+
documents = [
28+
{"id": r["id"], "text": r["context"], "title": r["title"]}
29+
for r in data.to_dict(orient="records")
30+
]
31+
32+
# Upsert documents
33+
await collection.upsert_documents(documents[:200])
34+
35+
# Query for context
36+
query = "Who won more than 20 grammy awards?"
37+
38+
console.print("Question: %s"%query)
39+
console.print("Querying for context ...")
40+
41+
start = time()
42+
results = (
43+
await collection.query().vector_recall(query, pipeline).limit(5).fetch_all()
44+
)
45+
end = time()
46+
47+
#console.print("Query time = %0.3f" % (end - start))
48+
49+
# Construct context from results
50+
context = " ".join(results[0][1].strip().split())
51+
context = context.replace('"', '\\"').replace("'", "''")
52+
console.print("Context is ready...")
53+
54+
# Query for answer
55+
system_prompt = """Use the following pieces of context to answer the question at the end.
56+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
57+
Use three sentences maximum and keep the answer as concise as possible.
58+
Always say "thanks for asking!" at the end of the answer."""
59+
user_prompt_template = """
60+
####
61+
Documents
62+
####
63+
{context}
64+
###
65+
User: {question}
66+
###
67+
"""
68+
69+
user_prompt = user_prompt_template.format(context=context, question=query)
70+
messages = [
71+
{"role": "system", "content": system_prompt},
72+
{"role": "user", "content": user_prompt},
73+
]
74+
75+
# Using OpenSource LLMs for Chat Completion
76+
client = OpenSourceAI()
77+
chat_completion_model = "HuggingFaceH4/zephyr-7b-beta"
78+
console.print("Generating response using %s LLM..."%chat_completion_model)
79+
response = client.chat_completions_create(
80+
model=chat_completion_model,
81+
messages=messages,
82+
temperature=0.3,
83+
max_tokens=256,
84+
)
85+
output = response["choices"][0]["message"]["content"]
86+
console.print("Answer: %s"%output)
87+
# Archive collection
88+
await collection.archive()
89+
90+
91+
if __name__ == "__main__":
92+
asyncio.run(main())

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy