From 9aebf2e912515716c4ef35088141673a5ac270bc Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Wed, 26 Jul 2023 12:44:54 +0200 Subject: [PATCH 1/8] refactor: hnswlib performance Signed-off-by: Joan Fontanals Martinez --- docarray/index/backends/hnswlib.py | 57 +++++++++++++++++++----------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index d4929569c6..ff45c21704 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -2,6 +2,7 @@ import hashlib import os import sqlite3 +from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path from typing import ( @@ -34,7 +35,7 @@ _collect_query_args, _execute_find_and_filter_query, ) -from docarray.proto import DocProto +from docarray.proto import DocProto, NdArrayProto, NodeProto from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.typing.tensor.ndarray import NdArray from docarray.utils._internal._typing import safe_issubclass @@ -63,7 +64,6 @@ HNSWLIB_PY_VEC_TYPES.append(tf.Tensor) HNSWLIB_PY_VEC_TYPES.append(TensorFlowTensor) - TSchema = TypeVar('TSchema', bound=BaseDoc) T = TypeVar('T', bound='HnswDocumentIndex') @@ -127,7 +127,6 @@ def __init__(self, db_config=None, **kwargs): self._sqlite_cursor = self._sqlite_conn.cursor() self._create_docs_table() self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() self._logger.info(f'{self.__class__.__name__} has been initialized') @property @@ -255,12 +254,9 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs): docs_validated = self._validate_docs(docs) self._update_subindex_data(docs_validated) data_by_columns = self._get_col_value_dict(docs_validated) - self._index(data_by_columns, docs_validated, **kwargs) - self._send_docs_to_sqlite(docs_validated) self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: """ @@ -293,10 +289,6 @@ def _find_batched( limit: int, search_field: str = '', ) -> _FindResultBatched: - if self.num_docs() == 0: - return _FindResultBatched(documents=[], scores=[]) # type: ignore - - limit = min(limit, self.num_docs()) index = self._hnsw_indices[search_field] labels, distances = index.knn_query(queries, k=int(limit)) @@ -311,9 +303,6 @@ def _find_batched( def _find( self, query: np.ndarray, limit: int, search_field: str = '' ) -> _FindResult: - if self.num_docs() == 0: - return _FindResult(documents=[], scores=[]) # type: ignore - query_batched = np.expand_dims(query, axis=0) docs, scores = self._find_batched( queries=query_batched, limit=limit, search_field=search_field @@ -381,7 +370,6 @@ def _del_items(self, doc_ids: Sequence[str]): self._delete_docs_from_sqlite(doc_ids) self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]: """Get Documents from the hnswlib index, by `id`. @@ -406,7 +394,7 @@ def num_docs(self) -> int: """ Get the number of documents. """ - return self._num_docs + return self._get_num_docs_sqlite() ############################################### # Helpers # @@ -471,10 +459,19 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True): self._sqlite_cursor.execute( 'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list, ) + embeddings: OrderedDict[str, list] = OrderedDict() + for col_name, index in self._hnsw_indices.items(): + embeddings[col_name] = index.get_items(univ_ids) rows = self._sqlite_cursor.fetchall() schema = self.out_schema if out else self._schema - docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], schema)) - return docs_cls([self._doc_from_bytes(row[0], out) for row in rows]) + docs = DocList.__class_getitem__(cast(Type[BaseDoc], schema))() + for i, row in enumerate(rows): + reconstruct_embeddings = {} + for col_name in embeddings.keys(): + reconstruct_embeddings[col_name] = embeddings[col_name][i] + docs.append(self._doc_from_bytes(row[0], reconstruct_embeddings, out)) + + return docs def _get_docs_sqlite_doc_id( self, doc_ids: Sequence[str], out: bool = True @@ -509,12 +506,30 @@ def _get_num_docs_sqlite(self) -> int: # serialization helpers def _doc_to_bytes(self, doc: BaseDoc) -> bytes: - return doc.to_protobuf().SerializeToString() - - def _doc_from_bytes(self, data: bytes, out: bool = True) -> BaseDoc: + pb = doc.to_protobuf() + for col_name in self._hnsw_indices.keys(): + pb.data[col_name].Clear() + return pb.SerializeToString() + + def _doc_from_bytes( + self, data: bytes, reconstruct_embeddings: Dict, out: bool = True + ) -> BaseDoc: schema = self.out_schema if out else self._schema schema_cls = cast(Type[BaseDoc], schema) - return schema_cls.from_protobuf(DocProto.FromString(data)) + pb = DocProto.FromString(data) + for k, v in reconstruct_embeddings.items(): + nd_proto = NdArrayProto() + np_array = np.array(v) + nd_proto.dense.buffer = np_array.tobytes() + nd_proto.dense.ClearField('shape') + nd_proto.dense.shape.extend(list(np_array.shape)) + nd_proto.dense.dtype = np_array.dtype.str + node_proto = NodeProto(ndarray=nd_proto, type='ndarray') + + pb.data[k].MergeFrom(node_proto) + + doc = schema_cls.from_protobuf(pb) + return doc def _get_root_doc_id(self, id: str, root: str, sub: str) -> str: """Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib. From f8bb5f922bb1ae480004aac251dafe35b5803391 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Thu, 27 Jul 2023 04:33:07 +0200 Subject: [PATCH 2/8] fix: reconstruct from proper type Signed-off-by: Joan Fontanals Martinez --- docarray/index/backends/hnswlib.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 2b69085e0b..edd41703ba 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -35,7 +35,7 @@ from docarray.index.backends.helper import ( _collect_query_args, ) -from docarray.proto import DocProto, NdArrayProto, NodeProto +from docarray.proto import DocProto from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.typing.tensor.ndarray import NdArray from docarray.utils._internal._typing import safe_issubclass @@ -571,16 +571,15 @@ def _doc_from_bytes( ) -> BaseDoc: schema = self.out_schema if out else self._schema schema_cls = cast(Type[BaseDoc], schema) - pb = DocProto.FromString(data) + pb = DocProto.FromString( + data + ) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional for k, v in reconstruct_embeddings.items(): - nd_proto = NdArrayProto() - np_array = np.array(v) - nd_proto.dense.buffer = np_array.tobytes() - nd_proto.dense.ClearField('shape') - nd_proto.dense.shape.extend(list(np_array.shape)) - nd_proto.dense.dtype = np_array.dtype.str - node_proto = NodeProto(ndarray=nd_proto, type='ndarray') - + node_proto = ( + self.out_schema.__fields__[k] + .type_._docarray_from_ndarray(np.array(v)) + ._to_node_protobuf() + ) pb.data[k].MergeFrom(node_proto) doc = schema_cls.from_protobuf(pb) From ae33fe6bda78ec2e57283beff5aa54c5b88483a0 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Thu, 27 Jul 2023 05:56:36 +0200 Subject: [PATCH 3/8] fix: get the proper embedding back --- docarray/index/backends/hnswlib.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index edd41703ba..d22ba80ad7 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -512,19 +512,22 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True): assert isinstance(id_, int) or is_np_int(id_) sql_id_list = '(' + ', '.join(str(id_) for id_ in univ_ids) + ')' self._sqlite_cursor.execute( - 'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list, + 'SELECT doc_id, data FROM docs WHERE doc_id IN %s' % sql_id_list, ) + rows = ( + self._sqlite_cursor.fetchall() + ) # doc_ids do not come back in the same order embeddings: OrderedDict[str, list] = OrderedDict() for col_name, index in self._hnsw_indices.items(): - embeddings[col_name] = index.get_items(univ_ids) - rows = self._sqlite_cursor.fetchall() + embeddings[col_name] = index.get_items([row[0] for row in rows]) + schema = self.out_schema if out else self._schema docs = DocList.__class_getitem__(cast(Type[BaseDoc], schema))() - for i, row in enumerate(rows): + for i, (_, data_bytes) in enumerate(rows): reconstruct_embeddings = {} for col_name in embeddings.keys(): reconstruct_embeddings[col_name] = embeddings[col_name][i] - docs.append(self._doc_from_bytes(row[0], reconstruct_embeddings, out)) + docs.append(self._doc_from_bytes(data_bytes, reconstruct_embeddings, out)) return docs @@ -575,6 +578,7 @@ def _doc_from_bytes( data ) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional for k, v in reconstruct_embeddings.items(): + print(f'v {v}') node_proto = ( self.out_schema.__fields__[k] .type_._docarray_from_ndarray(np.array(v)) @@ -583,6 +587,7 @@ def _doc_from_bytes( pb.data[k].MergeFrom(node_proto) doc = schema_cls.from_protobuf(pb) + print(f'doc {doc}') return doc def _get_root_doc_id(self, id: str, root: str, sub: str) -> str: From a91e23ea76b6b33765c22b12fb3492d4fa9a10f0 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Thu, 27 Jul 2023 09:45:09 +0800 Subject: [PATCH 4/8] chore: avoid extra debugging (#1730) Signed-off-by: Joan Fontanals Martinez --- docarray/index/backends/elastic.py | 1 - tests/index/hnswlib/test_index_get_del.py | 22 +++++++++++----------- tests/index/hnswlib/test_persist_data.py | 6 +++--- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index 7981ba3d4e..b83ce82641 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -136,7 +136,6 @@ def index_name(self): self._logger.error(err_msg) raise ValueError(err_msg) index_name = self._db_config.index_name or default_index_name - self._logger.debug(f'Retrieved index name: {index_name}') return index_name ############################################### diff --git a/tests/index/hnswlib/test_index_get_del.py b/tests/index/hnswlib/test_index_get_del.py index ebf06ce4b1..34a194a9d3 100644 --- a/tests/index/hnswlib/test_index_get_del.py +++ b/tests/index/hnswlib/test_index_get_del.py @@ -211,7 +211,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): for d in ten_simple_docs: id_ = d.id assert index[id_].id == id_ - assert np.all(index[id_].tens == d.tens) + assert np.allclose(index[id_].tens == d.tens) # flat index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path)) @@ -221,8 +221,8 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): for d in ten_flat_docs: id_ = d.id assert index[id_].id == id_ - assert np.all(index[id_].tens_one == d.tens_one) - assert np.all(index[id_].tens_two == d.tens_two) + assert np.allclose(index[id_].tens_one == d.tens_one) + assert np.allclose(index[id_].tens_two == d.tens_two) # nested index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path)) @@ -233,7 +233,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): id_ = d.id assert index[id_].id == id_ assert index[id_].d.id == d.d.id - assert np.all(index[id_].d.tens == d.d.tens) + assert np.allclose(index[id_].d.tens == d.d.tens) def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): @@ -252,7 +252,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path) retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ - assert np.all(d_out.tens == d_in.tens) + assert np.allclose(d_out.tens == d_in.tens) # flat index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path)) @@ -264,8 +264,8 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path) retrieved_docs = index[ids_to_get] for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ - assert np.all(d_out.tens_one == d_in.tens_one) - assert np.all(d_out.tens_two == d_in.tens_two) + assert np.allclose(d_out.tens_one == d_in.tens_one) + assert np.allclose(d_out.tens_two == d_in.tens_two) # nested index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path)) @@ -278,7 +278,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path) for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): assert d_out.id == id_ assert d_out.d.id == d_in.d.id - assert np.all(d_out.d.tens == d_in.d.tens) + assert np.allclose(d_out.d.tens == d_in.d.tens) def test_get_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): @@ -303,7 +303,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): index[id_] else: assert index[id_].id == id_ - assert np.all(index[id_].tens == d.tens) + assert np.allclose(index[id_].tens == d.tens) # delete again del index[ten_simple_docs[3].id] assert index.num_docs() == 8 @@ -314,7 +314,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): index[id_] else: assert index[id_].id == id_ - assert np.all(index[id_].tens == d.tens) + assert np.allclose(index[id_].tens == d.tens) def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): @@ -333,7 +333,7 @@ def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path) index[doc.id] else: assert index[doc.id].id == doc.id - assert np.all(index[doc.id].tens == doc.tens) + assert np.allclose(index[doc.id].tens == doc.tens) def test_del_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): diff --git a/tests/index/hnswlib/test_persist_data.py b/tests/index/hnswlib/test_persist_data.py index fab761582c..d2395347d6 100644 --- a/tests/index/hnswlib/test_persist_data.py +++ b/tests/index/hnswlib/test_persist_data.py @@ -22,7 +22,7 @@ def test_persist_and_restore(tmp_path): query = SimpleDoc(tens=np.random.random((10,))) # create index - index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + _ = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) # load existing index file index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) @@ -38,7 +38,7 @@ def test_persist_and_restore(tmp_path): find_results_after = index.find(query, search_field='tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): assert doc_before.id == doc_after.id - assert (doc_before.tens == doc_after.tens).all() + assert np.allclose(doc_before.tens == doc_after.tens) # add new data index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)]) @@ -70,7 +70,7 @@ def test_persist_and_restore_nested(tmp_path): find_results_after = index.find(query, search_field='d__tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): assert doc_before.id == doc_after.id - assert (doc_before.tens == doc_after.tens).all() + assert np.allclose(doc_before.tens == doc_after.tens) # delete and restore index.index( From 58ee168c1753a8f773bca7aae24880c654279544 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Fri, 28 Jul 2023 02:02:48 +0200 Subject: [PATCH 5/8] chore: add comment to take over later --- docarray/index/backends/hnswlib.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index d22ba80ad7..141091a8cd 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -567,6 +567,7 @@ def _doc_to_bytes(self, doc: BaseDoc) -> bytes: pb = doc.to_protobuf() for col_name in self._hnsw_indices.keys(): pb.data[col_name].Clear() + pb.data[col_name].Clear() return pb.SerializeToString() def _doc_from_bytes( @@ -578,6 +579,17 @@ def _doc_from_bytes( data ) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional for k, v in reconstruct_embeddings.items(): + print(f' k {k}') + # access = k.split('__') + # proto = pb + # python_schema = self.out_schema + # for field in access: + # print(f' field {field} and {python_schema} and {proto}') + # proto = proto.data[field] + # python_schema = python_schema._get_field_type(field) + # node_proto = python_schema._docarray_from_ndarray(np.array(v))._to_node_protobuf() + # proto.MergeFrom(node_proto) + print(f'v {v}') node_proto = ( self.out_schema.__fields__[k] From 83e55f5af5b94f24ca0ffe6a0bd566e1bf7189ca Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Fri, 28 Jul 2023 03:12:03 +0200 Subject: [PATCH 6/8] test: refactor hnswlib test subindex Signed-off-by: Joan Fontanals Martinez --- tests/index/hnswlib/test_subindex.py | 47 +++++++++++++++++----------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/tests/index/hnswlib/test_subindex.py b/tests/index/hnswlib/test_subindex.py index 857cc8f029..82d51069df 100644 --- a/tests/index/hnswlib/test_subindex.py +++ b/tests/index/hnswlib/test_subindex.py @@ -27,20 +27,7 @@ class MyDoc(BaseDoc): @pytest.fixture(scope='session') -def index(): - index = HnswDocumentIndex[MyDoc](work_dir='./tmp') - return index - - -def test_subindex_init(index): - assert isinstance(index._subindices['docs'], HnswDocumentIndex) - assert isinstance(index._subindices['list_docs'], HnswDocumentIndex) - assert isinstance( - index._subindices['list_docs']._subindices['docs'], HnswDocumentIndex - ) - - -def test_subindex_index(index): +def index_docs(): my_docs = [ MyDoc( id=f'{i}', @@ -82,15 +69,31 @@ def test_subindex_index(index): ) for i in range(5) ] + return my_docs + + +def test_subindex_init(tmpdir, index_docs): + index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir)) + index.index(index_docs) + assert isinstance(index._subindices['docs'], HnswDocumentIndex) + assert isinstance(index._subindices['list_docs'], HnswDocumentIndex) + assert isinstance( + index._subindices['list_docs']._subindices['docs'], HnswDocumentIndex + ) + - index.index(my_docs) +def test_subindex_index(tmpdir, index_docs): + index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir)) + index.index(index_docs) assert index.num_docs() == 5 assert index._subindices['docs'].num_docs() == 25 assert index._subindices['list_docs'].num_docs() == 25 assert index._subindices['list_docs']._subindices['docs'].num_docs() == 125 -def test_subindex_get(index): +def test_subindex_get(tmpdir, index_docs): + index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir)) + index.index(index_docs) doc = index['1'] assert type(doc) == MyDoc assert doc.id == '1' @@ -116,7 +119,9 @@ def test_subindex_get(index): assert np.allclose(doc.my_tens, np.ones(30) * 2) -def test_find_subindex(index): +def test_find_subindex(tmpdir, index_docs): + index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir)) + index.index(index_docs) # root level query = np.ones((30,)) with pytest.raises(ValueError): @@ -148,7 +153,9 @@ def test_find_subindex(index): assert root_doc.id == f'{doc.id.split("-")[2]}' -def test_subindex_del(index): +def test_subindex_del(tmpdir, index_docs): + index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir)) + index.index(index_docs) del index['0'] assert index.num_docs() == 4 assert index._subindices['docs'].num_docs() == 20 @@ -156,7 +163,9 @@ def test_subindex_del(index): assert index._subindices['list_docs']._subindices['docs'].num_docs() == 100 -def test_subindex_contain(index): +def test_subindex_contain(tmpdir, index_docs): + index = HnswDocumentIndex[MyDoc](work_dir=str(tmpdir)) + index.index(index_docs) # Checks for individual simple_docs within list_docs for i in range(4): doc = index[f'{i + 1}'] From 2ea5804a26a1f94157025b9754357adfdf59b058 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Fri, 28 Jul 2023 03:35:27 +0200 Subject: [PATCH 7/8] chore: fix test --- docarray/index/backends/hnswlib.py | 13 ------------- tests/index/elastic/v7/docker-compose.yml | 6 ------ tests/index/elastic/v8/docker-compose.yml | 6 ------ 3 files changed, 25 deletions(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 141091a8cd..dd46b47378 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -579,18 +579,6 @@ def _doc_from_bytes( data ) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional for k, v in reconstruct_embeddings.items(): - print(f' k {k}') - # access = k.split('__') - # proto = pb - # python_schema = self.out_schema - # for field in access: - # print(f' field {field} and {python_schema} and {proto}') - # proto = proto.data[field] - # python_schema = python_schema._get_field_type(field) - # node_proto = python_schema._docarray_from_ndarray(np.array(v))._to_node_protobuf() - # proto.MergeFrom(node_proto) - - print(f'v {v}') node_proto = ( self.out_schema.__fields__[k] .type_._docarray_from_ndarray(np.array(v)) @@ -599,7 +587,6 @@ def _doc_from_bytes( pb.data[k].MergeFrom(node_proto) doc = schema_cls.from_protobuf(pb) - print(f'doc {doc}') return doc def _get_root_doc_id(self, id: str, root: str, sub: str) -> str: diff --git a/tests/index/elastic/v7/docker-compose.yml b/tests/index/elastic/v7/docker-compose.yml index f4dd8a49d0..1559e0b714 100644 --- a/tests/index/elastic/v7/docker-compose.yml +++ b/tests/index/elastic/v7/docker-compose.yml @@ -8,9 +8,3 @@ services: - ES_JAVA_OPTS=-Xmx1024m ports: - "9200:9200" - networks: - - elastic - -networks: - elastic: - name: elastic \ No newline at end of file diff --git a/tests/index/elastic/v8/docker-compose.yml b/tests/index/elastic/v8/docker-compose.yml index 70eedba34f..78d84e05f5 100644 --- a/tests/index/elastic/v8/docker-compose.yml +++ b/tests/index/elastic/v8/docker-compose.yml @@ -8,9 +8,3 @@ services: - ES_JAVA_OPTS=-Xmx1024m ports: - "9200:9200" - networks: - - elastic - -networks: - elastic: - name: elastic \ No newline at end of file From 8c0bf6cf536c99b0e7dee665822cc210018ffd19 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Fri, 28 Jul 2023 03:47:37 +0200 Subject: [PATCH 8/8] chore: apply optim Signed-off-by: Joan Fontanals Martinez --- docarray/index/backends/hnswlib.py | 31 +++++++++++++++-------- tests/index/hnswlib/test_index_get_del.py | 2 +- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index dd46b47378..8f08ae5c39 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -108,7 +108,11 @@ def __init__(self, db_config=None, **kwargs): if col.config } self._hnsw_indices = {} + sub_docs_exist = False + cosine_metric_index_exist = False for col_name, col in self._column_infos.items(): + if '__' in col_name: + sub_docs_exist = True if safe_issubclass(col.docarray_type, AnyDocArray): continue if not col.config: @@ -128,7 +132,12 @@ def __init__(self, db_config=None, **kwargs): else: self._hnsw_indices[col_name] = self._create_index(col_name, col) self._logger.info(f'Created a new index for column `{col_name}`') + if self._hnsw_indices[col_name].space == 'cosine': + cosine_metric_index_exist = True + self._apply_optim_no_embedding_in_sqlite = ( + not sub_docs_exist and not cosine_metric_index_exist + ) # optimization consisting in not serializing embeddings to SQLite because they are expensive to send and they can be reconstructed from the HNSW index itself. # SQLite setup self._sqlite_db_path = os.path.join(self._work_dir, 'docs_sqlite.db') self._logger.debug(f'DB path set to {self._sqlite_db_path}') @@ -565,9 +574,10 @@ def _get_num_docs_sqlite(self) -> int: # serialization helpers def _doc_to_bytes(self, doc: BaseDoc) -> bytes: pb = doc.to_protobuf() - for col_name in self._hnsw_indices.keys(): - pb.data[col_name].Clear() - pb.data[col_name].Clear() + if self._apply_optim_no_embedding_in_sqlite: + for col_name in self._hnsw_indices.keys(): + pb.data[col_name].Clear() + pb.data[col_name].Clear() return pb.SerializeToString() def _doc_from_bytes( @@ -578,13 +588,14 @@ def _doc_from_bytes( pb = DocProto.FromString( data ) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional - for k, v in reconstruct_embeddings.items(): - node_proto = ( - self.out_schema.__fields__[k] - .type_._docarray_from_ndarray(np.array(v)) - ._to_node_protobuf() - ) - pb.data[k].MergeFrom(node_proto) + if self._apply_optim_no_embedding_in_sqlite: + for k, v in reconstruct_embeddings.items(): + node_proto = ( + schema_cls._get_field_type(k) + ._docarray_from_ndarray(np.array(v)) + ._to_node_protobuf() + ) + pb.data[k].MergeFrom(node_proto) doc = schema_cls.from_protobuf(pb) return doc diff --git a/tests/index/hnswlib/test_index_get_del.py b/tests/index/hnswlib/test_index_get_del.py index 6607d64659..845169da12 100644 --- a/tests/index/hnswlib/test_index_get_del.py +++ b/tests/index/hnswlib/test_index_get_del.py @@ -410,5 +410,5 @@ class TextSimpleDoc(SimpleDoc): for doc in res.documents: if doc.id == docs[0].id: found = True - assert (doc.tens == new_tensor).all() + assert np.allclose(doc.tens, new_tensor) assert found pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy