Skip to content

Commit eb27ee2

Browse files
Merge pull request #32 from Anush008/master
feat: Qdrant vectorstore support
2 parents f7e24b3 + 28c32a7 commit eb27ee2

File tree

3 files changed

+165
-6
lines changed

3 files changed

+165
-6
lines changed

mindsql/vectorstores/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .ivectorstore import IVectorstore
22
from .chromadb import ChromaDB
33
from .faiss_db import Faiss
4+
from .qdrant import Qdrant

mindsql/vectorstores/qdrant.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import json
2+
import os
3+
import uuid
4+
from typing import List
5+
6+
import pandas as pd
7+
from qdrant_client import QdrantClient
8+
from qdrant_client.http.models import Distance, VectorParams, PointStruct
9+
from sentence_transformers import SentenceTransformer
10+
11+
from . import IVectorstore
12+
13+
sentence_transformer_ef = SentenceTransformer("WhereIsAI/UAE-Large-V1")
14+
15+
16+
class Qdrant(IVectorstore):
17+
def __init__(self, config=None):
18+
if config is not None:
19+
self.embedding_function = config.get(
20+
"embedding_function", sentence_transformer_ef
21+
)
22+
self.dimension = config.get("dimension", 1024)
23+
qdrant_client_options = config.get("qdrant_client_options", {})
24+
else:
25+
self.embedding_function = sentence_transformer_ef
26+
self.dimension = 1024
27+
qdrant_client_options = {}
28+
self.client = QdrantClient(**qdrant_client_options)
29+
self._init_collections()
30+
31+
def _init_collections(self):
32+
for name in ["sql", "ddl", "documentation"]:
33+
if not self.client.collection_exists(collection_name=name):
34+
self.client.create_collection(
35+
collection_name=name,
36+
vectors_config=VectorParams(
37+
size=self.dimension, distance=Distance.COSINE
38+
),
39+
)
40+
41+
def index_question_sql(self, question: str, sql: str, **kwargs) -> str:
42+
question_sql_json = json.dumps(
43+
{"question": question, "sql": sql}, ensure_ascii=False
44+
)
45+
chunk_id = str(uuid.uuid4())
46+
vector = self.embedding_function.encode([question_sql_json])[0]
47+
self.client.upsert(
48+
collection_name="sql",
49+
points=[
50+
PointStruct(
51+
id=chunk_id, vector=vector, payload={"data": question_sql_json}
52+
)
53+
],
54+
)
55+
return chunk_id + "-sql"
56+
57+
def index_ddl(self, ddl: str, **kwargs) -> str:
58+
chunk_id = str(uuid.uuid4())
59+
table = kwargs.get("table", None)
60+
vector = self.embedding_function.encode([ddl])[0]
61+
payload = {"data": ddl}
62+
if table:
63+
payload["table_name"] = table
64+
self.client.upsert(
65+
collection_name="ddl",
66+
points=[PointStruct(id=chunk_id, vector=vector, payload=payload)],
67+
)
68+
return chunk_id + "-ddl"
69+
70+
def index_documentation(self, documentation: str, **kwargs) -> str:
71+
chunk_id = str(uuid.uuid4())
72+
vector = self.embedding_function.encode([documentation])[0]
73+
self.client.upsert(
74+
collection_name="documentation",
75+
points=[
76+
PointStruct(id=chunk_id, vector=vector, payload={"data": documentation})
77+
],
78+
)
79+
return chunk_id + "-doc"
80+
81+
def fetch_all_vectorstore_data(self, **kwargs) -> pd.DataFrame:
82+
data = []
83+
for name in ["sql", "ddl", "documentation"]:
84+
points = self.client.scroll(collection_name=name, limit=10000)[0]
85+
for point in points:
86+
payload = point.payload or {}
87+
if name == "sql":
88+
doc = json.loads(payload.get("data", "{}"))
89+
question = doc.get("question")
90+
content = doc.get("sql")
91+
else:
92+
question = None
93+
content = payload.get("data")
94+
data.append(
95+
{
96+
"id": point.id,
97+
"question": question,
98+
"content": content,
99+
"training_data_type": name,
100+
}
101+
)
102+
return pd.DataFrame(data)
103+
104+
def delete_vectorstore_data(self, item_id: str, **kwargs) -> bool:
105+
uuid_str = item_id[:-4]
106+
if item_id.endswith("-sql"):
107+
self.client.delete(collection_name="sql", points_selector=[uuid_str])
108+
return True
109+
elif item_id.endswith("-ddl"):
110+
self.client.delete(collection_name="ddl", points_selector=[uuid_str])
111+
return True
112+
elif item_id.endswith("-doc"):
113+
self.client.delete(
114+
collection_name="documentation", points_selector=[uuid_str]
115+
)
116+
return True
117+
else:
118+
return False
119+
120+
def remove_collection(self, collection_name: str) -> bool:
121+
if self.client.collection_exists(collection_name=collection_name):
122+
self.client.delete_collection(collection_name=collection_name)
123+
self.client.create_collection(
124+
collection_name=collection_name,
125+
vectors_config=VectorParams(
126+
size=self.dimension, distance=Distance.COSINE
127+
),
128+
)
129+
return True
130+
return False
131+
132+
def retrieve_relevant_question_sql(self, question: str, **kwargs) -> list:
133+
n = kwargs.get("n_results", 2)
134+
vector = self.embedding_function.encode([question])[0]
135+
hits = self.client.query_points(
136+
collection_name="sql", query=vector, limit=n
137+
).points
138+
results = []
139+
for hit in hits:
140+
doc = json.loads(hit.payload.get("data", "{}"))
141+
results.append(doc)
142+
return results
143+
144+
def retrieve_relevant_ddl(self, question: str, **kwargs) -> list:
145+
n = kwargs.get("n_results", 2)
146+
vector = self.embedding_function.encode([question])[0]
147+
hits = self.client.query_points(
148+
collection_name="ddl", query=vector, limit=n
149+
).points
150+
return [hit.payload.get("data") for hit in hits]
151+
152+
def retrieve_relevant_documentation(self, question: str, **kwargs) -> list:
153+
n = kwargs.get("n_results", 2)
154+
vector = self.embedding_function.encode([question])[0]
155+
hits = self.client.query_points(
156+
collection_name="documentation", query=vector, limit=n
157+
).points
158+
return [hit.payload.get("data") for hit in hits]

pyproject.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ classifiers = [
1616

1717

1818
[tool.poetry.dependencies]
19-
python = "^3.10"
20-
chromadb = "^0.4.22"
21-
pandas = "2.2.0"
19+
python = "^3.11"
20+
chromadb = "^1.0.15"
21+
pandas = "2.3.1"
2222
plotly = "5.19.0"
2323
mysql-connector-python = "^8.3.0"
2424
google-generativeai="0.3.2"
2525
llama-cpp-python = "0.2.47"
2626
openai = "^1.12.0"
2727
sqlparse = "^0.4.4"
28-
numpy = "^1.26.4"
28+
numpy = "2.3.1"
2929
sentence-transformers = "^2.3.1"
3030
psycopg2-binary = "^2.9.9"
31-
faiss-cpu = "^1.8.0"
32-
pysqlite3-binary = "^0.5.2.post3"
31+
faiss-cpu = "^1.11.0.post1"
3332
transformers = "^4.38.2"
33+
qdrant-client = "^1.14.3"
3434

3535

3636
[build-system]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy