45
45
46
46
TSchema = TypeVar ('TSchema' , bound = BaseDoc )
47
47
48
-
49
48
QDRANT_PY_VECTOR_TYPES : List [Any ] = [np .ndarray , AbstractTensor ]
50
49
if torch_imported :
51
50
import torch
@@ -86,6 +85,19 @@ def __init__(self, db_config=None, **kwargs):
86
85
self ._initialize_collection ()
87
86
self ._logger .info (f'{ self .__class__ .__name__ } has been initialized' )
88
87
88
+ @property
89
+ def collection_name (self ):
90
+ default_collection_name = (
91
+ self ._schema .__name__ .lower () if self ._schema is not None else None
92
+ )
93
+ if default_collection_name is None :
94
+ raise ValueError (
95
+ 'A QdrantDocumentIndex must be typed with a Document type.'
96
+ 'To do so, use the syntax: QdrantDocumentIndex[DocumentType]'
97
+ )
98
+
99
+ return self ._db_config .collection_name or default_collection_name
100
+
89
101
@dataclass
90
102
class Query :
91
103
"""Dataclass describing a query."""
@@ -211,7 +223,7 @@ class DBConfig(BaseDocIndex.DBConfig):
211
223
timeout : Optional [float ] = None
212
224
host : Optional [str ] = None
213
225
path : Optional [str ] = None
214
- collection_name : str = 'documents'
226
+ collection_name : Optional [ str ] = None
215
227
shard_number : Optional [int ] = None
216
228
replication_factor : Optional [int ] = None
217
229
write_consistency_factor : Optional [int ] = None
@@ -250,15 +262,15 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
250
262
251
263
def _initialize_collection (self ):
252
264
try :
253
- self ._client .get_collection (self ._db_config . collection_name )
265
+ self ._client .get_collection (self .collection_name )
254
266
except (UnexpectedResponse , RpcError , ValueError ):
255
267
vectors_config = {
256
268
column_name : self ._to_qdrant_vector_params (column_info )
257
269
for column_name , column_info in self ._column_infos .items ()
258
270
if column_info .db_type == 'vector'
259
271
}
260
272
self ._client .create_collection (
261
- collection_name = self ._db_config . collection_name ,
273
+ collection_name = self .collection_name ,
262
274
vectors_config = vectors_config ,
263
275
shard_number = self ._db_config .shard_number ,
264
276
replication_factor = self ._db_config .replication_factor ,
@@ -270,7 +282,7 @@ def _initialize_collection(self):
270
282
quantization_config = self ._db_config .quantization_config ,
271
283
)
272
284
self ._client .create_payload_index (
273
- collection_name = self ._db_config . collection_name ,
285
+ collection_name = self .collection_name ,
274
286
field_name = '__generated_vectors' ,
275
287
field_schema = rest .PayloadSchemaType .KEYWORD ,
276
288
)
@@ -280,15 +292,15 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
280
292
# TODO: add batching the documents to avoid timeouts
281
293
points = [self ._build_point_from_row (row ) for row in rows ]
282
294
self ._client .upsert (
283
- collection_name = self ._db_config . collection_name ,
295
+ collection_name = self .collection_name ,
284
296
points = points ,
285
297
)
286
298
287
299
def num_docs (self ) -> int :
288
300
"""
289
301
Get the number of documents.
290
302
"""
291
- return self ._client .count (collection_name = self ._db_config . collection_name ).count
303
+ return self ._client .count (collection_name = self .collection_name ).count
292
304
293
305
def _del_items (self , doc_ids : Sequence [str ]):
294
306
items = self ._get_items (doc_ids )
@@ -298,7 +310,7 @@ def _del_items(self, doc_ids: Sequence[str]):
298
310
raise KeyError ('Document keys could not found: %s' % ',' .join (missing_keys ))
299
311
300
312
self ._client .delete (
301
- collection_name = self ._db_config . collection_name ,
313
+ collection_name = self .collection_name ,
302
314
points_selector = rest .PointIdsList (
303
315
points = [self ._to_qdrant_id (doc_id ) for doc_id in doc_ids ],
304
316
),
@@ -308,7 +320,7 @@ def _get_items(
308
320
self , doc_ids : Sequence [str ]
309
321
) -> Union [Sequence [TSchema ], Sequence [Dict [str , Any ]]]:
310
322
response , _ = self ._client .scroll (
311
- collection_name = self ._db_config . collection_name ,
323
+ collection_name = self .collection_name ,
312
324
scroll_filter = rest .Filter (
313
325
must = [
314
326
rest .HasIdCondition (
@@ -343,7 +355,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi
343
355
# We perform semantic search with some vectors with Qdrant's search method
344
356
# should be called
345
357
points = self ._client .search ( # type: ignore[assignment]
346
- collection_name = self ._db_config . collection_name ,
358
+ collection_name = self .collection_name ,
347
359
query_vector = (query .vector_field , query .vector_query ), # type: ignore[arg-type]
348
360
query_filter = rest .Filter (
349
361
must = [query .filter ],
@@ -364,7 +376,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi
364
376
else :
365
377
# Just filtering, so Qdrant's scroll has to be used instead
366
378
points , _ = self ._client .scroll ( # type: ignore[assignment]
367
- collection_name = self ._db_config . collection_name ,
379
+ collection_name = self .collection_name ,
368
380
scroll_filter = query .filter ,
369
381
limit = query .limit ,
370
382
with_payload = True ,
@@ -388,7 +400,7 @@ def _execute_raw_query(
388
400
if search_params :
389
401
search_params = rest .SearchParams .parse_obj (search_params ) # type: ignore[assignment]
390
402
points = self ._client .search ( # type: ignore[assignment]
391
- collection_name = self ._db_config . collection_name ,
403
+ collection_name = self .collection_name ,
392
404
query_vector = query .pop ('vector' ),
393
405
query_filter = payload_filter ,
394
406
search_params = search_params ,
@@ -397,7 +409,7 @@ def _execute_raw_query(
397
409
else :
398
410
# Just filtering, so Qdrant's scroll has to be used instead
399
411
points , _ = self ._client .scroll ( # type: ignore[assignment]
400
- collection_name = self ._db_config . collection_name ,
412
+ collection_name = self .collection_name ,
401
413
scroll_filter = payload_filter ,
402
414
** query ,
403
415
)
@@ -417,7 +429,7 @@ def _find_batched(
417
429
self , queries : np .ndarray , limit : int , search_field : str = ''
418
430
) -> _FindResultBatched :
419
431
responses = self ._client .search_batch (
420
- collection_name = self ._db_config . collection_name ,
432
+ collection_name = self .collection_name ,
421
433
requests = [
422
434
rest .SearchRequest (
423
435
vector = rest .NamedVector (
@@ -470,7 +482,7 @@ def _filter_batched(
470
482
# There is no batch scroll available in Qdrant client yet, so we need to
471
483
# perform the queries one by one. It will be changed in the future versions.
472
484
response , _ = self ._client .scroll (
473
- collection_name = self ._db_config . collection_name ,
485
+ collection_name = self .collection_name ,
474
486
scroll_filter = filter_query ,
475
487
limit = limit ,
476
488
with_payload = True ,
0 commit comments