From f95b623180483b490221f0f48c2861dfe6fc2735 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Wed, 26 Jul 2023 13:00:55 +0200 Subject: [PATCH 1/2] refactor: do not recompute every time num_docs Signed-off-by: Joan Fontanals Martinez --- docarray/index/backends/hnswlib.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 6048a5fec8..2841b26e3b 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -137,7 +137,7 @@ def __init__(self, db_config=None, **kwargs): self._column_names: List[str] = [] self._create_docs_table() self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() + self._num_docs = 0 # recompute again when needed self._logger.info(f'{self.__class__.__name__} has been initialized') @property @@ -281,7 +281,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs): self._send_docs_to_sqlite(docs_validated) self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() + self._num_docs = 0 # recompute again when needed def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: """ @@ -318,9 +318,6 @@ def _find_batched( def _find( self, query: np.ndarray, limit: int, search_field: str = '' ) -> _FindResult: - if self.num_docs() == 0: - return _FindResult(documents=[], scores=[]) # type: ignore - query_batched = np.expand_dims(query, axis=0) docs, scores = self._find_batched( queries=query_batched, limit=limit, search_field=search_field @@ -385,7 +382,7 @@ def _del_items(self, doc_ids: Sequence[str]): self._delete_docs_from_sqlite(doc_ids) self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() + self._num_docs = 0 # recompute again when needed def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]: """Get Documents from the hnswlib index, by `id`. @@ -410,6 +407,8 @@ def num_docs(self) -> int: """ Get the number of documents. """ + if self._num_docs == 0: + self._num_docs = self._get_num_docs_sqlite() return self._num_docs ############################################### @@ -605,11 +604,11 @@ def _search_and_filter( documents and their corresponding scores. """ # If there are no documents or hashed_ids is an empty set, return an empty _FindResultBatched - if self.num_docs() == 0 or (hashed_ids is not None and len(hashed_ids) == 0): + if hashed_ids is not None and len(hashed_ids) == 0: return _FindResultBatched(documents=[], scores=[]) # type: ignore # Set limit as the minimum of the provided limit and the total number of documents - limit = min(limit, self.num_docs()) + limit = limit # Ensure the search field is in the HNSW indices if search_field not in self._hnsw_indices: From 3ca5780b9f5c89769e8ead133475151d205105d3 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 26 Jul 2023 13:28:53 +0200 Subject: [PATCH 2/2] fix: k more than num docs Signed-off-by: jupyterjazz --- docarray/index/backends/hnswlib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 2841b26e3b..8b40c043c9 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -608,7 +608,7 @@ def _search_and_filter( return _FindResultBatched(documents=[], scores=[]) # type: ignore # Set limit as the minimum of the provided limit and the total number of documents - limit = limit + limit = min(limit, self.num_docs()) # Ensure the search field is in the HNSW indices if search_field not in self._hnsw_indices: pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy