2
2
import hashlib
3
3
import os
4
4
import sqlite3
5
- from collections import defaultdict
5
+ from collections import OrderedDict , defaultdict
6
6
from dataclasses import dataclass , field
7
7
from pathlib import Path
8
8
from typing import (
32
32
_raise_not_composable ,
33
33
_raise_not_supported ,
34
34
)
35
- from docarray .index .backends .helper import _collect_query_args
35
+ from docarray .index .backends .helper import (
36
+ _collect_query_args ,
37
+ )
36
38
from docarray .proto import DocProto
37
39
from docarray .typing .tensor .abstract_tensor import AbstractTensor
38
40
from docarray .typing .tensor .ndarray import NdArray
63
65
HNSWLIB_PY_VEC_TYPES .append (tf .Tensor )
64
66
HNSWLIB_PY_VEC_TYPES .append (TensorFlowTensor )
65
67
66
-
67
68
TSchema = TypeVar ('TSchema' , bound = BaseDoc )
68
69
T = TypeVar ('T' , bound = 'HnswDocumentIndex' )
69
70
@@ -107,7 +108,11 @@ def __init__(self, db_config=None, **kwargs):
107
108
if col .config
108
109
}
109
110
self ._hnsw_indices = {}
111
+ sub_docs_exist = False
112
+ cosine_metric_index_exist = False
110
113
for col_name , col in self ._column_infos .items ():
114
+ if '__' in col_name :
115
+ sub_docs_exist = True
111
116
if safe_issubclass (col .docarray_type , AnyDocArray ):
112
117
continue
113
118
if not col .config :
@@ -127,7 +132,12 @@ def __init__(self, db_config=None, **kwargs):
127
132
else :
128
133
self ._hnsw_indices [col_name ] = self ._create_index (col_name , col )
129
134
self ._logger .info (f'Created a new index for column `{ col_name } `' )
135
+ if self ._hnsw_indices [col_name ].space == 'cosine' :
136
+ cosine_metric_index_exist = True
130
137
138
+ self ._apply_optim_no_embedding_in_sqlite = (
139
+ not sub_docs_exist and not cosine_metric_index_exist
140
+ ) # optimization consisting in not serializing embeddings to SQLite because they are expensive to send and they can be reconstructed from the HNSW index itself.
131
141
# SQLite setup
132
142
self ._sqlite_db_path = os .path .join (self ._work_dir , 'docs_sqlite.db' )
133
143
self ._logger .debug (f'DB path set to { self ._sqlite_db_path } ' )
@@ -276,9 +286,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
276
286
docs_validated = self ._validate_docs (docs )
277
287
self ._update_subindex_data (docs_validated )
278
288
data_by_columns = self ._get_col_value_dict (docs_validated )
279
-
280
289
self ._index (data_by_columns , docs_validated , ** kwargs )
281
-
282
290
self ._send_docs_to_sqlite (docs_validated )
283
291
self ._sqlite_conn .commit ()
284
292
self ._num_docs = 0 # recompute again when needed
@@ -332,7 +340,19 @@ def _filter(
332
340
limit : int ,
333
341
) -> DocList :
334
342
rows = self ._execute_filter (filter_query = filter_query , limit = limit )
335
- return DocList [self .out_schema ](self ._doc_from_bytes (blob ) for _ , blob in rows ) # type: ignore[name-defined]
343
+ hashed_ids = [doc_id for doc_id , _ in rows ]
344
+ embeddings : OrderedDict [str , list ] = OrderedDict ()
345
+ for col_name , index in self ._hnsw_indices .items ():
346
+ embeddings [col_name ] = index .get_items (hashed_ids )
347
+
348
+ docs = DocList .__class_getitem__ (cast (Type [BaseDoc ], self .out_schema ))()
349
+ for i , row in enumerate (rows ):
350
+ reconstruct_embeddings = {}
351
+ for col_name in embeddings .keys ():
352
+ reconstruct_embeddings [col_name ] = embeddings [col_name ][i ]
353
+ docs .append (self ._doc_from_bytes (row [1 ], reconstruct_embeddings ))
354
+
355
+ return docs
336
356
337
357
def _filter_batched (
338
358
self ,
@@ -501,12 +521,24 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True):
501
521
assert isinstance (id_ , int ) or is_np_int (id_ )
502
522
sql_id_list = '(' + ', ' .join (str (id_ ) for id_ in univ_ids ) + ')'
503
523
self ._sqlite_cursor .execute (
504
- 'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list ,
524
+ 'SELECT doc_id, data FROM docs WHERE doc_id IN %s' % sql_id_list ,
505
525
)
506
- rows = self ._sqlite_cursor .fetchall ()
526
+ rows = (
527
+ self ._sqlite_cursor .fetchall ()
528
+ ) # doc_ids do not come back in the same order
529
+ embeddings : OrderedDict [str , list ] = OrderedDict ()
530
+ for col_name , index in self ._hnsw_indices .items ():
531
+ embeddings [col_name ] = index .get_items ([row [0 ] for row in rows ])
532
+
507
533
schema = self .out_schema if out else self ._schema
508
- docs_cls = DocList .__class_getitem__ (cast (Type [BaseDoc ], schema ))
509
- return docs_cls ([self ._doc_from_bytes (row [0 ], out ) for row in rows ])
534
+ docs = DocList .__class_getitem__ (cast (Type [BaseDoc ], schema ))()
535
+ for i , (_ , data_bytes ) in enumerate (rows ):
536
+ reconstruct_embeddings = {}
537
+ for col_name in embeddings .keys ():
538
+ reconstruct_embeddings [col_name ] = embeddings [col_name ][i ]
539
+ docs .append (self ._doc_from_bytes (data_bytes , reconstruct_embeddings , out ))
540
+
541
+ return docs
510
542
511
543
def _get_docs_sqlite_doc_id (
512
544
self , doc_ids : Sequence [str ], out : bool = True
@@ -541,12 +573,32 @@ def _get_num_docs_sqlite(self) -> int:
541
573
542
574
# serialization helpers
543
575
def _doc_to_bytes (self , doc : BaseDoc ) -> bytes :
544
- return doc .to_protobuf ().SerializeToString ()
545
-
546
- def _doc_from_bytes (self , data : bytes , out : bool = True ) -> BaseDoc :
576
+ pb = doc .to_protobuf ()
577
+ if self ._apply_optim_no_embedding_in_sqlite :
578
+ for col_name in self ._hnsw_indices .keys ():
579
+ pb .data [col_name ].Clear ()
580
+ pb .data [col_name ].Clear ()
581
+ return pb .SerializeToString ()
582
+
583
+ def _doc_from_bytes (
584
+ self , data : bytes , reconstruct_embeddings : Dict , out : bool = True
585
+ ) -> BaseDoc :
547
586
schema = self .out_schema if out else self ._schema
548
587
schema_cls = cast (Type [BaseDoc ], schema )
549
- return schema_cls .from_protobuf (DocProto .FromString (data ))
588
+ pb = DocProto .FromString (
589
+ data
590
+ ) # I cannot reconstruct directly the DA object because it may fail at validation because embedding may not be Optional
591
+ if self ._apply_optim_no_embedding_in_sqlite :
592
+ for k , v in reconstruct_embeddings .items ():
593
+ node_proto = (
594
+ schema_cls ._get_field_type (k )
595
+ ._docarray_from_ndarray (np .array (v ))
596
+ ._to_node_protobuf ()
597
+ )
598
+ pb .data [k ].MergeFrom (node_proto )
599
+
600
+ doc = schema_cls .from_protobuf (pb )
601
+ return doc
550
602
551
603
def _get_root_doc_id (self , id : str , root : str , sub : str ) -> str :
552
604
"""Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.
@@ -608,25 +660,24 @@ def _search_and_filter(
608
660
return _FindResultBatched (documents = [], scores = []) # type: ignore
609
661
610
662
# Set limit as the minimum of the provided limit and the total number of documents
611
- limit = min ( limit , self . num_docs ())
663
+ limit = limit
612
664
613
665
# Ensure the search field is in the HNSW indices
614
666
if search_field not in self ._hnsw_indices :
615
667
raise ValueError (
616
668
f'Search field { search_field } is not present in the HNSW indices'
617
669
)
618
670
619
- index = self ._hnsw_indices [search_field ]
620
-
621
671
def accept_hashed_ids (id ):
622
672
"""Accepts IDs that are in hashed_ids."""
623
673
return id in hashed_ids # type: ignore[operator]
624
674
625
- # Choose the appropriate filter function based on whether hashed_ids was provided
626
675
extra_kwargs = {'filter' : accept_hashed_ids } if hashed_ids else {}
627
676
628
677
# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
629
678
k = min (limit , len (hashed_ids )) if hashed_ids else limit
679
+ index = self ._hnsw_indices [search_field ]
680
+
630
681
try :
631
682
labels , distances = index .knn_query (queries , k = k , ** extra_kwargs )
632
683
except RuntimeError :
@@ -639,7 +690,6 @@ def accept_hashed_ids(id):
639
690
)
640
691
for ids_per_query in labels
641
692
]
642
-
643
693
return _FindResultBatched (documents = result_das , scores = distances )
644
694
645
695
@classmethod
0 commit comments