|
| 1 | +from pgml import Collection, Model, Splitter, Pipeline, Builtins, OpenSourceAI |
| 2 | +import json |
| 3 | +from datasets import load_dataset |
| 4 | +from time import time |
| 5 | +from dotenv import load_dotenv |
| 6 | +from rich.console import Console |
| 7 | +import asyncio |
| 8 | + |
| 9 | + |
| 10 | +async def main(): |
| 11 | + load_dotenv() |
| 12 | + console = Console() |
| 13 | + |
| 14 | + # Initialize collection |
| 15 | + collection = Collection("squad_collection") |
| 16 | + |
| 17 | + # Create a pipeline using the default model and splitter |
| 18 | + model = Model() |
| 19 | + splitter = Splitter() |
| 20 | + pipeline = Pipeline("squadv1", model, splitter) |
| 21 | + await collection.add_pipeline(pipeline) |
| 22 | + |
| 23 | + # Prep documents for upserting |
| 24 | + data = load_dataset("squad", split="train") |
| 25 | + data = data.to_pandas() |
| 26 | + data = data.drop_duplicates(subset=["context"]) |
| 27 | + documents = [ |
| 28 | + {"id": r["id"], "text": r["context"], "title": r["title"]} |
| 29 | + for r in data.to_dict(orient="records") |
| 30 | + ] |
| 31 | + |
| 32 | + # Upsert documents |
| 33 | + await collection.upsert_documents(documents[:200]) |
| 34 | + |
| 35 | + # Query for context |
| 36 | + query = "Who won more than 20 grammy awards?" |
| 37 | + |
| 38 | + console.print("Question: %s"%query) |
| 39 | + console.print("Querying for context ...") |
| 40 | + |
| 41 | + start = time() |
| 42 | + results = ( |
| 43 | + await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() |
| 44 | + ) |
| 45 | + end = time() |
| 46 | + |
| 47 | + #console.print("Query time = %0.3f" % (end - start)) |
| 48 | + |
| 49 | + # Construct context from results |
| 50 | + context = " ".join(results[0][1].strip().split()) |
| 51 | + context = context.replace('"', '\\"').replace("'", "''") |
| 52 | + console.print("Context is ready...") |
| 53 | + |
| 54 | + # Query for answer |
| 55 | + system_prompt = """Use the following pieces of context to answer the question at the end. |
| 56 | + If you don't know the answer, just say that you don't know, don't try to make up an answer. |
| 57 | + Use three sentences maximum and keep the answer as concise as possible. |
| 58 | + Always say "thanks for asking!" at the end of the answer.""" |
| 59 | + user_prompt_template = """ |
| 60 | + #### |
| 61 | + Documents |
| 62 | + #### |
| 63 | + {context} |
| 64 | + ### |
| 65 | + User: {question} |
| 66 | + ### |
| 67 | + """ |
| 68 | + |
| 69 | + user_prompt = user_prompt_template.format(context=context, question=query) |
| 70 | + messages = [ |
| 71 | + {"role": "system", "content": system_prompt}, |
| 72 | + {"role": "user", "content": user_prompt}, |
| 73 | + ] |
| 74 | + |
| 75 | + # Using OpenSource LLMs for Chat Completion |
| 76 | + client = OpenSourceAI() |
| 77 | + chat_completion_model = "HuggingFaceH4/zephyr-7b-beta" |
| 78 | + console.print("Generating response using %s LLM..."%chat_completion_model) |
| 79 | + response = client.chat_completions_create( |
| 80 | + model=chat_completion_model, |
| 81 | + messages=messages, |
| 82 | + temperature=0.3, |
| 83 | + max_tokens=256, |
| 84 | + ) |
| 85 | + output = response["choices"][0]["message"]["content"] |
| 86 | + console.print("Answer: %s"%output) |
| 87 | + # Archive collection |
| 88 | + await collection.archive() |
| 89 | + |
| 90 | + |
| 91 | +if __name__ == "__main__": |
| 92 | + asyncio.run(main()) |
0 commit comments