Skip to content

refactor: contains method in the base class #1701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions docarray/index/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ def execute_query(self, query: Any, *args, **kwargs) -> Any:
"""
...

@abstractmethod
def _doc_exists(self, doc_id: str) -> bool:
"""
Checks if a given document exists in the index.

:param doc_id: The id of a document to check.
:return: True if the document exists in the index, False otherwise.
"""
...

@abstractmethod
def _find(
self,
Expand Down Expand Up @@ -403,6 +413,21 @@ def __delitem__(self, key: Union[str, Sequence[str]]):
# delete data
self._del_items(key)

def __contains__(self, item: BaseDoc) -> bool:
"""
Checks if a given document exists in the index.

:param item: The document to check.
It must be an instance of BaseDoc or its subclass.
:return: True if the document exists in the index, False otherwise.
"""
if safe_issubclass(type(item), BaseDoc):
return self._doc_exists(str(item.id))
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)

def configure(self, runtime_config=None, **kwargs):
"""
Configure the DocumentIndex.
Expand Down Expand Up @@ -1170,14 +1195,6 @@ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
)
return self._get_root_doc_id(cur_root_id, root, '')

def __contains__(self, item: BaseDoc) -> bool:
"""Checks if a given BaseDoc item is contained in the index.

:param item: the given BaseDoc
:return: if the given BaseDoc item is contained in the index
"""
return False # Will be overridden by backends

def subindex_contains(self, item: BaseDoc) -> bool:
"""Checks if a given BaseDoc item is contained in the index or any of its subindices.

Expand Down
15 changes: 5 additions & 10 deletions docarray/index/backends/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,16 +669,11 @@ def _format_response(self, response: Any) -> Tuple[List[Dict], List[Any]]:
def _refresh(self, index_name: str):
self._client.indices.refresh(index=index_name)

def __contains__(self, item: BaseDoc) -> bool:
if safe_issubclass(type(item), BaseDoc):
if len(item.id) == 0:
return False
ret = self._client_mget([item.id])
return ret["docs"][0]["found"]
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
def _doc_exists(self, doc_id: str) -> bool:
if len(doc_id) == 0:
return False
ret = self._client_mget([doc_id])
return ret["docs"][0]["found"]

###############################################
# API Wrappers #
Expand Down
17 changes: 5 additions & 12 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,18 +393,11 @@ def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSche
raise KeyError(f'No document with id {doc_ids} found')
return out_docs

def __contains__(self, item: BaseDoc):
if safe_issubclass(type(item), BaseDoc):
hash_id = self._to_hashed_id(item.id)
self._sqlite_cursor.execute(
f"SELECT data FROM docs WHERE doc_id = '{hash_id}'"
)
rows = self._sqlite_cursor.fetchall()
return len(rows) > 0
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
def _doc_exists(self, doc_id: str) -> bool:
hash_id = self._to_hashed_id(doc_id)
self._sqlite_cursor.execute(f"SELECT data FROM docs WHERE doc_id = '{hash_id}'")
rows = self._sqlite_cursor.fetchall()
return len(rows) > 0

def num_docs(self) -> int:
"""
Expand Down
9 changes: 2 additions & 7 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,13 +431,8 @@ def _text_search_batched(
) -> _FindResultBatched:
raise NotImplementedError(f'{type(self)} does not support text search.')

def __contains__(self, item: BaseDoc):
if safe_issubclass(type(item), BaseDoc):
return any(doc.id == item.id for doc in self._docs)
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
def _doc_exists(self, doc_id: str) -> bool:
return any(doc.id == doc_id for doc in self._docs)

def persist(self, file: Optional[str] = None) -> None:
"""Persist InMemoryExactNNIndex into a binary file."""
Expand Down
25 changes: 10 additions & 15 deletions docarray/index/backends/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,21 +317,16 @@ def num_docs(self) -> int:
"""
return self._client.count(collection_name=self.collection_name).count

def __contains__(self, item: BaseDoc) -> bool:
if safe_issubclass(type(item), BaseDoc):
response, _ = self._client.scroll(
collection_name=self.index_name,
scroll_filter=rest.Filter(
must=[
rest.HasIdCondition(has_id=[self._to_qdrant_id(item.id)]),
],
),
)
return len(response) > 0
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
def _doc_exists(self, doc_id: str) -> bool:
response, _ = self._client.scroll(
collection_name=self.index_name,
scroll_filter=rest.Filter(
must=[
rest.HasIdCondition(has_id=[self._to_qdrant_id(doc_id)]),
],
),
)
return len(response) > 0

def _del_items(self, doc_ids: Sequence[str]):
items = self._get_items(doc_ids)
Expand Down
17 changes: 1 addition & 16 deletions docarray/index/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _del_items(self, doc_ids: Sequence[str]) -> None:
):
self._client.delete(*batch)

def _doc_exists(self, doc_id) -> bool:
def _doc_exists(self, doc_id: str) -> bool:
"""
Checks if a document exists in the index.

Expand Down Expand Up @@ -610,18 +610,3 @@ def _text_search_batched(
scores.append(results.scores)

return _FindResultBatched(documents=docs, scores=scores)

def __contains__(self, item: BaseDoc) -> bool:
"""
Checks if a given document exists in the index.

:param item: The document to check.
It must be an instance of BaseDoc or its subclass.
:return: True if the document exists in the index, False otherwise.
"""
if safe_issubclass(type(item), BaseDoc):
return self._doc_exists(item.id)
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
31 changes: 13 additions & 18 deletions docarray/index/backends/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,25 +760,20 @@ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
]
return ids

def __contains__(self, item: BaseDoc) -> bool:
if safe_issubclass(type(item), BaseDoc):
result = (
self._client.query.get(self.index_name, ['docarrayid'])
.with_where(
{
"path": ['docarrayid'],
"operator": "Equal",
"valueString": f'{item.id}',
}
)
.do()
)
docs = result["data"]["Get"][self.index_name]
return docs is not None and len(docs) > 0
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
def _doc_exists(self, doc_id: str) -> bool:
result = (
self._client.query.get(self.index_name, ['docarrayid'])
.with_where(
{
"path": ['docarrayid'],
"operator": "Equal",
"valueString": f'{doc_id}',
}
)
.do()
)
docs = result["data"]["Get"][self.index_name]
return docs is not None and len(docs) > 0

class QueryBuilder(BaseDocIndex.QueryBuilder):
def __init__(self, document_index):
Expand Down
3 changes: 3 additions & 0 deletions tests/index/base_classes/test_base_doc_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def python_type_to_db_type(self, x):
def num_docs(self):
return 3

def _doc_exists(self, doc_id: str) -> bool:
return False

_index = _identity
_del_items = _identity
_get_items = _identity
Expand Down
2 changes: 1 addition & 1 deletion tests/index/base_classes/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DBConfig(BaseDocIndex.DBConfig):

@dataclass
class RuntimeConfig(BaseDocIndex.RuntimeConfig):

default_ef: int = 50


Expand All @@ -61,6 +60,7 @@ def python_type_to_db_type(self, x):
_filter_batched = _identity
_text_search = _identity
_text_search_batched = _identity
_doc_exists = _identity


def test_defaults():
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy