From cd516248b217b6a7d836dfa943312dcaa71e6e83 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Thu, 20 Jul 2023 20:17:57 +0200 Subject: [PATCH 1/5] feat: filtering in hnsw Signed-off-by: jupyterjazz --- docarray/index/backends/hnswlib.py | 430 +++++++++++++++++++--- tests/index/hnswlib/test_filter.py | 158 ++++++++ tests/index/hnswlib/test_find.py | 27 -- tests/index/hnswlib/test_query_builder.py | 186 ++++++++++ 4 files changed, 725 insertions(+), 76 deletions(-) create mode 100644 tests/index/hnswlib/test_filter.py create mode 100644 tests/index/hnswlib/test_query_builder.py diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index d4929569c63..7d492cb0031 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -2,6 +2,7 @@ import hashlib import os import sqlite3 +from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import ( @@ -13,6 +14,7 @@ List, Optional, Sequence, + Set, Tuple, Type, TypeVar, @@ -30,16 +32,14 @@ _raise_not_composable, _raise_not_supported, ) -from docarray.index.backends.helper import ( - _collect_query_args, - _execute_find_and_filter_query, -) +from docarray.index.backends.helper import _collect_query_args from docarray.proto import DocProto from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.typing.tensor.ndarray import NdArray from docarray.utils._internal._typing import safe_issubclass from docarray.utils._internal.misc import import_library, is_np_int -from docarray.utils.find import _FindResult, _FindResultBatched +from docarray.utils.filter import filter_docs +from docarray.utils.find import FindResult, _FindResult, _FindResultBatched if TYPE_CHECKING: import hnswlib @@ -67,6 +67,15 @@ TSchema = TypeVar('TSchema', bound=BaseDoc) T = TypeVar('T', bound='HnswDocumentIndex') +OPERATOR_MAPPING = { + '$eq': '=', + '$neq': '!=', + '$lt': '<', + '$lte': '<=', + '$gt': '>', + '$gte': '>=', +} + class HnswDocumentIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): @@ -125,6 +134,7 @@ def __init__(self, db_config=None, **kwargs): self._sqlite_conn = sqlite3.connect(self._sqlite_db_path) self._logger.info('Connection to DB has been established') self._sqlite_cursor = self._sqlite_conn.cursor() + self._column_names: List[str] = [] self._create_docs_table() self._sqlite_conn.commit() self._num_docs = self._get_num_docs_sqlite() @@ -167,21 +177,22 @@ class DBConfig(BaseDocIndex.DBConfig): work_dir: str = '.' default_column_config: Dict[Type, Dict[str, Any]] = field( - default_factory=lambda: { - np.ndarray: { - 'dim': -1, - 'index': True, # if False, don't index at all - 'space': 'l2', # 'l2', 'ip', 'cosine' - 'max_elements': 1024, - 'ef_construction': 200, - 'ef': 10, - 'M': 16, - 'allow_replace_deleted': True, - 'num_threads': 1, + default_factory=lambda: defaultdict( + dict, + { + np.ndarray: { + 'dim': -1, + 'index': True, # if False, don't index at all + 'space': 'l2', # 'l2', 'ip', 'cosine' + 'max_elements': 1024, + 'ef_construction': 200, + 'ef': 10, + 'M': 16, + 'allow_replace_deleted': True, + 'num_threads': 1, + }, }, - # `None` is not a Type, but we allow it here anyway - None: {}, # type: ignore - } + ) ) @dataclass @@ -206,6 +217,15 @@ def python_type_to_db_type(self, python_type: Type) -> Any: if safe_issubclass(python_type, allowed_type): return np.ndarray + type_map = { + int: 'INTEGER', + float: 'REAL', + str: 'TEXT', + } + for py_type, sqlite_type in type_map.items(): + if safe_issubclass(python_type, py_type): + return sqlite_type + return None # all types allowed, but no db type needed def _index( @@ -281,11 +301,8 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: raise ValueError( f'args and kwargs not supported for `execute_query` on {type(self)}' ) - find_res = _execute_find_and_filter_query( - doc_index=self, - query=query, - ) - return find_res + + return self._execute_find_and_filter_query(query) def _find_batched( self, @@ -293,20 +310,9 @@ def _find_batched( limit: int, search_field: str = '', ) -> _FindResultBatched: - if self.num_docs() == 0: - return _FindResultBatched(documents=[], scores=[]) # type: ignore - - limit = min(limit, self.num_docs()) - - index = self._hnsw_indices[search_field] - labels, distances = index.knn_query(queries, k=int(limit)) - result_das = [ - self._get_docs_sqlite_hashed_id( - ids_per_query.tolist(), - ) - for ids_per_query in labels - ] - return _FindResultBatched(documents=result_das, scores=distances) + return self._search_and_filter( + queries=queries, limit=limit, search_field=search_field + ) def _find( self, query: np.ndarray, limit: int, search_field: str = '' @@ -327,11 +333,8 @@ def _filter( filter_query: Any, limit: int, ) -> DocList: - raise NotImplementedError( - f'{type(self)} does not support filter-only queries.' - f' To perform post-filtering on a query, use' - f' `build_query()` and `execute_query()`.' - ) + rows = self._execute_filter(filter_query=filter_query, limit=limit) + return DocList[self.out_schema](self._doc_from_bytes(blob) for _, blob in rows) # type: ignore[name-defined] def _filter_batched( self, @@ -451,17 +454,46 @@ def _create_index(self, col_name: str, col: '_ColumnInfo') -> hnswlib.Index: # SQLite helpers def _create_docs_table(self): - self._sqlite_cursor.execute( - 'CREATE TABLE IF NOT EXISTS docs (doc_id INTEGER PRIMARY KEY, data BLOB)' - ) + columns: List[Tuple[str, str]] = [] + for col, info in self._column_infos.items(): + if ( + col == 'id' + or '__' in col + or not info.db_type + or info.db_type == np.ndarray + ): + continue + columns.append((col, info.db_type)) + + columns_str = ', '.join(f'{name} {type}' for name, type in columns) + if columns_str: + columns_str = ', ' + columns_str + + query = f'CREATE TABLE IF NOT EXISTS docs (doc_id INTEGER PRIMARY KEY, data BLOB{columns_str})' + self._sqlite_cursor.execute(query) def _send_docs_to_sqlite(self, docs: Sequence[BaseDoc]): + # Generate the IDs ids = (self._to_hashed_id(doc.id) for doc in docs) - self._sqlite_cursor.executemany( - 'INSERT OR REPLACE INTO docs VALUES (?, ?)', - ((id_, self._doc_to_bytes(doc)) for id_, doc in zip(ids, docs)), + + column_names = self._get_column_names() + # Construct the field names and placeholders for the SQL query + all_fields = ', '.join(column_names) + placeholders = ', '.join(['?'] * len(column_names)) + + # Prepare the SQL statement + query = f'INSERT OR REPLACE INTO docs ({all_fields}) VALUES ({placeholders})' + + # Prepare the data for insertion + data_to_insert = ( + (id_, self._doc_to_bytes(doc)) + + tuple(getattr(doc, field) for field in column_names[2:]) + for id_, doc in zip(ids, docs) ) + # Execute the query + self._sqlite_cursor.executemany(query, data_to_insert) + def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True): for id_ in univ_ids: # I hope this protects from injection attacks @@ -540,3 +572,303 @@ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str: id, fields[0], '__'.join(fields[1:]) ) return self._get_root_doc_id(cur_root_id, root, '') + + def _get_column_names(self) -> List[str]: + """ + Retrieves the column names of the 'docs' table in the SQLite database. + The column names are cached in `self._column_names` to prevent multiple queries to the SQLite database. + + :return: A list of strings, where each string is a column name. + """ + if not self._column_names: + self._sqlite_cursor.execute('PRAGMA table_info(docs)') + info = self._sqlite_cursor.fetchall() + self._column_names = [row[1] for row in info] + return self._column_names + + def _search_and_filter( + self, + queries: np.ndarray, + limit: int, + search_field: str = '', + hashed_ids: Optional[Set[str]] = None, + ) -> _FindResultBatched: + """ + Executes a search and filter operation on the database. + + :param queries: A numpy array of queries. + :param limit: The maximum number of results to return. + :param search_field: The field to search in. + :param hashed_ids: A set of hashed IDs to filter the results with. + :return: An instance of _FindResultBatched, containing the matching + documents and their corresponding scores. + """ + # If there are no documents or hashed_ids is an empty set, return an empty _FindResultBatched + if self.num_docs() == 0 or (hashed_ids is not None and len(hashed_ids) == 0): + return _FindResultBatched(documents=[], scores=[]) # type: ignore + + # Set limit as the minimum of the provided limit and the total number of documents + limit = min(limit, self.num_docs()) + + # Ensure the search field is in the HNSW indices + if search_field not in self._hnsw_indices: + raise ValueError( + f"Search field {search_field} is not present in the HNSW indices" + ) + + index = self._hnsw_indices[search_field] + + def accept_all(id): + """Accepts all IDs.""" + return True + + def accept_hashed_ids(id): + """Accepts IDs that are in hashed_ids.""" + return id in hashed_ids + + # Choose the appropriate filter function based on whether hashed_ids was provided + filter_function = accept_hashed_ids if hashed_ids else accept_all + + # If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit + k = min(limit, len(hashed_ids)) if hashed_ids else limit + + labels, distances = index.knn_query(queries, k=k, filter=filter_function) + + result_das = [ + self._get_docs_sqlite_hashed_id( + ids_per_query.tolist(), + ) + for ids_per_query in labels + ] + + return _FindResultBatched(documents=result_das, scores=distances) + + @classmethod + def _build_filter_query( + cls, query: Union[Dict, str], param_values: List[Any] + ) -> str: + """ + Builds a filter query for database operations. + + :param query: Query for filtering. + :param param_values: A list to store the parameters for the query. + :return: A string representing a SQL filter query. + """ + if isinstance(query, dict): + if len(query) != 1: + raise ValueError("Each nested dict must have exactly one key") + + key, value = next(iter(query.items())) + + if key in ['$and', '$or']: + # Combine subqueries using the AND or OR operator + subqueries = [cls._build_filter_query(q, param_values) for q in value] + return f'({f" {key[1:].upper()} ".join(subqueries)})' + elif key == '$not': + # Negate the query + return f'NOT {cls._build_filter_query(value, param_values)}' + else: # normal field + field = key + if not isinstance(value, dict) or len(value) != 1: + raise ValueError(f"Invalid condition for field {field}") + operator_key, operator_value = next(iter(value.items())) + + if operator_key == "$exists": + # Check for the existence or non-existence of a field + if operator_value: + return f'{field} IS NOT NULL' + else: + return f'{field} IS NULL' + elif operator_key not in OPERATOR_MAPPING: + raise ValueError(f"Invalid operator {operator_key}") + else: + # If the operator is valid, create a placeholder and append the value to param_values + operator = OPERATOR_MAPPING[operator_key] + placeholder = '?' + param_values.append(operator_value) + return f'{field} {operator} {placeholder}' + else: + raise ValueError("Invalid query") + + def _execute_filter( + self, + filter_query: Any, + limit: int, + ) -> List[Tuple[str, bytes]]: + """ + Executes a filter query on the database. + + :param filter_query: Query for filtering. + :param limit: Maximum number of rows to be fetched. + :return: A list of rows fetched from the database. + """ + param_values: List[Any] = [] + sql_query = self._build_filter_query(filter_query, param_values) + sql_query = f'SELECT doc_id, data FROM docs WHERE {sql_query} LIMIT {limit}' + return self._sqlite_cursor.execute(sql_query, param_values).fetchall() + + def _execute_find_and_filter_query( + self, query: List[Tuple[str, Dict]] + ) -> FindResult: + """ + Executes a query to find and filter documents. + + :param query: A list of operations and their corresponding arguments. + :return: A FindResult object containing filtered documents and their scores. + """ + # Dictionary to store the score of each document + doc_to_score: Dict[BaseDoc, Any] = {} + + # Pre- and post-filter conditions + pre_filters: Dict[str, Dict] = {} + post_filters: Dict[str, Dict] = {} + + # Define filter limits + pre_filter_limit = self.num_docs() + post_filter_limit = self.num_docs() + + find_executed: bool = False + + # Document list with output schema + out_docs: DocList = DocList[self.out_schema]() # type: ignore[name-defined] + + for op, op_kwargs in query: + if op == 'find': + hashed_ids: Optional[Set[str]] = None + if pre_filters: + hashed_ids = self._pre_filtering(pre_filters, pre_filter_limit) + + query_vector = self._get_vector_for_query_builder(op_kwargs) + # Perform search and filter if hashed_ids returned by pre-filtering is not empty + if not (pre_filters and not hashed_ids): + # Returns batched output, so we need to get the first lists + out_docs, scores = self._search_and_filter( # type: ignore[assignment] + queries=query_vector, + limit=op_kwargs.get('limit', self.num_docs()), + search_field=op_kwargs['search_field'], + hashed_ids=hashed_ids, + ) + out_docs = DocList[self.out_schema](out_docs[0]) # type: ignore[name-defined] + doc_to_score.update(zip(out_docs.__getattribute__('id'), scores[0])) + find_executed = True + elif op == 'filter': + if find_executed: + post_filters, post_filter_limit = self._update_filter_conditions( + post_filters, op_kwargs, post_filter_limit + ) + else: + pre_filters, pre_filter_limit = self._update_filter_conditions( + pre_filters, op_kwargs, pre_filter_limit + ) + else: + raise ValueError(f'Query operation is not supported: {op}') + + if post_filters: + out_docs = self._post_filtering( + out_docs, post_filters, post_filter_limit, find_executed + ) + + return self._prepare_out_docs(out_docs, doc_to_score) + + def _update_filter_conditions( + self, filter_conditions: Dict, operation_args: Dict, filter_limit: int + ) -> Tuple[Dict, int]: + """ + Updates filter conditions based on the operation arguments and updates the filter limit. + + :param filter_conditions: Current filter conditions. + :param operation_args: Arguments of the operation to be executed. + :param filter_limit: Current filter limit. + :return: Updated filter conditions and filter limit. + """ + # Use '$and' operator if filter_conditions is not empty, else use operation_args['filter_query'] + updated_filter_conditions = ( + {'$and': {**filter_conditions, **operation_args['filter_query']}} + if filter_conditions + else operation_args['filter_query'] + ) + # Update filter limit based on the operation_args limit + updated_filter_limit = min( + filter_limit, operation_args.get('limit', filter_limit) + ) + return updated_filter_conditions, updated_filter_limit + + def _pre_filtering( + self, pre_filters: Dict[str, Dict], pre_filter_limit: int + ) -> Set[str]: + """ + Performs pre-filtering on the data. + + :param pre_filters: Filter conditions. + :param pre_filter_limit: Limit for the filtering. + :return: A set of hashed IDs from the filtered rows. + """ + rows = self._execute_filter(filter_query=pre_filters, limit=pre_filter_limit) + return set(hashed_id for hashed_id, _ in rows) + + def _get_vector_for_query_builder(self, find_args: Dict[str, Any]) -> np.ndarray: + """ + Prepares the query vector for search operation. + + :param find_args: Arguments for the 'find' operation. + :return: A numpy array representing the query vector. + """ + if isinstance(find_args['query'], BaseDoc): + query_vec = self._get_values_by_column( + [find_args['query']], find_args['search_field'] + )[0] + else: + query_vec = find_args['query'] + query_vec_np = self._to_numpy(query_vec) + query_batched = np.expand_dims(query_vec_np, axis=0) + return query_batched + + def _post_filtering( + self, + out_docs: DocList, + post_filters: Dict[str, Dict], + post_filter_limit: int, + find_executed: bool, + ) -> DocList: + """ + Performs post-filtering on the found documents. + + :param out_docs: The documents found by the 'find' operation. + :param post_filters: The post-filter conditions. + :param post_filter_limit: Limit for the post-filtering. + :param find_executed: Whether 'find' operation was executed. + :return: Filtered documents as per the post-filter conditions. + """ + if not find_executed: + out_docs = self.filter(post_filters, limit=self.num_docs()) + else: + docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema)) + out_docs = docs_cls(filter_docs(out_docs, post_filters)) + + if post_filters: + out_docs = out_docs[:post_filter_limit] + + return out_docs + + def _prepare_out_docs( + self, out_docs: DocList, doc_to_score: Dict[BaseDoc, Any] + ) -> FindResult: + """ + Prepares output documents with their scores. + + :param out_docs: The documents to be output. + :param doc_to_score: Mapping of documents to their scores. + :return: FindResult object with documents and their scores. + """ + if out_docs: + # If the "find" operation isn't called through the query builder, + # all returned scores will be 0 + docs_and_scores = zip( + out_docs, (doc_to_score.get(doc.id, 0) for doc in out_docs) + ) + docs_sorted = sorted(docs_and_scores, key=lambda x: x[1]) + out_docs, out_scores = zip(*docs_sorted) + else: + out_docs, out_scores = [], [] # type: ignore[assignment] + + return FindResult(documents=out_docs, scores=out_scores) diff --git a/tests/index/hnswlib/test_filter.py b/tests/index/hnswlib/test_filter.py new file mode 100644 index 00000000000..10466288a6d --- /dev/null +++ b/tests/index/hnswlib/test_filter.py @@ -0,0 +1,158 @@ +import numpy as np +import pytest + +from docarray import BaseDoc, DocList +from docarray.index import HnswDocumentIndex +from docarray.typing import NdArray + + +class SchemaDoc(BaseDoc): + text: str + price: int + tensor: NdArray[10] + + +@pytest.fixture +def docs(): + docs = DocList[SchemaDoc]( + [ + SchemaDoc(text=f'text {i}', price=i, tensor=np.random.rand(10)) + for i in range(9) + ] + ) + docs.append(SchemaDoc(text='zd all', price=100, tensor=np.random.rand(10))) + return docs + + +@pytest.fixture +def doc_index(docs, tmp_path): + doc_index = HnswDocumentIndex[SchemaDoc](work_dir=tmp_path) + doc_index.index(docs) + return doc_index + + +def test_build_query_eq(): + param_values = [] + query = {'text': {'$eq': 'text 1'}} + assert HnswDocumentIndex._build_filter_query(query, param_values) == 'text = ?' + assert param_values == ['text 1'] + + +def test_build_query_lt(): + param_values = [] + query = {'price': {'$lt': 500}} + assert HnswDocumentIndex._build_filter_query(query, param_values) == 'price < ?' + assert param_values == [500] + + +def test_build_query_and(): + param_values = [] + query = {'$and': [{'text': {'$eq': 'text 1'}}, {'price': {'$lt': 500}}]} + assert ( + HnswDocumentIndex._build_filter_query(query, param_values) + == '(text = ? AND price < ?)' + ) + assert param_values == ['text 1', 500] + + +def test_build_query_invalid_operator(): + param_values = [] + query = {'price': {'$invalid': 500}} + with pytest.raises(ValueError, match=r"Invalid operator \$invalid"): + HnswDocumentIndex._build_filter_query(query, param_values) + + +def test_build_query_invalid_query(): + param_values = [] + query = {'price': 500} + with pytest.raises(ValueError, match=r"Invalid condition for field price"): + HnswDocumentIndex._build_filter_query(query, param_values) + + +def test_filter_eq(doc_index): + docs = doc_index.filter({'text': {'$eq': 'text 1'}}) + assert len(docs) == 1 + assert docs[0].text == 'text 1' + + +def test_filter_neq(doc_index): + docs = doc_index.filter({'text': {'$neq': 'text 1'}}) + assert len(docs) == 9 + assert all(doc.text != 'text 1' for doc in docs) + + +def test_filter_lt(doc_index): + docs = doc_index.filter({'price': {'$lt': 3}}) + assert len(docs) == 3 + assert all(doc.price < 3 for doc in docs) + + +def test_filter_lte(doc_index): + docs = doc_index.filter({'price': {'$lte': 2}}) + assert len(docs) == 3 + assert all(doc.price <= 2 for doc in docs) + + +def test_filter_gt(doc_index): + docs = doc_index.filter({'price': {'$gt': 5}}) + assert len(docs) == 4 + assert all(doc.price > 5 for doc in docs) + + +def test_filter_gte(doc_index): + docs = doc_index.filter({'price': {'$gte': 6}}) + assert len(docs) == 4 + assert all(doc.price >= 6 for doc in docs) + + +def test_filter_exists(doc_index): + docs = doc_index.filter({'price': {'$exists': True}}) + assert len(docs) == 10 + assert all(hasattr(doc, 'price') for doc in docs) + + +def test_filter_or(doc_index): + docs = doc_index.filter( + { + '$or': [ + {'text': {'$eq': 'text 1'}}, + {'price': {'$eq': 2}}, + ] + } + ) + assert len(docs) == 2 + assert any(doc.text == 'text 1' or doc.price == 2 for doc in docs) + + +def test_filter_and(doc_index): + docs = doc_index.filter( + { + '$and': [ + {'text': {'$eq': 'text 1'}}, + {'price': {'$eq': 1}}, + ] + } + ) + assert len(docs) == 1 + assert any(doc.text == 'text 1' and doc.price == 1 for doc in docs) + + +def test_filter_not(doc_index): + docs = doc_index.filter({'$not': {'text': {'$eq': 'text 1'}}}) + assert len(docs) == 9 + assert all(doc.text != 'text 1' for doc in docs) + + +def test_filter_not_and(doc_index): + docs = doc_index.filter( + { + '$not': { + '$and': [ + {'text': {'$eq': 'text 1'}}, + {'price': {'$eq': 1}}, + ] + } + } + ) + assert len(docs) == 9 + assert all(not (doc.text == 'text 1' and doc.price == 1) for doc in docs) diff --git a/tests/index/hnswlib/test_find.py b/tests/index/hnswlib/test_find.py index 644f3665278..406e412f959 100644 --- a/tests/index/hnswlib/test_find.py +++ b/tests/index/hnswlib/test_find.py @@ -304,33 +304,6 @@ class MyDoc(BaseDoc): assert q.id == matches[0].id -@pytest.mark.parametrize( - 'find_limit, filter_limit, expected_docs', [(10, 3, 3), (5, None, 5)] -) -def test_query_builder_limits(find_limit, filter_limit, expected_docs, tmp_path): - class SimpleSchema(BaseDoc): - tensor: NdArray[10] = Field(space='l2') - price: int - - index = HnswDocumentIndex[SimpleSchema](work_dir=str(tmp_path)) - - index_docs = [SimpleSchema(tensor=np.array([i] * 10), price=i) for i in range(10)] - index.index(index_docs) - - query = SimpleSchema(tensor=np.array([3] * 10), price=3) - - q = ( - index.build_query() - .find(query=query, search_field='tensor', limit=find_limit) - .filter(filter_query={'price': {'$lte': 5}}, limit=filter_limit) - .build() - ) - - docs, scores = index.execute_query(q) - - assert len(docs) == expected_docs - - def test_contain(tmp_path): class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(space="cosine") diff --git a/tests/index/hnswlib/test_query_builder.py b/tests/index/hnswlib/test_query_builder.py new file mode 100644 index 00000000000..3070d67e3b2 --- /dev/null +++ b/tests/index/hnswlib/test_query_builder.py @@ -0,0 +1,186 @@ +import numpy as np +import pytest + +from docarray import BaseDoc, DocList +from docarray.index import HnswDocumentIndex +from docarray.typing import NdArray + + +class SchemaDoc(BaseDoc): + text: str + price: int + tensor: NdArray[10] + + +@pytest.fixture +def docs(): + docs = DocList[SchemaDoc]( + [ + SchemaDoc(text=f'text {i}', price=i, tensor=np.random.rand(10)) + for i in range(9) + ] + ) + docs.append(SchemaDoc(text='zd all', price=100, tensor=np.random.rand(10))) + return docs + + +@pytest.fixture +def doc_index(docs, tmp_path): + doc_index = HnswDocumentIndex[SchemaDoc](work_dir=tmp_path) + doc_index.index(docs) + return doc_index + + +def test_query_filter_find_filter(doc_index): + q = ( + doc_index.build_query() + .filter(filter_query={'price': {'$lte': 3}}) + .find(query=np.ones(10), search_field='tensor') + .filter(filter_query={'text': {'$eq': 'text 1'}}) + .build() + ) + + docs, scores = doc_index.execute_query(q) + + assert len(docs) == 1 + assert docs[0].price <= 3 + assert docs[0].text == 'text 1' + + +def test_query_find_filter(doc_index): + q = ( + doc_index.build_query() + .find(query=np.ones(10), search_field='tensor') + .filter(filter_query={'price': {'$gt': 3}}, limit=5) + .build() + ) + + docs, scores = doc_index.execute_query(q) + + assert len(docs) <= 5 + for doc in docs: + assert doc.price > 3 + + +def test_query_filter_exists_find(doc_index): + q = ( + doc_index.build_query() + .filter(filter_query={'text': {'$exists': True}}) + .find(query=np.ones(10), search_field='tensor') + .build() + ) + + docs, scores = doc_index.execute_query(q) + + # All documents have a 'text' field, so all documents should be returned. + assert len(docs) == 10 + + +def test_query_filter_not_exists_find(doc_index): + q = ( + doc_index.build_query() + .filter(filter_query={'text': {'$exists': False}}) + .find(query=np.ones(10), search_field='tensor') + .build() + ) + + docs, scores = doc_index.execute_query(q) + + # No documents have missing 'text' field, so no documents should be returned. + assert len(docs) == 0 + + +def test_query_find_filter_neq(doc_index): + q = ( + doc_index.build_query() + .find(query=np.ones(10), search_field='tensor') + .filter(filter_query={'price': {'$neq': 3}}, limit=5) + .build() + ) + + docs, scores = doc_index.execute_query(q) + + assert len(docs) <= 5 + for doc in docs: + assert doc.price != 3 + + +def test_query_filter_gte_find(doc_index): + q = ( + doc_index.build_query() + .filter(filter_query={'price': {'$gte': 5}}) + .find(query=np.ones(10), search_field='tensor') + .build() + ) + + docs, scores = doc_index.execute_query(q) + + for doc in docs: + assert doc.price >= 5 + + +def test_query_filter_lt_find_filter_gt(doc_index): + q = ( + doc_index.build_query() + .filter(filter_query={'price': {'$lt': 8}}) + .find(query=np.ones(10), search_field='tensor') + .filter(filter_query={'price': {'$gt': 2}}, limit=5) + .build() + ) + + docs, scores = doc_index.execute_query(q) + + assert len(docs) <= 5 + for doc in docs: + assert 2 < doc.price < 8 + + +def test_query_find_filter_and(doc_index): + q = ( + doc_index.build_query() + .find(query=np.ones(10), search_field='tensor') + .filter( + filter_query={ + '$and': [{'price': {'$gt': 2}}, {'text': {'$neq': 'text 1'}}] + }, + limit=5, + ) + .build() + ) + + docs, scores = doc_index.execute_query(q) + + assert len(docs) <= 5 + for doc in docs: + assert doc.price > 2 and doc.text != 'text 1' + + +def test_query_filter_or_find(doc_index): + q = ( + doc_index.build_query() + .filter( + filter_query={'$or': [{'price': {'$eq': 3}}, {'text': {'$eq': 'text 3'}}]} + ) + .find(query=np.ones(10), search_field='tensor') + .build() + ) + + docs, scores = doc_index.execute_query(q) + + for doc in docs: + assert doc.price == 3 or doc.text == 'text 3' + + +def test_query_find_filter_not(doc_index): + q = ( + doc_index.build_query() + .find(query=np.ones(10), search_field='tensor') + .filter(filter_query={'$not': {'price': {'$eq': 3}}}, limit=5) + .build() + ) + + docs, scores = doc_index.execute_query(q) + + assert len(docs) <= 5 + for doc in docs: + assert doc.price != 3 From 75d0bc38a625164207daf45d2182d3ef06a021ad Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Thu, 20 Jul 2023 20:41:43 +0200 Subject: [PATCH 2/5] style: mypys the worst Signed-off-by: jupyterjazz --- docarray/index/backends/hnswlib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 7d492cb0031..cf133606b4f 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -624,7 +624,7 @@ def accept_all(id): def accept_hashed_ids(id): """Accepts IDs that are in hashed_ids.""" - return id in hashed_ids + return id in hashed_ids # type: ignore[operator] # Choose the appropriate filter function based on whether hashed_ids was provided filter_function = accept_hashed_ids if hashed_ids else accept_all From ed621ef027cf92087f017193ea9493e2704f5bfc Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Fri, 21 Jul 2023 12:08:03 +0200 Subject: [PATCH 3/5] style: minor changes Signed-off-by: jupyterjazz --- docarray/index/backends/hnswlib.py | 70 +++++++++++++++--------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index cf133606b4f..01c63ea04fc 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -613,7 +613,7 @@ def _search_and_filter( # Ensure the search field is in the HNSW indices if search_field not in self._hnsw_indices: raise ValueError( - f"Search field {search_field} is not present in the HNSW indices" + f'Search field {search_field} is not present in the HNSW indices' ) index = self._hnsw_indices[search_field] @@ -654,41 +654,41 @@ def _build_filter_query( :param param_values: A list to store the parameters for the query. :return: A string representing a SQL filter query. """ - if isinstance(query, dict): - if len(query) != 1: - raise ValueError("Each nested dict must have exactly one key") - - key, value = next(iter(query.items())) - - if key in ['$and', '$or']: - # Combine subqueries using the AND or OR operator - subqueries = [cls._build_filter_query(q, param_values) for q in value] - return f'({f" {key[1:].upper()} ".join(subqueries)})' - elif key == '$not': - # Negate the query - return f'NOT {cls._build_filter_query(value, param_values)}' - else: # normal field - field = key - if not isinstance(value, dict) or len(value) != 1: - raise ValueError(f"Invalid condition for field {field}") - operator_key, operator_value = next(iter(value.items())) - - if operator_key == "$exists": - # Check for the existence or non-existence of a field - if operator_value: - return f'{field} IS NOT NULL' - else: - return f'{field} IS NULL' - elif operator_key not in OPERATOR_MAPPING: - raise ValueError(f"Invalid operator {operator_key}") + if not isinstance(query, dict): + raise ValueError('Invalid query') + + if len(query) != 1: + raise ValueError('Each nested dict must have exactly one key') + + key, value = next(iter(query.items())) + + if key in ['$and', '$or']: + # Combine subqueries using the AND or OR operator + subqueries = [cls._build_filter_query(q, param_values) for q in value] + return f'({f" {key[1:].upper()} ".join(subqueries)})' + elif key == '$not': + # Negate the query + return f'NOT {cls._build_filter_query(value, param_values)}' + else: # normal field + field = key + if not isinstance(value, dict) or len(value) != 1: + raise ValueError(f'Invalid condition for field {field}') + operator_key, operator_value = next(iter(value.items())) + + if operator_key == "$exists": + # Check for the existence or non-existence of a field + if operator_value: + return f'{field} IS NOT NULL' else: - # If the operator is valid, create a placeholder and append the value to param_values - operator = OPERATOR_MAPPING[operator_key] - placeholder = '?' - param_values.append(operator_value) - return f'{field} {operator} {placeholder}' - else: - raise ValueError("Invalid query") + return f'{field} IS NULL' + elif operator_key not in OPERATOR_MAPPING: + raise ValueError(f"Invalid operator {operator_key}") + else: + # If the operator is valid, create a placeholder and append the value to param_values + operator = OPERATOR_MAPPING[operator_key] + placeholder = '?' + param_values.append(operator_value) + return f'{field} {operator} {placeholder}' def _execute_filter( self, From 1a7eeb99d83e3a1a0b5c7ea38e1d91d82ee6f96c Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Fri, 21 Jul 2023 12:25:08 +0200 Subject: [PATCH 4/5] chore: bump hnswlib version Signed-off-by: jupyterjazz --- poetry.lock | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index ee685ca38b3..55fe579ba53 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5273,4 +5273,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "3f67aa7266b35860429a1911f5acdbef0db4b0ec2b5151ae2f563030c177c19e" +content-hash = "acf833d086fbe0c98e995ca60533883e5d90f24d2bba29ef7910b2bedabb93cb" diff --git a/pyproject.toml b/pyproject.toml index b6107f46ad8..32b05cf251b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ types-requests = ">=2.28.11.6" av = {version = ">=10.0.0", optional = true} fastapi = {version = ">=0.87.0", optional = true } rich = ">=13.1.0" -hnswlib = {version = ">=0.6.2", optional = true } +hnswlib = {version = ">=0.7.0", optional = true } lz4 = {version= ">=1.0.0", optional = true} pydub = {version = "^0.25.1", optional = true } pandas = {version = ">=1.1.0", optional = true } From e67dfba5a9d388cc7f12b78ded2c82ea06bf1f60 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Tue, 25 Jul 2023 13:31:41 +0200 Subject: [PATCH 5/5] test: add a test with multiple limits Signed-off-by: jupyterjazz --- docarray/index/backends/hnswlib.py | 1 + tests/index/hnswlib/test_query_builder.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 01c63ea04fc..6048a5fec8d 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -217,6 +217,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any: if safe_issubclass(python_type, allowed_type): return np.ndarray + # types allowed for filtering type_map = { int: 'INTEGER', float: 'REAL', diff --git a/tests/index/hnswlib/test_query_builder.py b/tests/index/hnswlib/test_query_builder.py index 3070d67e3b2..13581cdb29a 100644 --- a/tests/index/hnswlib/test_query_builder.py +++ b/tests/index/hnswlib/test_query_builder.py @@ -184,3 +184,19 @@ def test_query_find_filter_not(doc_index): assert len(docs) <= 5 for doc in docs: assert doc.price != 3 + + +@pytest.mark.parametrize( + 'find_limit, filter_limit, expected_docs', [(10, 3, 3), (5, 8, 5)] +) +def test_query_builder_limits(find_limit, filter_limit, expected_docs, doc_index): + q = ( + doc_index.build_query() + .filter(filter_query={'price': {'$lte': 5}}, limit=filter_limit) + .find(query=np.random.rand(10), search_field='tensor', limit=find_limit) + .build() + ) + + docs, scores = doc_index.execute_query(q) + + assert len(docs) == expected_docs pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy