diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 434d9734c82..2f8fddf70a1 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -230,6 +230,16 @@ def execute_query(self, query: Any, *args, **kwargs) -> Any: """ ... + @abstractmethod + def _doc_exists(self, doc_id: str) -> bool: + """ + Checks if a given document exists in the index. + + :param doc_id: The id of a document to check. + :return: True if the document exists in the index, False otherwise. + """ + ... + @abstractmethod def _find( self, @@ -403,6 +413,21 @@ def __delitem__(self, key: Union[str, Sequence[str]]): # delete data self._del_items(key) + def __contains__(self, item: BaseDoc) -> bool: + """ + Checks if a given document exists in the index. + + :param item: The document to check. + It must be an instance of BaseDoc or its subclass. + :return: True if the document exists in the index, False otherwise. + """ + if safe_issubclass(type(item), BaseDoc): + return self._doc_exists(str(item.id)) + else: + raise TypeError( + f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" + ) + def configure(self, runtime_config=None, **kwargs): """ Configure the DocumentIndex. @@ -1170,14 +1195,6 @@ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str: ) return self._get_root_doc_id(cur_root_id, root, '') - def __contains__(self, item: BaseDoc) -> bool: - """Checks if a given BaseDoc item is contained in the index. - - :param item: the given BaseDoc - :return: if the given BaseDoc item is contained in the index - """ - return False # Will be overridden by backends - def subindex_contains(self, item: BaseDoc) -> bool: """Checks if a given BaseDoc item is contained in the index or any of its subindices. diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index e201dca4ac0..7981ba3d4e8 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -669,16 +669,11 @@ def _format_response(self, response: Any) -> Tuple[List[Dict], List[Any]]: def _refresh(self, index_name: str): self._client.indices.refresh(index=index_name) - def __contains__(self, item: BaseDoc) -> bool: - if safe_issubclass(type(item), BaseDoc): - if len(item.id) == 0: - return False - ret = self._client_mget([item.id]) - return ret["docs"][0]["found"] - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) + def _doc_exists(self, doc_id: str) -> bool: + if len(doc_id) == 0: + return False + ret = self._client_mget([doc_id]) + return ret["docs"][0]["found"] ############################################### # API Wrappers # diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index b13e143a6c6..00124c7fbd0 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -393,18 +393,11 @@ def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSche raise KeyError(f'No document with id {doc_ids} found') return out_docs - def __contains__(self, item: BaseDoc): - if safe_issubclass(type(item), BaseDoc): - hash_id = self._to_hashed_id(item.id) - self._sqlite_cursor.execute( - f"SELECT data FROM docs WHERE doc_id = '{hash_id}'" - ) - rows = self._sqlite_cursor.fetchall() - return len(rows) > 0 - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) + def _doc_exists(self, doc_id: str) -> bool: + hash_id = self._to_hashed_id(doc_id) + self._sqlite_cursor.execute(f"SELECT data FROM docs WHERE doc_id = '{hash_id}'") + rows = self._sqlite_cursor.fetchall() + return len(rows) > 0 def num_docs(self) -> int: """ diff --git a/docarray/index/backends/in_memory.py b/docarray/index/backends/in_memory.py index 62ee3f0ffee..8a8132c2394 100644 --- a/docarray/index/backends/in_memory.py +++ b/docarray/index/backends/in_memory.py @@ -431,13 +431,8 @@ def _text_search_batched( ) -> _FindResultBatched: raise NotImplementedError(f'{type(self)} does not support text search.') - def __contains__(self, item: BaseDoc): - if safe_issubclass(type(item), BaseDoc): - return any(doc.id == item.id for doc in self._docs) - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) + def _doc_exists(self, doc_id: str) -> bool: + return any(doc.id == doc_id for doc in self._docs) def persist(self, file: Optional[str] = None) -> None: """Persist InMemoryExactNNIndex into a binary file.""" diff --git a/docarray/index/backends/qdrant.py b/docarray/index/backends/qdrant.py index 0ddf77a1e96..5288da19881 100644 --- a/docarray/index/backends/qdrant.py +++ b/docarray/index/backends/qdrant.py @@ -317,21 +317,16 @@ def num_docs(self) -> int: """ return self._client.count(collection_name=self.collection_name).count - def __contains__(self, item: BaseDoc) -> bool: - if safe_issubclass(type(item), BaseDoc): - response, _ = self._client.scroll( - collection_name=self.index_name, - scroll_filter=rest.Filter( - must=[ - rest.HasIdCondition(has_id=[self._to_qdrant_id(item.id)]), - ], - ), - ) - return len(response) > 0 - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) + def _doc_exists(self, doc_id: str) -> bool: + response, _ = self._client.scroll( + collection_name=self.index_name, + scroll_filter=rest.Filter( + must=[ + rest.HasIdCondition(has_id=[self._to_qdrant_id(doc_id)]), + ], + ), + ) + return len(response) > 0 def _del_items(self, doc_ids: Sequence[str]): items = self._get_items(doc_ids) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index e82652d86e4..937c77efdaa 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -377,7 +377,7 @@ def _del_items(self, doc_ids: Sequence[str]) -> None: ): self._client.delete(*batch) - def _doc_exists(self, doc_id) -> bool: + def _doc_exists(self, doc_id: str) -> bool: """ Checks if a document exists in the index. @@ -610,18 +610,3 @@ def _text_search_batched( scores.append(results.scores) return _FindResultBatched(documents=docs, scores=scores) - - def __contains__(self, item: BaseDoc) -> bool: - """ - Checks if a given document exists in the index. - - :param item: The document to check. - It must be an instance of BaseDoc or its subclass. - :return: True if the document exists in the index, False otherwise. - """ - if safe_issubclass(type(item), BaseDoc): - return self._doc_exists(item.id) - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py index 20b1c649c3b..b001888dd98 100644 --- a/docarray/index/backends/weaviate.py +++ b/docarray/index/backends/weaviate.py @@ -760,25 +760,20 @@ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]: ] return ids - def __contains__(self, item: BaseDoc) -> bool: - if safe_issubclass(type(item), BaseDoc): - result = ( - self._client.query.get(self.index_name, ['docarrayid']) - .with_where( - { - "path": ['docarrayid'], - "operator": "Equal", - "valueString": f'{item.id}', - } - ) - .do() - ) - docs = result["data"]["Get"][self.index_name] - return docs is not None and len(docs) > 0 - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" + def _doc_exists(self, doc_id: str) -> bool: + result = ( + self._client.query.get(self.index_name, ['docarrayid']) + .with_where( + { + "path": ['docarrayid'], + "operator": "Equal", + "valueString": f'{doc_id}', + } ) + .do() + ) + docs = result["data"]["Get"][self.index_name] + return docs is not None and len(docs) > 0 class QueryBuilder(BaseDocIndex.QueryBuilder): def __init__(self, document_index): diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index 69b63c57e88..09f46ee4535 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -97,6 +97,9 @@ def python_type_to_db_type(self, x): def num_docs(self): return 3 + def _doc_exists(self, doc_id: str) -> bool: + return False + _index = _identity _del_items = _identity _get_items = _identity diff --git a/tests/index/base_classes/test_configs.py b/tests/index/base_classes/test_configs.py index 7b7efbea596..b2a5f0ecfd5 100644 --- a/tests/index/base_classes/test_configs.py +++ b/tests/index/base_classes/test_configs.py @@ -35,7 +35,6 @@ class DBConfig(BaseDocIndex.DBConfig): @dataclass class RuntimeConfig(BaseDocIndex.RuntimeConfig): - default_ef: int = 50 @@ -61,6 +60,7 @@ def python_type_to_db_type(self, x): _filter_batched = _identity _text_search = _identity _text_search_batched = _identity + _doc_exists = _identity def test_defaults(): pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy