-
Notifications
You must be signed in to change notification settings - Fork 232
refactor: hnswlib performance #1727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9aebf2e
02849c4
3f56d44
f8bb5f9
ae33fe6
a91e23e
fa9b9ce
79b0fea
58ee168
83e55f5
281cf7e
2ea5804
1c8337e
8c0bf6c
c062c8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import hashlib | ||
import os | ||
import sqlite3 | ||
from collections import defaultdict | ||
from collections import OrderedDict, defaultdict | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import ( | ||
|
@@ -32,7 +32,9 @@ | |
_raise_not_composable, | ||
_raise_not_supported, | ||
) | ||
from docarray.index.backends.helper import _collect_query_args | ||
from docarray.index.backends.helper import ( | ||
_collect_query_args, | ||
) | ||
from docarray.proto import DocProto | ||
from docarray.typing.tensor.abstract_tensor import AbstractTensor | ||
from docarray.typing.tensor.ndarray import NdArray | ||
|
@@ -63,7 +65,6 @@ | |
HNSWLIB_PY_VEC_TYPES.append(tf.Tensor) | ||
HNSWLIB_PY_VEC_TYPES.append(TensorFlowTensor) | ||
|
||
|
||
TSchema = TypeVar('TSchema', bound=BaseDoc) | ||
T = TypeVar('T', bound='HnswDocumentIndex') | ||
|
||
|
@@ -107,7 +108,11 @@ def __init__(self, db_config=None, **kwargs): | |
if col.config | ||
} | ||
self._hnsw_indices = {} | ||
sub_docs_exist = False | ||
cosine_metric_index_exist = False | ||
for col_name, col in self._column_infos.items(): | ||
if '__' in col_name: | ||
sub_docs_exist = True | ||
if safe_issubclass(col.docarray_type, AnyDocArray): | ||
continue | ||
if not col.config: | ||
|
@@ -127,7 +132,12 @@ def __init__(self, db_config=None, **kwargs): | |
else: | ||
self._hnsw_indices[col_name] = self._create_index(col_name, col) | ||
self._logger.info(f'Created a new index for column `{col_name}`') | ||
if self._hnsw_indices[col_name].space == 'cosine': | ||
cosine_metric_index_exist = True | ||
|
||
self._apply_optim_no_embedding_in_sqlite = ( | ||
not sub_docs_exist and not cosine_metric_index_exist | ||
) # optimization consisting in not serializing embeddings to SQLite because they are expensive to send and they can be reconstructed from the HNSW index itself. | ||
# SQLite setup | ||
self._sqlite_db_path = os.path.join(self._work_dir, 'docs_sqlite.db') | ||
self._logger.debug(f'DB path set to {self._sqlite_db_path}') | ||
|
@@ -276,9 +286,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs): | |
docs_validated = self._validate_docs(docs) | ||
self._update_subindex_data(docs_validated) | ||
data_by_columns = self._get_col_value_dict(docs_validated) | ||
|
||
self._index(data_by_columns, docs_validated, **kwargs) | ||
|
||
self._send_docs_to_sqlite(docs_validated) | ||
self._sqlite_conn.commit() | ||
self._num_docs = 0 # recompute again when needed | ||
|
@@ -332,7 +340,19 @@ def _filter( | |
limit: int, | ||
) -> DocList: | ||
rows = self._execute_filter(filter_query=filter_query, limit=limit) | ||
return DocList[self.out_schema](self._doc_from_bytes(blob) for _, blob in rows) # type: ignore[name-defined] | ||
hashed_ids = [doc_id for doc_id, _ in rows] | ||
embeddings: OrderedDict[str, list] = OrderedDict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why specifically OrderedDict? I think normal dict in python will already be ordered by default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would still like to be specific about the requirement for order |
||
for col_name, index in self._hnsw_indices.items(): | ||
embeddings[col_name] = index.get_items(hashed_ids) | ||
|
||
docs = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))() | ||
for i, row in enumerate(rows): | ||
reconstruct_embeddings = {} | ||
for col_name in embeddings.keys(): | ||
reconstruct_embeddings[col_name] = embeddings[col_name][i] | ||
docs.append(self._doc_from_bytes(row[1], reconstruct_embeddings)) | ||
|
||
return docs | ||
|
||
def _filter_batched( | ||
self, | ||
|
@@ -501,12 +521,24 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True): | |
assert isinstance(id_, int) or is_np_int(id_) | ||
sql_id_list = '(' + ', '.join(str(id_) for id_ in univ_ids) + ')' | ||
self._sqlite_cursor.execute( | ||
'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list, | ||
'SELECT doc_id, data FROM docs WHERE doc_id IN %s' % sql_id_list, | ||
) | ||
rows = self._sqlite_cursor.fetchall() | ||
rows = ( | ||
self._sqlite_cursor.fetchall() | ||
) # doc_ids do not come back in the same order | ||
embeddings: OrderedDict[str, list] = OrderedDict() | ||
for col_name, index in self._hnsw_indices.items(): | ||
embeddings[col_name] = index.get_items([row[0] for row in rows]) | ||
|
||
schema = self.out_schema if out else self._schema | ||
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], schema)) | ||
return docs_cls([self._doc_from_bytes(row[0], out) for row in rows]) | ||
docs = DocList.__class_getitem__(cast(Type[BaseDoc], schema))() | ||
for i, (_, data_bytes) in enumerate(rows): | ||
reconstruct_embeddings = {} | ||
for col_name in embeddings.keys(): | ||
reconstruct_embeddings[col_name] = embeddings[col_name][i] | ||
docs.append(self._doc_from_bytes(data_bytes, reconstruct_embeddings, out)) | ||
|
||
return docs | ||
|
||
def _get_docs_sqlite_doc_id( | ||
self, doc_ids: Sequence[str], out: bool = True | ||
|
@@ -541,12 +573,32 @@ def _get_num_docs_sqlite(self) -> int: | |
|
||
# serialization helpers | ||
def _doc_to_bytes(self, doc: BaseDoc) -> bytes: | ||
return doc.to_protobuf().SerializeToString() | ||
|
||
def _doc_from_bytes(self, data: bytes, out: bool = True) -> BaseDoc: | ||
pb = doc.to_protobuf() | ||
if self._apply_optim_no_embedding_in_sqlite: | ||
for col_name in self._hnsw_indices.keys(): | ||
pb.data[col_name].Clear() | ||
pb.data[col_name].Clear() | ||
return pb.SerializeToString() | ||
|
||
def _doc_from_bytes( | ||
self, data: bytes, reconstruct_embeddings: Dict, out: bool = True | ||
) -> BaseDoc: | ||
schema = self.out_schema if out else self._schema | ||
schema_cls = cast(Type[BaseDoc], schema) | ||
return schema_cls.from_protobuf(DocProto.FromString(data)) | ||
pb = DocProto.FromString( | ||
data | ||
) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional | ||
if self._apply_optim_no_embedding_in_sqlite: | ||
for k, v in reconstruct_embeddings.items(): | ||
node_proto = ( | ||
schema_cls._get_field_type(k) | ||
._docarray_from_ndarray(np.array(v)) | ||
._to_node_protobuf() | ||
) | ||
pb.data[k].MergeFrom(node_proto) | ||
|
||
doc = schema_cls.from_protobuf(pb) | ||
return doc | ||
|
||
def _get_root_doc_id(self, id: str, root: str, sub: str) -> str: | ||
"""Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib. | ||
|
@@ -608,25 +660,24 @@ def _search_and_filter( | |
return _FindResultBatched(documents=[], scores=[]) # type: ignore | ||
|
||
# Set limit as the minimum of the provided limit and the total number of documents | ||
limit = min(limit, self.num_docs()) | ||
limit = limit | ||
|
||
# Ensure the search field is in the HNSW indices | ||
if search_field not in self._hnsw_indices: | ||
raise ValueError( | ||
f'Search field {search_field} is not present in the HNSW indices' | ||
) | ||
|
||
index = self._hnsw_indices[search_field] | ||
|
||
def accept_hashed_ids(id): | ||
"""Accepts IDs that are in hashed_ids.""" | ||
return id in hashed_ids # type: ignore[operator] | ||
|
||
# Choose the appropriate filter function based on whether hashed_ids was provided | ||
extra_kwargs = {'filter': accept_hashed_ids} if hashed_ids else {} | ||
|
||
# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit | ||
k = min(limit, len(hashed_ids)) if hashed_ids else limit | ||
index = self._hnsw_indices[search_field] | ||
|
||
try: | ||
labels, distances = index.knn_query(queries, k=k, **extra_kwargs) | ||
except RuntimeError: | ||
|
@@ -639,7 +690,6 @@ def accept_hashed_ids(id): | |
) | ||
for ids_per_query in labels | ||
] | ||
|
||
return _FindResultBatched(documents=result_das, scores=distances) | ||
|
||
@classmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,3 @@ services: | |
- ES_JAVA_OPTS=-Xmx1024m | ||
ports: | ||
- "9200:9200" | ||
networks: | ||
- elastic | ||
|
||
networks: | ||
elastic: | ||
name: elastic |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,3 @@ services: | |
- ES_JAVA_OPTS=-Xmx1024m | ||
ports: | ||
- "9200:9200" | ||
networks: | ||
- elastic | ||
|
||
networks: | ||
elastic: | ||
name: elastic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we care about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because for
cosine
, HNSWLibnormalizes
the vectors, and then if we retrieve, they have chanded, so no consistent API can be provided