Skip to content

Commit eb69318

Browse files
Joan FontanalsJohannesMessnerAnneYang720
authored
feat: index or collection name will default to doc-type name (docarray#1486)
Signed-off-by: Joan Fontanals Martinez <joan.martinez@jina.ai> Signed-off-by: Johannes Messner <messnerjo@gmail.com> Signed-off-by: AnneY <evangeline-lun@foxmail.com> Co-authored-by: Johannes Messner <messnerjo@gmail.com> Co-authored-by: AnneY <evangeline-lun@foxmail.com>
1 parent d1f13d6 commit eb69318

22 files changed

+374
-182
lines changed

docarray/index/backends/elastic.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# mypy: ignore-errors
2-
import uuid
32
import warnings
43
from collections import defaultdict
54
from dataclasses import dataclass, field
@@ -74,12 +73,6 @@ def __init__(self, db_config=None, **kwargs):
7473
self._logger.debug('Elastic Search index is being initialized')
7574

7675
# ElasticSearch client creation
77-
if self._db_config.index_name is None:
78-
id = uuid.uuid4().hex
79-
self._db_config.index_name = 'index__' + id
80-
81-
self._index_name = self._db_config.index_name
82-
8376
self._client = Elasticsearch(
8477
hosts=self._db_config.hosts,
8578
**self._db_config.es_config,
@@ -108,15 +101,28 @@ def __init__(self, db_config=None, **kwargs):
108101
mappings['properties'][col_name] = self._create_index_mapping(col)
109102

110103
# print(mappings['properties'])
111-
if self._client.indices.exists(index=self._index_name):
104+
if self._client.indices.exists(index=self.index_name):
112105
self._client_put_mapping(mappings)
113106
else:
114107
self._client_create(mappings)
115108

116109
if len(self._db_config.index_settings):
117110
self._client_put_settings(self._db_config.index_settings)
118111

119-
self._refresh(self._index_name)
112+
self._refresh(self.index_name)
113+
114+
@property
115+
def index_name(self):
116+
default_index_name = (
117+
self._schema.__name__.lower() if self._schema is not None else None
118+
)
119+
if default_index_name is None:
120+
raise ValueError(
121+
'A ElasticDocIndex must be typed with a Document type.'
122+
'To do so, use the syntax: ElasticDocIndex[DocumentType]'
123+
)
124+
125+
return self._db_config.index_name or default_index_name
120126

121127
###############################################
122128
# Inner classes for query builder and configs #
@@ -333,7 +339,7 @@ def _index(
333339

334340
for row in data:
335341
request = {
336-
'_index': self._index_name,
342+
'_index': self.index_name,
337343
'_id': row['id'],
338344
}
339345
for col_name, col in self._column_infos.items():
@@ -349,13 +355,13 @@ def _index(
349355
warnings.warn(str(info))
350356

351357
if refresh:
352-
self._refresh(self._index_name)
358+
self._refresh(self.index_name)
353359

354360
def num_docs(self) -> int:
355361
"""
356362
Get the number of documents.
357363
"""
358-
return self._client.count(index=self._index_name)['count']
364+
return self._client.count(index=self.index_name)['count']
359365

360366
def _del_items(
361367
self,
@@ -365,7 +371,7 @@ def _del_items(
365371
requests = []
366372
for _id in doc_ids:
367373
requests.append(
368-
{'_op_type': 'delete', '_index': self._index_name, '_id': _id}
374+
{'_op_type': 'delete', '_index': self.index_name, '_id': _id}
369375
)
370376

371377
_, warning_info = self._send_requests(requests, chunk_size)
@@ -375,7 +381,7 @@ def _del_items(
375381
ids = [info['delete']['_id'] for info in warning_info]
376382
warnings.warn(f'No document with id {ids} found')
377383

378-
self._refresh(self._index_name)
384+
self._refresh(self.index_name)
379385

380386
def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]:
381387
accumulated_docs = []
@@ -416,7 +422,7 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any:
416422
f'args and kwargs not supported for `execute_query` on {type(self)}'
417423
)
418424

419-
resp = self._client.search(index=self._index_name, **query)
425+
resp = self._client.search(index=self.index_name, **query)
420426
docs, scores = self._format_response(resp)
421427

422428
return _FindResult(documents=docs, scores=parse_obj_as(NdArray, scores))
@@ -440,7 +446,7 @@ def _find_batched(
440446
) -> _FindResultBatched:
441447
request = []
442448
for query in queries:
443-
head = {'index': self._index_name}
449+
head = {'index': self.index_name}
444450
body = self._form_search_body(query, limit, search_field)
445451
request.extend([head, body])
446452

@@ -469,7 +475,7 @@ def _filter_batched(
469475
) -> List[List[Dict]]:
470476
request = []
471477
for query in filter_queries:
472-
head = {'index': self._index_name}
478+
head = {'index': self.index_name}
473479
body = {'query': query, 'size': limit}
474480
request.extend([head, body])
475481

@@ -499,7 +505,7 @@ def _text_search_batched(
499505
) -> _FindResultBatched:
500506
request = []
501507
for query in queries:
502-
head = {'index': self._index_name}
508+
head = {'index': self.index_name}
503509
body = self._form_text_search_body(query, limit, search_field)
504510
request.extend([head, body])
505511

@@ -615,20 +621,20 @@ def _refresh(self, index_name: str):
615621

616622
def _client_put_mapping(self, mappings: Dict[str, Any]):
617623
self._client.indices.put_mapping(
618-
index=self._index_name, properties=mappings['properties']
624+
index=self.index_name, properties=mappings['properties']
619625
)
620626

621627
def _client_create(self, mappings: Dict[str, Any]):
622-
self._client.indices.create(index=self._index_name, mappings=mappings)
628+
self._client.indices.create(index=self.index_name, mappings=mappings)
623629

624630
def _client_put_settings(self, settings: Dict[str, Any]):
625-
self._client.indices.put_settings(index=self._index_name, settings=settings)
631+
self._client.indices.put_settings(index=self.index_name, settings=settings)
626632

627633
def _client_mget(self, ids: Sequence[str]):
628-
return self._client.mget(index=self._index_name, ids=ids)
634+
return self._client.mget(index=self.index_name, ids=ids)
629635

630636
def _client_search(self, **kwargs):
631-
return self._client.search(index=self._index_name, **kwargs)
637+
return self._client.search(index=self.index_name, **kwargs)
632638

633639
def _client_msearch(self, request: List[Dict[str, Any]]):
634-
return self._client.msearch(index=self._index_name, searches=request)
640+
return self._client.msearch(index=self.index_name, searches=request)

docarray/index/backends/elasticv7.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any:
119119
f'args and kwargs not supported for `execute_query` on {type(self)}'
120120
)
121121

122-
resp = self._client.search(index=self._index_name, body=query)
122+
resp = self._client.search(index=self.index_name, body=query)
123123
docs, scores = self._format_response(resp)
124124

125125
return _FindResult(documents=docs, scores=parse_obj_as(NdArray, scores))
@@ -161,20 +161,20 @@ def _form_search_body(self, query: np.ndarray, limit: int, search_field: str = '
161161
###############################################
162162

163163
def _client_put_mapping(self, mappings: Dict[str, Any]):
164-
self._client.indices.put_mapping(index=self._index_name, body=mappings)
164+
self._client.indices.put_mapping(index=self.index_name, body=mappings)
165165

166166
def _client_create(self, mappings: Dict[str, Any]):
167167
body = {'mappings': mappings}
168-
self._client.indices.create(index=self._index_name, body=body)
168+
self._client.indices.create(index=self.index_name, body=body)
169169

170170
def _client_put_settings(self, settings: Dict[str, Any]):
171-
self._client.indices.put_settings(index=self._index_name, body=settings)
171+
self._client.indices.put_settings(index=self.index_name, body=settings)
172172

173173
def _client_mget(self, ids: Sequence[str]):
174-
return self._client.mget(index=self._index_name, body={'ids': ids})
174+
return self._client.mget(index=self.index_name, body={'ids': ids})
175175

176176
def _client_search(self, **kwargs):
177-
return self._client.search(index=self._index_name, body=kwargs)
177+
return self._client.search(index=self.index_name, body=kwargs)
178178

179179
def _client_msearch(self, request: List[Dict[str, Any]]):
180-
return self._client.msearch(index=self._index_name, body=request)
180+
return self._client.msearch(index=self.index_name, body=request)

docarray/index/backends/qdrant.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545

4646
TSchema = TypeVar('TSchema', bound=BaseDoc)
4747

48-
4948
QDRANT_PY_VECTOR_TYPES: List[Any] = [np.ndarray, AbstractTensor]
5049
if torch_imported:
5150
import torch
@@ -86,6 +85,19 @@ def __init__(self, db_config=None, **kwargs):
8685
self._initialize_collection()
8786
self._logger.info(f'{self.__class__.__name__} has been initialized')
8887

88+
@property
89+
def collection_name(self):
90+
default_collection_name = (
91+
self._schema.__name__.lower() if self._schema is not None else None
92+
)
93+
if default_collection_name is None:
94+
raise ValueError(
95+
'A QdrantDocumentIndex must be typed with a Document type.'
96+
'To do so, use the syntax: QdrantDocumentIndex[DocumentType]'
97+
)
98+
99+
return self._db_config.collection_name or default_collection_name
100+
89101
@dataclass
90102
class Query:
91103
"""Dataclass describing a query."""
@@ -211,7 +223,7 @@ class DBConfig(BaseDocIndex.DBConfig):
211223
timeout: Optional[float] = None
212224
host: Optional[str] = None
213225
path: Optional[str] = None
214-
collection_name: str = 'documents'
226+
collection_name: Optional[str] = None
215227
shard_number: Optional[int] = None
216228
replication_factor: Optional[int] = None
217229
write_consistency_factor: Optional[int] = None
@@ -250,15 +262,15 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
250262

251263
def _initialize_collection(self):
252264
try:
253-
self._client.get_collection(self._db_config.collection_name)
265+
self._client.get_collection(self.collection_name)
254266
except (UnexpectedResponse, RpcError, ValueError):
255267
vectors_config = {
256268
column_name: self._to_qdrant_vector_params(column_info)
257269
for column_name, column_info in self._column_infos.items()
258270
if column_info.db_type == 'vector'
259271
}
260272
self._client.create_collection(
261-
collection_name=self._db_config.collection_name,
273+
collection_name=self.collection_name,
262274
vectors_config=vectors_config,
263275
shard_number=self._db_config.shard_number,
264276
replication_factor=self._db_config.replication_factor,
@@ -270,7 +282,7 @@ def _initialize_collection(self):
270282
quantization_config=self._db_config.quantization_config,
271283
)
272284
self._client.create_payload_index(
273-
collection_name=self._db_config.collection_name,
285+
collection_name=self.collection_name,
274286
field_name='__generated_vectors',
275287
field_schema=rest.PayloadSchemaType.KEYWORD,
276288
)
@@ -280,15 +292,15 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
280292
# TODO: add batching the documents to avoid timeouts
281293
points = [self._build_point_from_row(row) for row in rows]
282294
self._client.upsert(
283-
collection_name=self._db_config.collection_name,
295+
collection_name=self.collection_name,
284296
points=points,
285297
)
286298

287299
def num_docs(self) -> int:
288300
"""
289301
Get the number of documents.
290302
"""
291-
return self._client.count(collection_name=self._db_config.collection_name).count
303+
return self._client.count(collection_name=self.collection_name).count
292304

293305
def _del_items(self, doc_ids: Sequence[str]):
294306
items = self._get_items(doc_ids)
@@ -298,7 +310,7 @@ def _del_items(self, doc_ids: Sequence[str]):
298310
raise KeyError('Document keys could not found: %s' % ','.join(missing_keys))
299311

300312
self._client.delete(
301-
collection_name=self._db_config.collection_name,
313+
collection_name=self.collection_name,
302314
points_selector=rest.PointIdsList(
303315
points=[self._to_qdrant_id(doc_id) for doc_id in doc_ids],
304316
),
@@ -308,7 +320,7 @@ def _get_items(
308320
self, doc_ids: Sequence[str]
309321
) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]:
310322
response, _ = self._client.scroll(
311-
collection_name=self._db_config.collection_name,
323+
collection_name=self.collection_name,
312324
scroll_filter=rest.Filter(
313325
must=[
314326
rest.HasIdCondition(
@@ -343,7 +355,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi
343355
# We perform semantic search with some vectors with Qdrant's search method
344356
# should be called
345357
points = self._client.search( # type: ignore[assignment]
346-
collection_name=self._db_config.collection_name,
358+
collection_name=self.collection_name,
347359
query_vector=(query.vector_field, query.vector_query), # type: ignore[arg-type]
348360
query_filter=rest.Filter(
349361
must=[query.filter],
@@ -364,7 +376,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi
364376
else:
365377
# Just filtering, so Qdrant's scroll has to be used instead
366378
points, _ = self._client.scroll( # type: ignore[assignment]
367-
collection_name=self._db_config.collection_name,
379+
collection_name=self.collection_name,
368380
scroll_filter=query.filter,
369381
limit=query.limit,
370382
with_payload=True,
@@ -388,7 +400,7 @@ def _execute_raw_query(
388400
if search_params:
389401
search_params = rest.SearchParams.parse_obj(search_params) # type: ignore[assignment]
390402
points = self._client.search( # type: ignore[assignment]
391-
collection_name=self._db_config.collection_name,
403+
collection_name=self.collection_name,
392404
query_vector=query.pop('vector'),
393405
query_filter=payload_filter,
394406
search_params=search_params,
@@ -397,7 +409,7 @@ def _execute_raw_query(
397409
else:
398410
# Just filtering, so Qdrant's scroll has to be used instead
399411
points, _ = self._client.scroll( # type: ignore[assignment]
400-
collection_name=self._db_config.collection_name,
412+
collection_name=self.collection_name,
401413
scroll_filter=payload_filter,
402414
**query,
403415
)
@@ -417,7 +429,7 @@ def _find_batched(
417429
self, queries: np.ndarray, limit: int, search_field: str = ''
418430
) -> _FindResultBatched:
419431
responses = self._client.search_batch(
420-
collection_name=self._db_config.collection_name,
432+
collection_name=self.collection_name,
421433
requests=[
422434
rest.SearchRequest(
423435
vector=rest.NamedVector(
@@ -470,7 +482,7 @@ def _filter_batched(
470482
# There is no batch scroll available in Qdrant client yet, so we need to
471483
# perform the queries one by one. It will be changed in the future versions.
472484
response, _ = self._client.scroll(
473-
collection_name=self._db_config.collection_name,
485+
collection_name=self.collection_name,
474486
scroll_filter=filter_query,
475487
limit=limit,
476488
with_payload=True,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy