Skip to content

Commit a643f6a

Browse files
author
Joan Fontanals
authored
refactor: hnswlib performance (#1727)
Signed-off-by: Joan Fontanals Martinez <joan.martinez@jina.ai>
1 parent 87ec19f commit a643f6a

File tree

6 files changed

+92
-50
lines changed

6 files changed

+92
-50
lines changed

docarray/index/backends/hnswlib.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import hashlib
33
import os
44
import sqlite3
5-
from collections import defaultdict
5+
from collections import OrderedDict, defaultdict
66
from dataclasses import dataclass, field
77
from pathlib import Path
88
from typing import (
@@ -32,7 +32,9 @@
3232
_raise_not_composable,
3333
_raise_not_supported,
3434
)
35-
from docarray.index.backends.helper import _collect_query_args
35+
from docarray.index.backends.helper import (
36+
_collect_query_args,
37+
)
3638
from docarray.proto import DocProto
3739
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3840
from docarray.typing.tensor.ndarray import NdArray
@@ -63,7 +65,6 @@
6365
HNSWLIB_PY_VEC_TYPES.append(tf.Tensor)
6466
HNSWLIB_PY_VEC_TYPES.append(TensorFlowTensor)
6567

66-
6768
TSchema = TypeVar('TSchema', bound=BaseDoc)
6869
T = TypeVar('T', bound='HnswDocumentIndex')
6970

@@ -107,7 +108,11 @@ def __init__(self, db_config=None, **kwargs):
107108
if col.config
108109
}
109110
self._hnsw_indices = {}
111+
sub_docs_exist = False
112+
cosine_metric_index_exist = False
110113
for col_name, col in self._column_infos.items():
114+
if '__' in col_name:
115+
sub_docs_exist = True
111116
if safe_issubclass(col.docarray_type, AnyDocArray):
112117
continue
113118
if not col.config:
@@ -127,7 +132,12 @@ def __init__(self, db_config=None, **kwargs):
127132
else:
128133
self._hnsw_indices[col_name] = self._create_index(col_name, col)
129134
self._logger.info(f'Created a new index for column `{col_name}`')
135+
if self._hnsw_indices[col_name].space == 'cosine':
136+
cosine_metric_index_exist = True
130137

138+
self._apply_optim_no_embedding_in_sqlite = (
139+
not sub_docs_exist and not cosine_metric_index_exist
140+
) # optimization consisting in not serializing embeddings to SQLite because they are expensive to send and they can be reconstructed from the HNSW index itself.
131141
# SQLite setup
132142
self._sqlite_db_path = os.path.join(self._work_dir, 'docs_sqlite.db')
133143
self._logger.debug(f'DB path set to {self._sqlite_db_path}')
@@ -276,9 +286,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
276286
docs_validated = self._validate_docs(docs)
277287
self._update_subindex_data(docs_validated)
278288
data_by_columns = self._get_col_value_dict(docs_validated)
279-
280289
self._index(data_by_columns, docs_validated, **kwargs)
281-
282290
self._send_docs_to_sqlite(docs_validated)
283291
self._sqlite_conn.commit()
284292
self._num_docs = 0 # recompute again when needed
@@ -332,7 +340,19 @@ def _filter(
332340
limit: int,
333341
) -> DocList:
334342
rows = self._execute_filter(filter_query=filter_query, limit=limit)
335-
return DocList[self.out_schema](self._doc_from_bytes(blob) for _, blob in rows) # type: ignore[name-defined]
343+
hashed_ids = [doc_id for doc_id, _ in rows]
344+
embeddings: OrderedDict[str, list] = OrderedDict()
345+
for col_name, index in self._hnsw_indices.items():
346+
embeddings[col_name] = index.get_items(hashed_ids)
347+
348+
docs = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))()
349+
for i, row in enumerate(rows):
350+
reconstruct_embeddings = {}
351+
for col_name in embeddings.keys():
352+
reconstruct_embeddings[col_name] = embeddings[col_name][i]
353+
docs.append(self._doc_from_bytes(row[1], reconstruct_embeddings))
354+
355+
return docs
336356

337357
def _filter_batched(
338358
self,
@@ -501,12 +521,24 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True):
501521
assert isinstance(id_, int) or is_np_int(id_)
502522
sql_id_list = '(' + ', '.join(str(id_) for id_ in univ_ids) + ')'
503523
self._sqlite_cursor.execute(
504-
'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list,
524+
'SELECT doc_id, data FROM docs WHERE doc_id IN %s' % sql_id_list,
505525
)
506-
rows = self._sqlite_cursor.fetchall()
526+
rows = (
527+
self._sqlite_cursor.fetchall()
528+
) # doc_ids do not come back in the same order
529+
embeddings: OrderedDict[str, list] = OrderedDict()
530+
for col_name, index in self._hnsw_indices.items():
531+
embeddings[col_name] = index.get_items([row[0] for row in rows])
532+
507533
schema = self.out_schema if out else self._schema
508-
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], schema))
509-
return docs_cls([self._doc_from_bytes(row[0], out) for row in rows])
534+
docs = DocList.__class_getitem__(cast(Type[BaseDoc], schema))()
535+
for i, (_, data_bytes) in enumerate(rows):
536+
reconstruct_embeddings = {}
537+
for col_name in embeddings.keys():
538+
reconstruct_embeddings[col_name] = embeddings[col_name][i]
539+
docs.append(self._doc_from_bytes(data_bytes, reconstruct_embeddings, out))
540+
541+
return docs
510542

511543
def _get_docs_sqlite_doc_id(
512544
self, doc_ids: Sequence[str], out: bool = True
@@ -541,12 +573,32 @@ def _get_num_docs_sqlite(self) -> int:
541573

542574
# serialization helpers
543575
def _doc_to_bytes(self, doc: BaseDoc) -> bytes:
544-
return doc.to_protobuf().SerializeToString()
545-
546-
def _doc_from_bytes(self, data: bytes, out: bool = True) -> BaseDoc:
576+
pb = doc.to_protobuf()
577+
if self._apply_optim_no_embedding_in_sqlite:
578+
for col_name in self._hnsw_indices.keys():
579+
pb.data[col_name].Clear()
580+
pb.data[col_name].Clear()
581+
return pb.SerializeToString()
582+
583+
def _doc_from_bytes(
584+
self, data: bytes, reconstruct_embeddings: Dict, out: bool = True
585+
) -> BaseDoc:
547586
schema = self.out_schema if out else self._schema
548587
schema_cls = cast(Type[BaseDoc], schema)
549-
return schema_cls.from_protobuf(DocProto.FromString(data))
588+
pb = DocProto.FromString(
589+
data
590+
) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional
591+
if self._apply_optim_no_embedding_in_sqlite:
592+
for k, v in reconstruct_embeddings.items():
593+
node_proto = (
594+
schema_cls._get_field_type(k)
595+
._docarray_from_ndarray(np.array(v))
596+
._to_node_protobuf()
597+
)
598+
pb.data[k].MergeFrom(node_proto)
599+
600+
doc = schema_cls.from_protobuf(pb)
601+
return doc
550602

551603
def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
552604
"""Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.
@@ -608,25 +660,24 @@ def _search_and_filter(
608660
return _FindResultBatched(documents=[], scores=[]) # type: ignore
609661

610662
# Set limit as the minimum of the provided limit and the total number of documents
611-
limit = min(limit, self.num_docs())
663+
limit = limit
612664

613665
# Ensure the search field is in the HNSW indices
614666
if search_field not in self._hnsw_indices:
615667
raise ValueError(
616668
f'Search field {search_field} is not present in the HNSW indices'
617669
)
618670

619-
index = self._hnsw_indices[search_field]
620-
621671
def accept_hashed_ids(id):
622672
"""Accepts IDs that are in hashed_ids."""
623673
return id in hashed_ids # type: ignore[operator]
624674

625-
# Choose the appropriate filter function based on whether hashed_ids was provided
626675
extra_kwargs = {'filter': accept_hashed_ids} if hashed_ids else {}
627676

628677
# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
629678
k = min(limit, len(hashed_ids)) if hashed_ids else limit
679+
index = self._hnsw_indices[search_field]
680+
630681
try:
631682
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
632683
except RuntimeError:
@@ -639,7 +690,6 @@ def accept_hashed_ids(id):
639690
)
640691
for ids_per_query in labels
641692
]
642-
643693
return _FindResultBatched(documents=result_das, scores=distances)
644694

645695
@classmethod

tests/index/elastic/v7/docker-compose.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,3 @@ services:
88
- ES_JAVA_OPTS=-Xmx1024m
99
ports:
1010
- "9200:9200"
11-
networks:
12-
- elastic
13-
14-
networks:
15-
elastic:
16-
name: elastic

tests/index/elastic/v8/docker-compose.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,3 @@ services:
88
- ES_JAVA_OPTS=-Xmx1024m
99
ports:
1010
- "9200:9200"
11-
networks:
12-
- elastic
13-
14-
networks:
15-
elastic:
16-
name: elastic

tests/index/hnswlib/test_filter.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,14 @@ def test_build_query_invalid_query():
6969
HnswDocumentIndex._build_filter_query(query, param_values)
7070

7171

72-
def test_filter_eq(doc_index):
73-
docs = doc_index.filter({'text': {'$eq': 'text 1'}})
74-
assert len(docs) == 1
75-
assert docs[0].text == 'text 1'
72+
def test_filter_eq(doc_index, docs):
73+
filter_result = doc_index.filter({'text': {'$eq': 'text 1'}})
74+
assert len(filter_result) == 1
75+
assert filter_result[0].text == 'text 1'
76+
assert filter_result[0].text == docs[1].text
77+
assert filter_result[0].price == docs[1].price
78+
assert filter_result[0].id == docs[1].id
79+
assert np.allclose(filter_result[0].tensor, docs[1].tensor)
7680

7781

7882
def test_filter_neq(doc_index):

tests/index/hnswlib/test_index_get_del.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
211211
for d in ten_simple_docs:
212212
id_ = d.id
213213
assert index[id_].id == id_
214-
assert np.all(index[id_].tens == d.tens)
214+
assert np.allclose(index[id_].tens, d.tens)
215215

216216
# flat
217217
index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path))
@@ -221,8 +221,8 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
221221
for d in ten_flat_docs:
222222
id_ = d.id
223223
assert index[id_].id == id_
224-
assert np.all(index[id_].tens_one == d.tens_one)
225-
assert np.all(index[id_].tens_two == d.tens_two)
224+
assert np.allclose(index[id_].tens_one, d.tens_one)
225+
assert np.allclose(index[id_].tens_two, d.tens_two)
226226

227227
# nested
228228
index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path))
@@ -233,7 +233,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
233233
id_ = d.id
234234
assert index[id_].id == id_
235235
assert index[id_].d.id == d.d.id
236-
assert np.all(index[id_].d.tens == d.d.tens)
236+
assert np.allclose(index[id_].d.tens, d.d.tens)
237237

238238

239239
def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
@@ -252,7 +252,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
252252
retrieved_docs = index[ids_to_get]
253253
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
254254
assert d_out.id == id_
255-
assert np.all(d_out.tens == d_in.tens)
255+
assert np.allclose(d_out.tens, d_in.tens)
256256

257257
# flat
258258
index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path))
@@ -264,8 +264,8 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
264264
retrieved_docs = index[ids_to_get]
265265
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
266266
assert d_out.id == id_
267-
assert np.all(d_out.tens_one == d_in.tens_one)
268-
assert np.all(d_out.tens_two == d_in.tens_two)
267+
assert np.allclose(d_out.tens_one, d_in.tens_one)
268+
assert np.allclose(d_out.tens_two, d_in.tens_two)
269269

270270
# nested
271271
index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path))
@@ -278,7 +278,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
278278
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
279279
assert d_out.id == id_
280280
assert d_out.d.id == d_in.d.id
281-
assert np.all(d_out.d.tens == d_in.d.tens)
281+
assert np.allclose(d_out.d.tens, d_in.d.tens)
282282

283283

284284
def test_get_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
@@ -303,7 +303,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
303303
index[id_]
304304
else:
305305
assert index[id_].id == id_
306-
assert np.all(index[id_].tens == d.tens)
306+
assert np.allclose(index[id_].tens, d.tens)
307307
# delete again
308308
del index[ten_simple_docs[3].id]
309309
assert index.num_docs() == 8
@@ -314,7 +314,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
314314
index[id_]
315315
else:
316316
assert index[id_].id == id_
317-
assert np.all(index[id_].tens == d.tens)
317+
assert np.allclose(index[id_].tens, d.tens)
318318

319319

320320
def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
@@ -333,7 +333,7 @@ def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
333333
index[doc.id]
334334
else:
335335
assert index[doc.id].id == doc.id
336-
assert np.all(index[doc.id].tens == doc.tens)
336+
assert np.allclose(index[doc.id].tens, doc.tens)
337337

338338

339339
def test_del_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
@@ -410,5 +410,5 @@ class TextSimpleDoc(SimpleDoc):
410410
for doc in res.documents:
411411
if doc.id == docs[0].id:
412412
found = True
413-
assert (doc.tens == new_tensor).all()
413+
assert np.allclose(doc.tens, new_tensor)
414414
assert found

tests/index/hnswlib/test_persist_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_persist_and_restore(tmp_path):
2222
query = SimpleDoc(tens=np.random.random((10,)))
2323

2424
# create index
25-
index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
25+
_ = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
2626

2727
# load existing index file
2828
index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
@@ -38,7 +38,7 @@ def test_persist_and_restore(tmp_path):
3838
find_results_after = index.find(query, search_field='tens', limit=5)
3939
for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):
4040
assert doc_before.id == doc_after.id
41-
assert (doc_before.tens == doc_after.tens).all()
41+
assert np.allclose(doc_before.tens, doc_after.tens)
4242

4343
# add new data
4444
index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)])
@@ -70,7 +70,7 @@ def test_persist_and_restore_nested(tmp_path):
7070
find_results_after = index.find(query, search_field='d__tens', limit=5)
7171
for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):
7272
assert doc_before.id == doc_after.id
73-
assert (doc_before.tens == doc_after.tens).all()
73+
assert np.allclose(doc_before.tens, doc_after.tens)
7474

7575
# delete and restore
7676
index.index(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy