Skip to content

refactor: hnswlib performance #1727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 69 additions & 19 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import os
import sqlite3
from collections import defaultdict
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -32,7 +32,9 @@
_raise_not_composable,
_raise_not_supported,
)
from docarray.index.backends.helper import _collect_query_args
from docarray.index.backends.helper import (
_collect_query_args,
)
from docarray.proto import DocProto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.ndarray import NdArray
Expand Down Expand Up @@ -63,7 +65,6 @@
HNSWLIB_PY_VEC_TYPES.append(tf.Tensor)
HNSWLIB_PY_VEC_TYPES.append(TensorFlowTensor)


TSchema = TypeVar('TSchema', bound=BaseDoc)
T = TypeVar('T', bound='HnswDocumentIndex')

Expand Down Expand Up @@ -107,7 +108,11 @@ def __init__(self, db_config=None, **kwargs):
if col.config
}
self._hnsw_indices = {}
sub_docs_exist = False
cosine_metric_index_exist = False
for col_name, col in self._column_infos.items():
if '__' in col_name:
sub_docs_exist = True
if safe_issubclass(col.docarray_type, AnyDocArray):
continue
if not col.config:
Expand All @@ -127,7 +132,12 @@ def __init__(self, db_config=None, **kwargs):
else:
self._hnsw_indices[col_name] = self._create_index(col_name, col)
self._logger.info(f'Created a new index for column `{col_name}`')
if self._hnsw_indices[col_name].space == 'cosine':
cosine_metric_index_exist = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we care about this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because for cosine, HNSWLib normalizes the vectors, and then if we retrieve, they have chanded, so no consistent API can be provided


self._apply_optim_no_embedding_in_sqlite = (
not sub_docs_exist and not cosine_metric_index_exist
) # optimization consisting in not serializing embeddings to SQLite because they are expensive to send and they can be reconstructed from the HNSW index itself.
# SQLite setup
self._sqlite_db_path = os.path.join(self._work_dir, 'docs_sqlite.db')
self._logger.debug(f'DB path set to {self._sqlite_db_path}')
Expand Down Expand Up @@ -276,9 +286,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
docs_validated = self._validate_docs(docs)
self._update_subindex_data(docs_validated)
data_by_columns = self._get_col_value_dict(docs_validated)

self._index(data_by_columns, docs_validated, **kwargs)

self._send_docs_to_sqlite(docs_validated)
self._sqlite_conn.commit()
self._num_docs = 0 # recompute again when needed
Expand Down Expand Up @@ -332,7 +340,19 @@ def _filter(
limit: int,
) -> DocList:
rows = self._execute_filter(filter_query=filter_query, limit=limit)
return DocList[self.out_schema](self._doc_from_bytes(blob) for _, blob in rows) # type: ignore[name-defined]
hashed_ids = [doc_id for doc_id, _ in rows]
embeddings: OrderedDict[str, list] = OrderedDict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why specifically OrderedDict? I think normal dict in python will already be ordered by default

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still like to be specific about the requirement for order

for col_name, index in self._hnsw_indices.items():
embeddings[col_name] = index.get_items(hashed_ids)

docs = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))()
for i, row in enumerate(rows):
reconstruct_embeddings = {}
for col_name in embeddings.keys():
reconstruct_embeddings[col_name] = embeddings[col_name][i]
docs.append(self._doc_from_bytes(row[1], reconstruct_embeddings))

return docs

def _filter_batched(
self,
Expand Down Expand Up @@ -501,12 +521,24 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True):
assert isinstance(id_, int) or is_np_int(id_)
sql_id_list = '(' + ', '.join(str(id_) for id_ in univ_ids) + ')'
self._sqlite_cursor.execute(
'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list,
'SELECT doc_id, data FROM docs WHERE doc_id IN %s' % sql_id_list,
)
rows = self._sqlite_cursor.fetchall()
rows = (
self._sqlite_cursor.fetchall()
) # doc_ids do not come back in the same order
embeddings: OrderedDict[str, list] = OrderedDict()
for col_name, index in self._hnsw_indices.items():
embeddings[col_name] = index.get_items([row[0] for row in rows])

schema = self.out_schema if out else self._schema
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], schema))
return docs_cls([self._doc_from_bytes(row[0], out) for row in rows])
docs = DocList.__class_getitem__(cast(Type[BaseDoc], schema))()
for i, (_, data_bytes) in enumerate(rows):
reconstruct_embeddings = {}
for col_name in embeddings.keys():
reconstruct_embeddings[col_name] = embeddings[col_name][i]
docs.append(self._doc_from_bytes(data_bytes, reconstruct_embeddings, out))

return docs

def _get_docs_sqlite_doc_id(
self, doc_ids: Sequence[str], out: bool = True
Expand Down Expand Up @@ -541,12 +573,32 @@ def _get_num_docs_sqlite(self) -> int:

# serialization helpers
def _doc_to_bytes(self, doc: BaseDoc) -> bytes:
return doc.to_protobuf().SerializeToString()

def _doc_from_bytes(self, data: bytes, out: bool = True) -> BaseDoc:
pb = doc.to_protobuf()
if self._apply_optim_no_embedding_in_sqlite:
for col_name in self._hnsw_indices.keys():
pb.data[col_name].Clear()
pb.data[col_name].Clear()
return pb.SerializeToString()

def _doc_from_bytes(
self, data: bytes, reconstruct_embeddings: Dict, out: bool = True
) -> BaseDoc:
schema = self.out_schema if out else self._schema
schema_cls = cast(Type[BaseDoc], schema)
return schema_cls.from_protobuf(DocProto.FromString(data))
pb = DocProto.FromString(
data
) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional
if self._apply_optim_no_embedding_in_sqlite:
for k, v in reconstruct_embeddings.items():
node_proto = (
schema_cls._get_field_type(k)
._docarray_from_ndarray(np.array(v))
._to_node_protobuf()
)
pb.data[k].MergeFrom(node_proto)

doc = schema_cls.from_protobuf(pb)
return doc

def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
"""Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.
Expand Down Expand Up @@ -608,25 +660,24 @@ def _search_and_filter(
return _FindResultBatched(documents=[], scores=[]) # type: ignore

# Set limit as the minimum of the provided limit and the total number of documents
limit = min(limit, self.num_docs())
limit = limit

# Ensure the search field is in the HNSW indices
if search_field not in self._hnsw_indices:
raise ValueError(
f'Search field {search_field} is not present in the HNSW indices'
)

index = self._hnsw_indices[search_field]

def accept_hashed_ids(id):
"""Accepts IDs that are in hashed_ids."""
return id in hashed_ids # type: ignore[operator]

# Choose the appropriate filter function based on whether hashed_ids was provided
extra_kwargs = {'filter': accept_hashed_ids} if hashed_ids else {}

# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
k = min(limit, len(hashed_ids)) if hashed_ids else limit
index = self._hnsw_indices[search_field]

try:
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
except RuntimeError:
Expand All @@ -639,7 +690,6 @@ def accept_hashed_ids(id):
)
for ids_per_query in labels
]

return _FindResultBatched(documents=result_das, scores=distances)

@classmethod
Expand Down
6 changes: 0 additions & 6 deletions tests/index/elastic/v7/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,3 @@ services:
- ES_JAVA_OPTS=-Xmx1024m
ports:
- "9200:9200"
networks:
- elastic

networks:
elastic:
name: elastic
6 changes: 0 additions & 6 deletions tests/index/elastic/v8/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,3 @@ services:
- ES_JAVA_OPTS=-Xmx1024m
ports:
- "9200:9200"
networks:
- elastic

networks:
elastic:
name: elastic
12 changes: 8 additions & 4 deletions tests/index/hnswlib/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,14 @@ def test_build_query_invalid_query():
HnswDocumentIndex._build_filter_query(query, param_values)


def test_filter_eq(doc_index):
docs = doc_index.filter({'text': {'$eq': 'text 1'}})
assert len(docs) == 1
assert docs[0].text == 'text 1'
def test_filter_eq(doc_index, docs):
filter_result = doc_index.filter({'text': {'$eq': 'text 1'}})
assert len(filter_result) == 1
assert filter_result[0].text == 'text 1'
assert filter_result[0].text == docs[1].text
assert filter_result[0].price == docs[1].price
assert filter_result[0].id == docs[1].id
assert np.allclose(filter_result[0].tensor, docs[1].tensor)


def test_filter_neq(doc_index):
Expand Down
24 changes: 12 additions & 12 deletions tests/index/hnswlib/test_index_get_del.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
for d in ten_simple_docs:
id_ = d.id
assert index[id_].id == id_
assert np.all(index[id_].tens == d.tens)
assert np.allclose(index[id_].tens, d.tens)

# flat
index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path))
Expand All @@ -221,8 +221,8 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
for d in ten_flat_docs:
id_ = d.id
assert index[id_].id == id_
assert np.all(index[id_].tens_one == d.tens_one)
assert np.all(index[id_].tens_two == d.tens_two)
assert np.allclose(index[id_].tens_one, d.tens_one)
assert np.allclose(index[id_].tens_two, d.tens_two)

# nested
index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path))
Expand All @@ -233,7 +233,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
id_ = d.id
assert index[id_].id == id_
assert index[id_].d.id == d.d.id
assert np.all(index[id_].d.tens == d.d.tens)
assert np.allclose(index[id_].d.tens, d.d.tens)


def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
Expand All @@ -252,7 +252,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
retrieved_docs = index[ids_to_get]
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
assert d_out.id == id_
assert np.all(d_out.tens == d_in.tens)
assert np.allclose(d_out.tens, d_in.tens)

# flat
index = HnswDocumentIndex[FlatDoc](work_dir=str(flat_path))
Expand All @@ -264,8 +264,8 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
retrieved_docs = index[ids_to_get]
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
assert d_out.id == id_
assert np.all(d_out.tens_one == d_in.tens_one)
assert np.all(d_out.tens_two == d_in.tens_two)
assert np.allclose(d_out.tens_one, d_in.tens_one)
assert np.allclose(d_out.tens_two, d_in.tens_two)

# nested
index = HnswDocumentIndex[NestedDoc](work_dir=str(nested_path))
Expand All @@ -278,7 +278,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs):
assert d_out.id == id_
assert d_out.d.id == d_in.d.id
assert np.all(d_out.d.tens == d_in.d.tens)
assert np.allclose(d_out.d.tens, d_in.d.tens)


def test_get_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
Expand All @@ -303,7 +303,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
index[id_]
else:
assert index[id_].id == id_
assert np.all(index[id_].tens == d.tens)
assert np.allclose(index[id_].tens, d.tens)
# delete again
del index[ten_simple_docs[3].id]
assert index.num_docs() == 8
Expand All @@ -314,7 +314,7 @@ def test_del_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
index[id_]
else:
assert index[id_].id == id_
assert np.all(index[id_].tens == d.tens)
assert np.allclose(index[id_].tens, d.tens)


def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
Expand All @@ -333,7 +333,7 @@ def test_del_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path)
index[doc.id]
else:
assert index[doc.id].id == doc.id
assert np.all(index[doc.id].tens == doc.tens)
assert np.allclose(index[doc.id].tens, doc.tens)


def test_del_key_error(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path):
Expand Down Expand Up @@ -410,5 +410,5 @@ class TextSimpleDoc(SimpleDoc):
for doc in res.documents:
if doc.id == docs[0].id:
found = True
assert (doc.tens == new_tensor).all()
assert np.allclose(doc.tens, new_tensor)
assert found
6 changes: 3 additions & 3 deletions tests/index/hnswlib/test_persist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_persist_and_restore(tmp_path):
query = SimpleDoc(tens=np.random.random((10,)))

# create index
index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
_ = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))

# load existing index file
index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
Expand All @@ -38,7 +38,7 @@ def test_persist_and_restore(tmp_path):
find_results_after = index.find(query, search_field='tens', limit=5)
for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):
assert doc_before.id == doc_after.id
assert (doc_before.tens == doc_after.tens).all()
assert np.allclose(doc_before.tens, doc_after.tens)

# add new data
index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)])
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_persist_and_restore_nested(tmp_path):
find_results_after = index.find(query, search_field='d__tens', limit=5)
for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):
assert doc_before.id == doc_after.id
assert (doc_before.tens == doc_after.tens).all()
assert np.allclose(doc_before.tens, doc_after.tens)

# delete and restore
index.index(
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy