From df474741e5b24e2641da63ff66d6f4e9b3daa7d9 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 17 May 2023 16:24:11 +0200 Subject: [PATCH 01/28] feat: support redis Signed-off-by: jupyterjazz --- docarray/index/__init__.py | 4 ++ docarray/index/backends/redis.py | 88 ++++++++++++++++++++++++++++ docarray/utils/_internal/misc.py | 1 + poetry.lock | 25 +++++++- pyproject.toml | 2 + tests/index/redis/__init__.py | 0 tests/index/redis/docker-compose.yml | 6 ++ tests/index/redis/fixtures.py | 34 +++++++++++ 8 files changed, 158 insertions(+), 2 deletions(-) create mode 100644 docarray/index/backends/redis.py create mode 100644 tests/index/redis/__init__.py create mode 100644 tests/index/redis/docker-compose.yml create mode 100644 tests/index/redis/fixtures.py diff --git a/docarray/index/__init__.py b/docarray/index/__init__.py index 9e4dbde474a..b24877526a2 100644 --- a/docarray/index/__init__.py +++ b/docarray/index/__init__.py @@ -13,6 +13,7 @@ from docarray.index.backends.hnswlib import HnswDocumentIndex # noqa: F401 from docarray.index.backends.qdrant import QdrantDocumentIndex # noqa: F401 from docarray.index.backends.weaviate import WeaviateDocumentIndex # noqa: F401 + from docarray.index.backends.redis import RedisDocumentIndex # noqa: F401 __all__ = ['InMemoryExactNNIndex'] @@ -34,6 +35,9 @@ def __getattr__(name: str): elif name == 'WeaviateDocumentIndex': import_library('weaviate', raise_error=True) import docarray.index.backends.weaviate as lib + elif name == 'RedisDocumentIndex': + import_library('redis', raise_error=True) + import docarray.index.backends.redis as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py new file mode 100644 index 00000000000..e492dbf9b47 --- /dev/null +++ b/docarray/index/backends/redis.py @@ -0,0 +1,88 @@ +from typing import TypeVar, Generic, Optional, List, Dict, Any, Sequence, Union, Generator, Type +from dataclasses import dataclass, field + +import numpy as np + +from docarray import BaseDoc, DocList +from docarray.index.abstract import BaseDocIndex +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils.find import _FindResultBatched, _FindResult +from redis.commands.search.field import NumericField, TextField, VectorField, GeoField + + +TSchema = TypeVar('TSchema', bound=BaseDoc) + + +class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]): + def __init__(self, db_config=None, **kwargs): + super().__init__(db_config=db_config, **kwargs) + + @dataclass + class DBConfig(BaseDocIndex.DBConfig): + """Dataclass that contains all "static" configurations of RedisDocumentIndex.""" + + host: str = 'http://localhost:6379' + index_name: Optional[str] = None + username: Optional[str] = None + password: Optional[str] = None + + @dataclass + class RuntimeConfig(BaseDocIndex.RuntimeConfig): + """Dataclass that contains all "dynamic" configurations of RedisDocumentIndex.""" + + default_column_config: Dict[Any, Dict[str, Any]] = field( + default_factory=lambda: { + TextField: {}, + NumericField: {}, + VectorField: {}, + } + ) + + def python_type_to_db_type(self, python_type: Type) -> Any: + type_map = { + int: NumericField, + float: NumericField, + str: TextField, + np.ndarray: VectorField, + list: VectorField, + AbstractTensor: VectorField, + } + + for py_type, redis_type in type_map.items(): + if issubclass(python_type, py_type): + return redis_type + raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') + + def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): + pass + + def num_docs(self) -> int: + pass + + def _del_items(self, doc_ids: Sequence[str]): + pass + + def _get_items(self, doc_ids: Sequence[str]) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: + pass + + def execute_query(self, query: Any, *args, **kwargs) -> Any: + pass + + def _find(self, query: np.ndarray, limit: int, search_field: str = '') -> _FindResult: + pass + + def _find_batched(self, queries: np.ndarray, limit: int, search_field: str = '') -> _FindResultBatched: + pass + + def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: + pass + + def _filter_batched(self, filter_queries: Any, limit: int) -> Union[List[DocList], List[List[Dict]]]: + pass + + def _text_search(self, query: str, limit: int, search_field: str = '') -> _FindResult: + pass + + def _text_search_batched(self, queries: Sequence[str], limit: int, search_field: str = '') -> _FindResultBatched: + pass + diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index ea1b7399ffd..b2a5d2ea3fc 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -42,6 +42,7 @@ 'smart_open': '"docarray[aws]"', 'boto3': '"docarray[aws]"', 'botocore': '"docarray[aws]"', + 'redis': '"docarray[redis]' } diff --git a/poetry.lock b/poetry.lock index 13264584717..c0c2c5c344b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -224,7 +224,7 @@ name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" category = "main" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"}, @@ -3715,6 +3715,27 @@ pydantic = ">=1.8,<2.0" typing-extensions = ">=4.0.0,<5.0.0" urllib3 = ">=1.26.14,<2.0.0" +[[package]] +name = "redis" +version = "4.5.5" +description = "Python client for Redis database and key-value store" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-4.5.5-py3-none-any.whl", hash = "sha256:77929bc7f5dab9adf3acba2d3bb7d7658f1e0c2f1cafe7eb36434e751c471119"}, + {file = "redis-4.5.5.tar.gz", hash = "sha256:dc87a0bdef6c8bfe1ef1e1c40be7034390c2ae02d92dcd0c7ca1729443899880"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""} +importlib-metadata = {version = ">=1.0", markers = "python_version < \"3.8\""} +typing-extensions = {version = "*", markers = "python_version < \"3.8\""} + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "regex" version = "2022.10.31" @@ -4972,4 +4993,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.7,<4.0" -content-hash = "0900f3f885ab2d0f0ef79e2772ce0322868d275a4576dc88a911178a8edda910" +content-hash = "d4b47231ad2c9f49e277b347c7ccb5a5db1d79a2fdcad6739381f8dbd2c4054f" diff --git a/pyproject.toml b/pyproject.toml index ed084e92914..b593538385e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ smart-open = {version = ">=6.3.0", extras = ["s3"], optional = true} jina-hubble-sdk = {version = ">=0.34.0", optional = true} elastic-transport = {version ="^8.4.0", optional = true } qdrant-client = {version = ">=1.1.4", python = "<3.12", optional = true } +redis = {version = "^4.5.5", optional = true } [tool.poetry.extras] proto = ["protobuf", "lz4"] @@ -75,6 +76,7 @@ torch = ["torch"] web = ["fastapi"] qdrant = ["qdrant-client"] weaviate = ["weaviate-client"] +redis = ['redis'] # all full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh"] diff --git a/tests/index/redis/__init__.py b/tests/index/redis/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/index/redis/docker-compose.yml b/tests/index/redis/docker-compose.yml new file mode 100644 index 00000000000..4db029055dd --- /dev/null +++ b/tests/index/redis/docker-compose.yml @@ -0,0 +1,6 @@ +version: '3.8' +services: + redis: + image: redis:latest + ports: + - "6379:6379" diff --git a/tests/index/redis/fixtures.py b/tests/index/redis/fixtures.py new file mode 100644 index 00000000000..c9d89d766a1 --- /dev/null +++ b/tests/index/redis/fixtures.py @@ -0,0 +1,34 @@ +import os +import time +import uuid +import pytest +import redis + + +@pytest.fixture(scope='session', autouse=True) +def start_redis(): + os.system('docker run -d -p 6379:6379 --name test-redis redis') + time.sleep(1) + + yield + + os.system('docker rm -f test-redis') + + +@pytest.fixture(scope='function') +def tmp_collection_name(): + return uuid.uuid4().hex + + +@pytest.fixture +def redis_client(): + """This fixture provides a Redis client""" + client = redis.Redis(host='localhost', port=6379) + yield client + client.flushall() + + +@pytest.fixture +def redis_config(redis_client): + """This fixture provides the Redis client and flushes all data after each test case""" + return redis_client From 51558b7df8dbc16c74619a23849f7c1bf6a7f349 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Thu, 15 Jun 2023 11:57:34 +0200 Subject: [PATCH 02/28] fix: index creation Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 170 +++++++++++++++++++++++++-- tests/index/redis/docker-compose.yml | 6 - tests/index/redis/fixtures.py | 4 +- 3 files changed, 160 insertions(+), 20 deletions(-) delete mode 100644 tests/index/redis/docker-compose.yml diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index e492dbf9b47..915746ac605 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -1,13 +1,39 @@ -from typing import TypeVar, Generic, Optional, List, Dict, Any, Sequence, Union, Generator, Type +from typing import ( + TypeVar, + Generic, + Optional, + List, + Dict, + Any, + Sequence, + Union, + Generator, + Type, + cast, + TYPE_CHECKING, Iterator, +) from dataclasses import dataclass, field import numpy as np +import pickle from docarray import BaseDoc, DocList from docarray.index.abstract import BaseDocIndex from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal.misc import import_library from docarray.utils.find import _FindResultBatched, _FindResult -from redis.commands.search.field import NumericField, TextField, VectorField, GeoField + +if TYPE_CHECKING: + import redis +else: + redis = import_library('redis') + + from redis.commands.search.field import ( + NumericField, + TextField, + VectorField, + ) + from redis.commands.search.indexDefinition import IndexDefinition, IndexType TSchema = TypeVar('TSchema', bound=BaseDoc) @@ -16,15 +42,76 @@ class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): super().__init__(db_config=db_config, **kwargs) + self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config) + + if not self._db_config.index_name: + self._db_config.index_name = 'index_name__' + 'random_name' # todo + self._prefix = self._db_config.index_name + ':' + + # initialize Redis client + self._client = redis.Redis( + host=self._db_config.host, + port=self._db_config.port, + username=self._db_config.username, + password=self._db_config.password, + ) + self._create_index() + self._logger.info(f'{self.__class__.__name__} has been initialized') + + def _create_index(self): + if not self._check_index_exists(self._db_config.index_name): + schema = [] + for column, info in self._column_infos.items(): + if info.db_type == VectorField: + schema.append( + info.db_type( + name=column, + algorithm=info.config.get( + 'algorithm', self._db_config.algorithm + ), + attributes={ + 'TYPE': 'FLOAT32', + 'DIM': info.n_dim, + 'DISTANCE_METRIC': 'COSINE', + }, + ) + ) + else: + schema.append(info.db_type(name=column)) + + + # Create Redis Index + self._client.ft(self._db_config.index_name).create_index( + fields=schema, + definition=IndexDefinition(prefix=[self._prefix], index_type=IndexType.HASH), + ) + + self._logger.info(f'index {self._db_config.index_name} has been created') + else: + self._logger.info( + f'connected to existing {self._db_config.index_name} index' + ) + + def _check_index_exists(self, index_name: str) -> bool: + """Check if Redis index exists.""" + try: + self._client.ft(index_name).info() + except: # noqa: E722 + self._logger.info("Index does not exist") + return False + self._logger.info("Index already exists") + return True @dataclass class DBConfig(BaseDocIndex.DBConfig): """Dataclass that contains all "static" configurations of RedisDocumentIndex.""" - host: str = 'http://localhost:6379' + host: str = 'localhost' + port: int = 6379 index_name: Optional[str] = None username: Optional[str] = None password: Optional[str] = None + algorithm: str = 'FLAT' @dataclass class RuntimeConfig(BaseDocIndex.RuntimeConfig): @@ -43,6 +130,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any: int: NumericField, float: NumericField, str: TextField, + bytes: TextField, np.ndarray: VectorField, list: VectorField, AbstractTensor: VectorField, @@ -53,36 +141,94 @@ def python_type_to_db_type(self, python_type: Type) -> Any: return redis_type raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') + @staticmethod + def _generate_item(column_to_data: Dict[str, Generator[Any, None, None]]) -> Iterator[Dict[str, Any]]: + """ + Given a dictionary of generators, yield a dictionary where each item consists of a key and + a single item from the corresponding generator. + + :param column_to_data: A dictionary where each key is a column and each value + is a generator. + + :yield: A dictionary where each item consists of a column name and an item from + the corresponding generator. Yields until all generators are exhausted. + """ + keys = list(column_to_data.keys()) + iterators = [iter(column_to_data[key]) for key in keys] + while True: + item_dict = {} + for key, it in zip(keys, iterators): + item = next(it, None) + if item is None: # If item is not None, add it to the dictionary + continue + if isinstance(item, AbstractTensor): + item_dict[key] = pickle.dumps(item) + else: + item_dict[key] = item + + if not item_dict: # If item_dict is empty, break the loop + break + yield item_dict + def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): - pass + ids = [] + pipeline = self._client.pipeline(transaction=False) + batch_size = 10 + for item in self._generate_item(column_to_data): + doc_id = self._prefix + item.pop('id') + pipeline.hset( + doc_id, + mapping=item, + ) + ids.append(doc_id) + + if len(ids) % batch_size == 0: + pipeline.execute() + + pipeline.execute() + + num_docs = self.num_docs() + print('indexed', num_docs) + return ids def num_docs(self) -> int: - pass + return self._client.ft(self._db_config.index_name).info()['num_docs'] def _del_items(self, doc_ids: Sequence[str]): pass - def _get_items(self, doc_ids: Sequence[str]) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: + def _get_items( + self, doc_ids: Sequence[str] + ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: pass def execute_query(self, query: Any, *args, **kwargs) -> Any: pass - def _find(self, query: np.ndarray, limit: int, search_field: str = '') -> _FindResult: + def _find( + self, query: np.ndarray, limit: int, search_field: str = '' + ) -> _FindResult: pass - def _find_batched(self, queries: np.ndarray, limit: int, search_field: str = '') -> _FindResultBatched: + def _find_batched( + self, queries: np.ndarray, limit: int, search_field: str = '' + ) -> _FindResultBatched: pass def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: pass - def _filter_batched(self, filter_queries: Any, limit: int) -> Union[List[DocList], List[List[Dict]]]: + def _filter_batched( + self, filter_queries: Any, limit: int + ) -> Union[List[DocList], List[List[Dict]]]: pass - def _text_search(self, query: str, limit: int, search_field: str = '') -> _FindResult: + def _text_search( + self, query: str, limit: int, search_field: str = '' + ) -> _FindResult: pass - def _text_search_batched(self, queries: Sequence[str], limit: int, search_field: str = '') -> _FindResultBatched: + def _text_search_batched( + self, queries: Sequence[str], limit: int, search_field: str = '' + ) -> _FindResultBatched: pass - diff --git a/tests/index/redis/docker-compose.yml b/tests/index/redis/docker-compose.yml deleted file mode 100644 index 4db029055dd..00000000000 --- a/tests/index/redis/docker-compose.yml +++ /dev/null @@ -1,6 +0,0 @@ -version: '3.8' -services: - redis: - image: redis:latest - ports: - - "6379:6379" diff --git a/tests/index/redis/fixtures.py b/tests/index/redis/fixtures.py index c9d89d766a1..b1e55f4a8ec 100644 --- a/tests/index/redis/fixtures.py +++ b/tests/index/redis/fixtures.py @@ -7,12 +7,12 @@ @pytest.fixture(scope='session', autouse=True) def start_redis(): - os.system('docker run -d -p 6379:6379 --name test-redis redis') + os.system('docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest') time.sleep(1) yield - os.system('docker rm -f test-redis') + os.system('docker rm -f redis-stack-server') @pytest.fixture(scope='function') From 12da714a673af63045f11fbc226f6895687ddae0 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 21 Jun 2023 16:59:52 +0200 Subject: [PATCH 03/28] feat: 1st draft, needs polishing Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 269 ++++++++++++++++++++++++++++--- tests/index/redis/tests.py | 74 +++++++++ 2 files changed, 319 insertions(+), 24 deletions(-) create mode 100644 tests/index/redis/tests.py diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 915746ac605..a65694239a5 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -1,3 +1,4 @@ +import uuid from typing import ( TypeVar, Generic, @@ -10,12 +11,16 @@ Generator, Type, cast, - TYPE_CHECKING, Iterator, + TYPE_CHECKING, + Iterator, + Mapping, ) from dataclasses import dataclass, field +import binascii import numpy as np -import pickle + +from redis.commands.search.query import Query from docarray import BaseDoc, DocList from docarray.index.abstract import BaseDocIndex @@ -34,10 +39,23 @@ VectorField, ) from redis.commands.search.indexDefinition import IndexDefinition, IndexType - + from redis.commands.search.querystring import ( + DistjunctUnion, + IntersectNode, + equal, + ge, + gt, + intersect, + le, + lt, + union, + ) TSchema = TypeVar('TSchema', bound=BaseDoc) +VALID_DISTANCES = ['L2', 'IP', 'COSINE'] +VALID_ALGORITHMS = ['FLAT', 'HNSW'] + class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): @@ -45,7 +63,7 @@ def __init__(self, db_config=None, **kwargs): self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config) if not self._db_config.index_name: - self._db_config.index_name = 'index_name__' + 'random_name' # todo + self._db_config.index_name = 'index_name__' + self._random_name() self._prefix = self._db_config.index_name + ':' # initialize Redis client @@ -54,15 +72,29 @@ def __init__(self, db_config=None, **kwargs): port=self._db_config.port, username=self._db_config.username, password=self._db_config.password, + decode_responses=False, ) self._create_index() self._logger.info(f'{self.__class__.__name__} has been initialized') + @staticmethod + def _random_name(): + return uuid.uuid4().hex + def _create_index(self): if not self._check_index_exists(self._db_config.index_name): schema = [] for column, info in self._column_infos.items(): + if info.db_type == VectorField: + space = info.config.get('space') + if space: + for valid_dist in VALID_DISTANCES: + if space.upper() == valid_dist: + space = valid_dist + if space not in VALID_DISTANCES: + space = self._db_config.distance + schema.append( info.db_type( name=column, @@ -72,18 +104,19 @@ def _create_index(self): attributes={ 'TYPE': 'FLOAT32', 'DIM': info.n_dim, - 'DISTANCE_METRIC': 'COSINE', + 'DISTANCE_METRIC': space, }, ) ) else: schema.append(info.db_type(name=column)) - # Create Redis Index self._client.ft(self._db_config.index_name).create_index( fields=schema, - definition=IndexDefinition(prefix=[self._prefix], index_type=IndexType.HASH), + definition=IndexDefinition( + prefix=[self._prefix], index_type=IndexType.HASH + ), ) self._logger.info(f'index {self._db_config.index_name} has been created') @@ -111,7 +144,22 @@ class DBConfig(BaseDocIndex.DBConfig): index_name: Optional[str] = None username: Optional[str] = None password: Optional[str] = None - algorithm: str = 'FLAT' + algorithm: str = field(default='FLAT') + distance: str = field(default='COSINE') + ef_construction: Optional[int] = None + m: Optional[int] = None + ef_runtime: Optional[int] = None + block_size: Optional[int] = None + initial_cap: Optional[int] = None + + def __post_init__(self): + if self.algorithm not in VALID_ALGORITHMS: + raise ValueError(f"Invalid algorithm '{self.algorithm}' provided. " + f"Must be one of: {', '.join(VALID_ALGORITHMS)}") + + if self.distance not in VALID_DISTANCES: + raise ValueError(f"Invalid distance metric '{self.distance}' provided. " + f"Must be one of: {', '.join(VALID_DISTANCES)}") @dataclass class RuntimeConfig(BaseDocIndex.RuntimeConfig): @@ -142,7 +190,9 @@ def python_type_to_db_type(self, python_type: Type) -> Any: raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') @staticmethod - def _generate_item(column_to_data: Dict[str, Generator[Any, None, None]]) -> Iterator[Dict[str, Any]]: + def _generate_item( + column_to_data: Dict[str, Generator[Any, None, None]] + ) -> Iterator[Dict[str, Any]]: """ Given a dictionary of generators, yield a dictionary where each item consists of a key and a single item from the corresponding generator. @@ -159,21 +209,25 @@ def _generate_item(column_to_data: Dict[str, Generator[Any, None, None]]) -> Ite item_dict = {} for key, it in zip(keys, iterators): item = next(it, None) - if item is None: # If item is not None, add it to the dictionary - continue - if isinstance(item, AbstractTensor): - item_dict[key] = pickle.dumps(item) + + if key == 'id' and not item: + return + + if item is None: + item_dict[key] = '__None__' + elif isinstance(item, AbstractTensor): + item_dict[key] = np.array( + item._docarray_to_ndarray(), dtype=np.float32 + ).tobytes() else: item_dict[key] = item - if not item_dict: # If item_dict is empty, break the loop - break yield item_dict def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): ids = [] pipeline = self._client.pipeline(transaction=False) - batch_size = 10 + batch_size = 10 # variable [1k] for item in self._generate_item(column_to_data): doc_id = self._prefix + item.pop('id') pipeline.hset( @@ -195,40 +249,207 @@ def num_docs(self) -> int: return self._client.ft(self._db_config.index_name).info()['num_docs'] def _del_items(self, doc_ids: Sequence[str]): - pass + doc_ids = [self._prefix + id for id in doc_ids if self._doc_exists(id)] + if doc_ids: + self._client.delete(*doc_ids) + + def _doc_exists(self, doc_id): + return self._client.exists(self._prefix + doc_id) def _get_items( self, doc_ids: Sequence[str] ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: - pass + if not doc_ids: + return [] + + pipe = self._client.pipeline() + for id in doc_ids: + pipe.hgetall(self._prefix + id) + + results = pipe.execute() + + docs = [ + {k.decode('utf-8'): v.decode('utf-8', 'ignore') for k, v in d.items()} + for d in results + ] + + docs = [{k: v for k, v in d.items() if k != 'tens'} for d in docs] # todo (vector decoding problem) + docs = [{k: None if v == '__None__' else v for k, v in d.items()} for d in docs] # todo (converting to None) + return docs def execute_query(self, query: Any, *args, **kwargs) -> Any: pass + def _convert_to_schema(self, document): + doc_kwargs = {} + for column, info in self._column_infos.items(): + if column == 'id': + doc_kwargs['id'] = document.id[len(self._prefix) :] + elif document[column] == '__None__': + doc_kwargs[column] = None + elif info.db_type == VectorField: + # byte_string = document[column] + # byte_data = byte_string.encode('utf-8') + doc_kwargs[column] = np.frombuffer(document[column], dtype=np.float32) + elif info.db_type == NumericField: + doc_kwargs[column] = info.docarray_type(document[column]) + else: + doc_kwargs[column] = document[column] + + return doc_kwargs + def _find( self, query: np.ndarray, limit: int, search_field: str = '' ) -> _FindResult: - pass + limit = 5 + query_str = '*' + redis_query = ( + Query(f'{query_str}=>[KNN {limit} @{search_field} $vec AS vector_score]') + .sort_by('vector_score') + .paging(0, limit) + .dialect(2) + ) + query_params: Mapping[str, str] = { # type: ignore + 'vec': np.array(query, dtype=np.float32).tobytes() + } + results = ( + self._client.ft(self._db_config.index_name) + .search(redis_query, query_params) + .docs + ) + + scores = [document['vector_score'] for document in results] + docs = [self._convert_to_schema(document) for document in results] + + return _FindResult(documents=docs, scores=scores) def _find_batched( self, queries: np.ndarray, limit: int, search_field: str = '' ) -> _FindResultBatched: - pass + docs, scores = [], [] + for query in queries: + results = self._find(query=query, search_field=search_field, limit=limit) + docs.append(results.documents) + scores.append(results.scores) + + return _FindResultBatched(documents=docs, scores=scores) def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: - pass + query_str = self._get_redis_filter_query(filter_query) + q = Query(query_str) + q.paging(0, limit) + + results = self._client.ft(index_name=self._db_config.index_name).search(q).docs + docs = [self._convert_to_schema(document) for document in results] + + return docs + + def _build_query_node(self, key, condition): + operator = list(condition.keys())[0] + value = condition[operator] + + query_dict = {} + + if operator in ['$ne', '$eq']: + if isinstance(value, bool): + query_dict[key] = equal(int(value)) + elif isinstance(value, (int, float)): + query_dict[key] = equal(value) + else: + query_dict[key] = value + elif operator == '$gt': + query_dict[key] = gt(value) + elif operator == '$gte': + query_dict[key] = ge(value) + elif operator == '$lt': + query_dict[key] = lt(value) + elif operator == '$lte': + query_dict[key] = le(value) + else: + raise ValueError( + f'Expecting filter operator one of $gt, $gte, $lt, $lte, $eq, $ne, $and OR $or, got {operator} instead' + ) + + if operator == '$ne': + return DistjunctUnion(**query_dict) + return IntersectNode(**query_dict) + + def _build_query_nodes(self, filter): + nodes = [] + for k, v in filter.items(): + if k == '$and': + children = self._build_query_nodes(v) + node = intersect(*children) + nodes.append(node) + elif k == '$or': + children = self._build_query_nodes(v) + node = union(*children) + nodes.append(node) + else: + child = self._build_query_node(k, v) + nodes.append(child) + + return nodes + + def _get_redis_filter_query(self, filter: Union[str, Dict]): + if isinstance(filter, dict): + nodes = self._build_query_nodes(filter) + query_str = intersect(*nodes).to_string() + elif isinstance(filter, str): + query_str = filter + else: + raise ValueError(f'Unexpected type of filter: {type(filter)}, expected str') + + return query_str def _filter_batched( self, filter_queries: Any, limit: int ) -> Union[List[DocList], List[List[Dict]]]: - pass + results = [] + for query in filter_queries: + results.append(self._filter(filter_query=query, limit=limit)) + return results def _text_search( self, query: str, limit: int, search_field: str = '' ) -> _FindResult: - pass + query_str = '|'.join(query.split(' ')) + + scorer = 'BM25' + if scorer not in [ + 'BM25', + 'TFIDF', + 'TFIDF.DOCNORM', + 'DISMAX', + 'DOCSCORE', + 'HAMMING', + ]: + raise ValueError( + f'Expecting a valid text similarity ranking algorithm, got {scorer} instead' + ) + q = ( + Query(f'@{search_field}:{query_str}') + .scorer(scorer) + .with_scores() + .paging(0, limit) + ) + + results = self._client.ft(index_name=self._db_config.index_name).search(q).docs + + scores = [document['score'] for document in results] + docs = [self._convert_to_schema(document) for document in results] + + return _FindResult(documents=docs, scores=scores) def _text_search_batched( self, queries: Sequence[str], limit: int, search_field: str = '' ) -> _FindResultBatched: - pass + docs, scores = [], [] + for query in queries: + results = self._text_search( + query=query, search_field=search_field, limit=limit + ) + docs.append(results.documents) + scores.append(results.scores) + + return _FindResultBatched(documents=docs, scores=scores) diff --git a/tests/index/redis/tests.py b/tests/index/redis/tests.py new file mode 100644 index 00000000000..65ef62a8cae --- /dev/null +++ b/tests/index/redis/tests.py @@ -0,0 +1,74 @@ +import numpy as np +import pytest + +from docarray import BaseDoc +from docarray.index import RedisDocumentIndex +from pydantic import Field +from docarray.typing import NdArray +from tests.index.redis.fixtures import start_redis, tmp_collection_name # noqa: F401 +from typing import Optional + + +@pytest.mark.parametrize('space', ['cosine']) +def test_find_simple_schema(space): + class SimpleSchema(BaseDoc): + tens: Optional[NdArray[10]] = Field(space=space, algorithm='HNSW') # type: ignore[valid-type] + bla: int + title: str + smth: Optional[str] = None + tenss: Optional[NdArray[10]] = None + + index = RedisDocumentIndex[SimpleSchema](host='localhost') + + docs = [SimpleSchema(bla=i, title=f'zdall {i}', tens=np.random.rand(10)) for i in range(5)] + docs.append(SimpleSchema(bla=6, title=f'hey everyone how are you', tens=np.random.rand(10))) + docs.append(SimpleSchema(bla=7, title=f'hey how are you', tens=np.random.rand(10))) + + + index.index(docs) + + query = np.random.rand(10) + results = index.find(query, search_field='tens') + print(len(results)) + + results = index.find_batched(np.array([np.random.rand(10), np.random.rand(10)]), search_field='tens') + print('find batched', results) + res = index[docs[0].id] + print(index.num_docs()) + del index[docs[0].id] + print(index.num_docs()) + + docs = index.filter({'bla': {'$gt': 3}}) + + print('filtered', docs) + + docs = index.filter_batched([{'bla': {'$gt': 3}}, {'bla': {'$lte': 3}}]) + print('batched filt', docs) + + docs = index.text_search(query='hey everyone', search_field='title') + print(docs) + + docs = index.text_search_batched(queries=['hey hey', 'hey everyone'], search_field='title') + print(docs) + + +def test_simple_scenario(): + # Define a document schema + class SimpleSchema(BaseDoc): + tensor: Optional[NdArray[10]] = Field(space='COSINE') + year: int + title: Optional[str] = None + + # Create a document index + index = RedisDocumentIndex[SimpleSchema](host='localhost') + + # Prepare documents + docs = [SimpleSchema(year=i, title=f'some text {i}', tensor=np.random.rand(10)) for i in range(5)] + + # Index + index.index(docs) + + # Search + query = np.random.rand(10) + results = index.find(query, search_field='tensor') + print(results) From be7bed7792448abd857dab0afe240228d58df939 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 09:12:14 +0200 Subject: [PATCH 04/28] feat: query builder, tests Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 379 +++++++++++++---------- tests/index/redis/test_configurations.py | 40 +++ tests/index/redis/test_find.py | 294 ++++++++++++++++++ tests/index/redis/test_index_get_del.py | 101 ++++++ tests/index/redis/test_persist_data.py | 41 +++ tests/index/redis/tests.py | 74 ----- 6 files changed, 687 insertions(+), 242 deletions(-) create mode 100644 tests/index/redis/test_configurations.py create mode 100644 tests/index/redis/test_find.py create mode 100644 tests/index/redis/test_index_get_del.py create mode 100644 tests/index/redis/test_persist_data.py delete mode 100644 tests/index/redis/tests.py diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index a65694239a5..b96ea730a90 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -14,22 +14,34 @@ TYPE_CHECKING, Iterator, Mapping, + Tuple, ) from dataclasses import dataclass, field -import binascii +import json import numpy as np +from numpy import ndarray -from redis.commands.search.query import Query - +from docarray.index.backends.helper import _collect_query_args from docarray import BaseDoc, DocList -from docarray.index.abstract import BaseDocIndex +from docarray.index.abstract import ( + BaseDocIndex, + _raise_not_composable, +) +from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal.misc import import_library -from docarray.utils.find import _FindResultBatched, _FindResult +from docarray.utils.find import _FindResultBatched, _FindResult, FindResult if TYPE_CHECKING: import redis + from redis.commands.search.query import Query + from redis.commands.search.field import ( + NumericField, + TextField, + VectorField, + ) + from redis.commands.search.indexDefinition import IndexDefinition, IndexType else: redis = import_library('redis') @@ -39,22 +51,20 @@ VectorField, ) from redis.commands.search.indexDefinition import IndexDefinition, IndexType - from redis.commands.search.querystring import ( - DistjunctUnion, - IntersectNode, - equal, - ge, - gt, - intersect, - le, - lt, - union, - ) + from redis.commands.search.query import Query TSchema = TypeVar('TSchema', bound=BaseDoc) VALID_DISTANCES = ['L2', 'IP', 'COSINE'] VALID_ALGORITHMS = ['FLAT', 'HNSW'] +VALID_TEXT_SCORERS = [ + 'BM25', + 'TFIDF', + 'TFIDF.DOCNORM', + 'DISMAX', + 'DOCSCORE', + 'HAMMING', +] class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]): @@ -85,7 +95,6 @@ def _create_index(self): if not self._check_index_exists(self._db_config.index_name): schema = [] for column, info in self._column_infos.items(): - if info.db_type == VectorField: space = info.config.get('space') if space: @@ -95,27 +104,36 @@ def _create_index(self): if space not in VALID_DISTANCES: space = self._db_config.distance + attributes = { + 'TYPE': 'FLOAT32', + 'DIM': info.n_dim or info.config.get('dim'), + 'DISTANCE_METRIC': space, + 'EF_CONSTRUCTION': self._db_config.ef_construction, + 'EF_RUNTIME': self._db_config.ef_runtime, + 'M': self._db_config.m, + 'INITIAL_CAP': self._db_config.initial_cap, + } + attributes = { + name: value for name, value in attributes.items() if value + } schema.append( info.db_type( - name=column, + '$.' + column, algorithm=info.config.get( 'algorithm', self._db_config.algorithm ), - attributes={ - 'TYPE': 'FLOAT32', - 'DIM': info.n_dim, - 'DISTANCE_METRIC': space, - }, + attributes=attributes, + as_name=column, ) ) else: - schema.append(info.db_type(name=column)) + schema.append(info.db_type('$.' + column, as_name=column)) # Create Redis Index self._client.ft(self._db_config.index_name).create_index( - fields=schema, + schema, definition=IndexDefinition( - prefix=[self._prefix], index_type=IndexType.HASH + prefix=[self._prefix], index_type=IndexType.JSON ), ) @@ -135,6 +153,23 @@ def _check_index_exists(self, index_name: str) -> bool: self._logger.info("Index already exists") return True + class QueryBuilder(BaseDocIndex.QueryBuilder): + def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None): + super().__init__() + # list of tuples (method name, kwargs) + self._queries: List[Tuple[str, Dict]] = query or [] + + def build(self, *args, **kwargs) -> Any: + """Build the query object.""" + return self._queries + + find = _collect_query_args('find') + filter = _collect_query_args('filter') + text_search = _raise_not_composable('text_search') + find_batched = _raise_not_composable('find_batched') + filter_batched = _raise_not_composable('filter_batched') + text_search_batched = _raise_not_composable('text_search_batched') + @dataclass class DBConfig(BaseDocIndex.DBConfig): """Dataclass that contains all "static" configurations of RedisDocumentIndex.""" @@ -146,6 +181,7 @@ class DBConfig(BaseDocIndex.DBConfig): password: Optional[str] = None algorithm: str = field(default='FLAT') distance: str = field(default='COSINE') + text_scorer: str = field(default='BM25') ef_construction: Optional[int] = None m: Optional[int] = None ef_runtime: Optional[int] = None @@ -153,13 +189,26 @@ class DBConfig(BaseDocIndex.DBConfig): initial_cap: Optional[int] = None def __post_init__(self): + self.algorithm = self.algorithm.upper() + self.distance = self.distance.upper() + self.text_scorer = self.text_scorer.upper() if self.algorithm not in VALID_ALGORITHMS: - raise ValueError(f"Invalid algorithm '{self.algorithm}' provided. " - f"Must be one of: {', '.join(VALID_ALGORITHMS)}") + raise ValueError( + f"Invalid algorithm '{self.algorithm}' provided. " + f"Must be one of: {', '.join(VALID_ALGORITHMS)}" + ) if self.distance not in VALID_DISTANCES: - raise ValueError(f"Invalid distance metric '{self.distance}' provided. " - f"Must be one of: {', '.join(VALID_DISTANCES)}") + raise ValueError( + f"Invalid distance metric '{self.distance}' provided. " + f"Must be one of: {', '.join(VALID_DISTANCES)}" + ) + + if self.text_scorer not in VALID_TEXT_SCORERS: + raise ValueError( + f"Invalid text scorer '{self.text_scorer}' provided. " + f"Must be one of: {', '.join(VALID_TEXT_SCORERS)}" + ) @dataclass class RuntimeConfig(BaseDocIndex.RuntimeConfig): @@ -213,40 +262,29 @@ def _generate_item( if key == 'id' and not item: return - if item is None: - item_dict[key] = '__None__' - elif isinstance(item, AbstractTensor): - item_dict[key] = np.array( - item._docarray_to_ndarray(), dtype=np.float32 - ).tobytes() - else: + if isinstance(item, AbstractTensor): + item_dict[key] = item._docarray_to_ndarray().tolist() + elif isinstance(item, ndarray): + item_dict[key] = item.astype(np.float32).tolist() + elif item is not None: item_dict[key] = item yield item_dict def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): ids = [] - pipeline = self._client.pipeline(transaction=False) - batch_size = 10 # variable [1k] for item in self._generate_item(column_to_data): - doc_id = self._prefix + item.pop('id') - pipeline.hset( - doc_id, - mapping=item, - ) - ids.append(doc_id) - - if len(ids) % batch_size == 0: - pipeline.execute() - - pipeline.execute() + ids.append(item['id']) + doc_id = self._prefix + item['id'] + self._client.json().set(doc_id, '$', item) num_docs = self.num_docs() print('indexed', num_docs) return ids def num_docs(self) -> int: - return self._client.ft(self._db_config.index_name).info()['num_docs'] + num_docs = self._client.ft(self._db_config.index_name).info()['num_docs'] + return int(num_docs) def _del_items(self, doc_ids: Sequence[str]): doc_ids = [self._prefix + id for id in doc_ids if self._doc_exists(id)] @@ -262,55 +300,60 @@ def _get_items( if not doc_ids: return [] - pipe = self._client.pipeline() + docs = [] for id in doc_ids: - pipe.hgetall(self._prefix + id) + doc = self._client.json().get(self._prefix + id) + if doc: + docs.append(doc) - results = pipe.execute() - - docs = [ - {k.decode('utf-8'): v.decode('utf-8', 'ignore') for k, v in d.items()} - for d in results - ] - - docs = [{k: v for k, v in d.items() if k != 'tens'} for d in docs] # todo (vector decoding problem) - docs = [{k: None if v == '__None__' else v for k, v in d.items()} for d in docs] # todo (converting to None) + if len(docs) == 0: + raise KeyError(f'No document with id {doc_ids} found') return docs - def execute_query(self, query: Any, *args, **kwargs) -> Any: - pass - - def _convert_to_schema(self, document): - doc_kwargs = {} - for column, info in self._column_infos.items(): - if column == 'id': - doc_kwargs['id'] = document.id[len(self._prefix) :] - elif document[column] == '__None__': - doc_kwargs[column] = None - elif info.db_type == VectorField: - # byte_string = document[column] - # byte_data = byte_string.encode('utf-8') - doc_kwargs[column] = np.frombuffer(document[column], dtype=np.float32) - elif info.db_type == NumericField: - doc_kwargs[column] = info.docarray_type(document[column]) - else: - doc_kwargs[column] = document[column] - - return doc_kwargs + def execute_query(self, query: Any, *args: Any, **kwargs: Any) -> Any: + components: Dict[str, List[Dict[str, Any]]] = {} + for component, value in query: + if component not in components: + components[component] = [] + components[component].append(value) + + if ( + len(components) != 2 + or len(components.get('find', [])) != 1 + or len(components.get('filter', [])) != 1 + ): + raise ValueError( + 'The query must contain exactly one "find" and "filter" components.' + ) - def _find( - self, query: np.ndarray, limit: int, search_field: str = '' - ) -> _FindResult: - limit = 5 - query_str = '*' + filter_query = components['filter'][0]['filter_query'] + query = components['find'][0]['query'] + search_field = components['find'][0]['search_field'] + limit = ( + components['find'][0].get('limit') + or components['filter'][0].get('limit') + or 10 + ) + docs, scores = self._hybrid_search( + query=query, + filter_query=filter_query, + search_field=search_field, + limit=limit, + ) + docs = self._dict_list_to_docarray(docs) + return FindResult(documents=docs, scores=scores) + + def _hybrid_search( + self, query: np.ndarray, filter_query: str, search_field: str, limit: int + ): redis_query = ( - Query(f'{query_str}=>[KNN {limit} @{search_field} $vec AS vector_score]') + Query(f'{filter_query}=>[KNN {limit} @{search_field} $vec AS vector_score]') .sort_by('vector_score') .paging(0, limit) .dialect(2) ) query_params: Mapping[str, str] = { # type: ignore - 'vec': np.array(query, dtype=np.float32).tobytes() + 'vec': np.array(query, dtype=np.float32).tobytes() # type: ignore } results = ( self._client.ft(self._db_config.index_name) @@ -318,11 +361,23 @@ def _find( .docs ) - scores = [document['vector_score'] for document in results] - docs = [self._convert_to_schema(document) for document in results] + scores: NdArray = NdArray._docarray_from_native( + np.array([document['vector_score'] for document in results]) + ) + docs = [] + for out_doc in results: + doc_dict = json.loads(out_doc.json) + docs.append(doc_dict) return _FindResult(documents=docs, scores=scores) + def _find( + self, query: np.ndarray, limit: int, search_field: str = '' + ) -> _FindResult: + return self._hybrid_search( + query=query, filter_query='*', search_field=search_field, limit=limit + ) + def _find_batched( self, queries: np.ndarray, limit: int, search_field: str = '' ) -> _FindResultBatched: @@ -335,72 +390,70 @@ def _find_batched( return _FindResultBatched(documents=docs, scores=scores) def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: - query_str = self._get_redis_filter_query(filter_query) - q = Query(query_str) + q = Query(filter_query) q.paging(0, limit) results = self._client.ft(index_name=self._db_config.index_name).search(q).docs - docs = [self._convert_to_schema(document) for document in results] - + docs = [json.loads(doc.json) for doc in results] return docs - def _build_query_node(self, key, condition): - operator = list(condition.keys())[0] - value = condition[operator] - - query_dict = {} - - if operator in ['$ne', '$eq']: - if isinstance(value, bool): - query_dict[key] = equal(int(value)) - elif isinstance(value, (int, float)): - query_dict[key] = equal(value) - else: - query_dict[key] = value - elif operator == '$gt': - query_dict[key] = gt(value) - elif operator == '$gte': - query_dict[key] = ge(value) - elif operator == '$lt': - query_dict[key] = lt(value) - elif operator == '$lte': - query_dict[key] = le(value) - else: - raise ValueError( - f'Expecting filter operator one of $gt, $gte, $lt, $lte, $eq, $ne, $and OR $or, got {operator} instead' - ) - - if operator == '$ne': - return DistjunctUnion(**query_dict) - return IntersectNode(**query_dict) - - def _build_query_nodes(self, filter): - nodes = [] - for k, v in filter.items(): - if k == '$and': - children = self._build_query_nodes(v) - node = intersect(*children) - nodes.append(node) - elif k == '$or': - children = self._build_query_nodes(v) - node = union(*children) - nodes.append(node) - else: - child = self._build_query_node(k, v) - nodes.append(child) - - return nodes - - def _get_redis_filter_query(self, filter: Union[str, Dict]): - if isinstance(filter, dict): - nodes = self._build_query_nodes(filter) - query_str = intersect(*nodes).to_string() - elif isinstance(filter, str): - query_str = filter - else: - raise ValueError(f'Unexpected type of filter: {type(filter)}, expected str') - - return query_str + # def _build_query_node(self, key, condition): + # operator = list(condition.keys())[0] + # value = condition[operator] + # + # query_dict = {} + # + # if operator in ['$ne', '$eq']: + # if isinstance(value, bool): + # query_dict[key] = equal(int(value)) + # elif isinstance(value, (int, float)): + # query_dict[key] = equal(value) + # else: + # query_dict[key] = '"' + value + '"' + # elif operator == '$gt': + # query_dict[key] = gt(value) + # elif operator == '$gte': + # query_dict[key] = ge(value) + # elif operator == '$lt': + # query_dict[key] = lt(value) + # elif operator == '$lte': + # query_dict[key] = le(value) + # else: + # raise ValueError( + # f'Expecting filter operator one of $gt, $gte, $lt, $lte, $eq, $ne, $and OR $or, got {operator} instead' + # ) + # + # if operator == '$ne': + # return DistjunctUnion(**query_dict) + # return IntersectNode(**query_dict) + # + # def _build_query_nodes(self, filter): + # nodes = [] + # for k, v in filter.items(): + # if k == '$and': + # children = self._build_query_nodes(v) + # node = intersect(*children) + # nodes.append(node) + # elif k == '$or': + # children = self._build_query_nodes(v) + # node = union(*children) + # nodes.append(node) + # else: + # child = self._build_query_node(k, v) + # nodes.append(child) + # + # return nodes + # + # def _get_redis_filter_query(self, filter: Union[str, Dict]): + # if isinstance(filter, dict): + # nodes = self._build_query_nodes(filter) + # query_str = intersect(*nodes).to_string() + # elif isinstance(filter, str): + # query_str = filter + # else: + # raise ValueError(f'Unexpected type of filter: {type(filter)}, expected str') + # + # return query_str def _filter_batched( self, filter_queries: Any, limit: int @@ -414,30 +467,20 @@ def _text_search( self, query: str, limit: int, search_field: str = '' ) -> _FindResult: query_str = '|'.join(query.split(' ')) - - scorer = 'BM25' - if scorer not in [ - 'BM25', - 'TFIDF', - 'TFIDF.DOCNORM', - 'DISMAX', - 'DOCSCORE', - 'HAMMING', - ]: - raise ValueError( - f'Expecting a valid text similarity ranking algorithm, got {scorer} instead' - ) q = ( Query(f'@{search_field}:{query_str}') - .scorer(scorer) + .scorer(self._db_config.text_scorer) .with_scores() .paging(0, limit) ) results = self._client.ft(index_name=self._db_config.index_name).search(q).docs - scores = [document['score'] for document in results] - docs = [self._convert_to_schema(document) for document in results] + scores: NdArray = NdArray._docarray_from_native( + np.array([document['score'] for document in results]) + ) + + docs = [json.loads(doc.json) for doc in results] return _FindResult(documents=docs, scores=scores) diff --git a/tests/index/redis/test_configurations.py b/tests/index/redis/test_configurations.py new file mode 100644 index 00000000000..6cd7ae8f3d6 --- /dev/null +++ b/tests/index/redis/test_configurations.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import RedisDocumentIndex +from docarray.typing import NdArray +from tests.index.redis.fixtures import start_redis # noqa: F401 + + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +def test_configure_dim(): + class Schema(BaseDoc): + tens: NdArray = Field(dim=10) + + index = RedisDocumentIndex[Schema](host='localhost') + + docs = [Schema(tens=np.random.random((10,))) for _ in range(10)] + index.index(docs) + + assert index.num_docs() == 10 + + +def test_configure_index(tmp_path): + class Schema(BaseDoc): + tens: NdArray[100] = Field(space='cosine') + title: str + year: int + + types = {'id': 'TEXT', 'tens': 'VECTOR', 'title': 'TEXT', 'year': 'NUMERIC'} + index = RedisDocumentIndex[Schema](host='localhost') + + attr_bytes = index._client.ft(index._db_config.index_name).info()['attributes'] + attr = [[byte.decode() for byte in sublist] for sublist in attr_bytes] + + assert len(Schema.__fields__) == len(attr) + for field, attr in zip(Schema.__fields__, attr): + assert field in attr and types[field] in attr diff --git a/tests/index/redis/test_find.py b/tests/index/redis/test_find.py new file mode 100644 index 00000000000..35665389b1c --- /dev/null +++ b/tests/index/redis/test_find.py @@ -0,0 +1,294 @@ +from typing import Optional + +import numpy as np +import pytest +import torch +from pydantic import Field + +from docarray import BaseDoc, DocList +from docarray.index import RedisDocumentIndex +from docarray.typing import NdArray, TorchTensor +from tests.index.redis.fixtures import start_redis # noqa: F401 + +N_DIM = 10 + + +def get_simple_schema(**kwargs): + class SimpleSchema(BaseDoc): + tens: NdArray[N_DIM] = Field(**kwargs) + + return SimpleSchema + + +class TorchDoc(BaseDoc): + tens: TorchTensor[N_DIM] + + +@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) +def test_find_simple_schema(space): + schema = get_simple_schema(space=space) + db = RedisDocumentIndex[schema](host='localhost') + + index_docs = [schema(tens=np.random.rand(N_DIM)) for _ in range(10)] + index_docs.append(schema(tens=np.ones(N_DIM))) + + db.index(index_docs) + + query = schema(tens=np.ones(N_DIM)) + + docs, scores = db.find(query, search_field='tens', limit=5) + + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens, index_docs[-1].tens) + + +def test_find_empty_index(): + schema = get_simple_schema() + empty_index = RedisDocumentIndex[schema](host='localhost') + query = schema(tens=np.random.rand(N_DIM)) + + docs, scores = empty_index.find(query, search_field='tens', limit=5) + assert len(docs) == 0 + assert len(scores) == 0 + + +def test_find_limit_larger_than_index(): + schema = get_simple_schema() + db = RedisDocumentIndex[schema](host='localhost') + query = schema(tens=np.ones(10)) + index_docs = [schema(tens=np.zeros(10)) for _ in range(10)] + db.index(index_docs) + docs, scores = db.find(query, search_field='tens', limit=20) + assert len(docs) == 10 + assert len(scores) == 10 + + +@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) +def test_find_torch(space): + db = RedisDocumentIndex[TorchDoc](host='localhost') + index_docs = [TorchDoc(tens=np.random.rand(10)) for _ in range(10)] + index_docs.append(TorchDoc(tens=np.ones(10, dtype=np.float32))) + db.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TorchTensor) + + query = TorchDoc(tens=np.ones(10, dtype=np.float32)) + + result_docs, scores = db.find(query, search_field='tens', limit=5) + + assert len(result_docs) == 5 + assert len(scores) == 5 + for doc in result_docs: + assert isinstance(doc.tens, TorchTensor) + assert result_docs[0].id == index_docs[-1].id + assert torch.allclose(result_docs[0].tens, index_docs[-1].tens) + + +@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) +def test_find_flat_schema(space): + class FlatSchema(BaseDoc): + tens_one: NdArray = Field(dim=10, space=space) + tens_two: NdArray = Field(dim=50, space=space) + + index = RedisDocumentIndex[FlatSchema](host='localhost') + + index_docs = [ + FlatSchema(tens_one=np.random.rand(10), tens_two=np.random.rand(50)) + for _ in range(10) + ] + index_docs.append(FlatSchema(tens_one=np.zeros(10), tens_two=np.ones(50))) + index_docs.append(FlatSchema(tens_one=np.ones(10), tens_two=np.zeros(50))) + index.index(index_docs) + + query = FlatSchema(tens_one=np.ones(10), tens_two=np.ones(50)) + + # find on tens_one + docs, scores = index.find(query, search_field='tens_one', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens_one, index_docs[-1].tens_one) + assert np.allclose(docs[0].tens_two, index_docs[-1].tens_two) + + # find on tens_two + docs, scores = index.find(query, search_field='tens_two', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-2].id + assert np.allclose(docs[0].tens_one, index_docs[-2].tens_one) + assert np.allclose(docs[0].tens_two, index_docs[-2].tens_two) + + +@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) +def test_find_nested_schema(space): + class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(space=space) + + class NestedDoc(BaseDoc): + d: SimpleDoc + tens: NdArray[10] = Field(space=space) + + class DeepNestedDoc(BaseDoc): + d: NestedDoc + tens: NdArray = Field(space=space, dim=10) + + index = RedisDocumentIndex[DeepNestedDoc](host='localhost') + + index_docs = [ + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.random.rand(10)), tens=np.random.rand(10)), + tens=np.random.rand(10), + ) + for _ in range(10) + ] + index_docs.append( + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.ones(10)), tens=np.zeros(10)), + tens=np.zeros(10), + ) + ) + index_docs.append( + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.zeros(10)), tens=np.ones(10)), + tens=np.zeros(10), + ) + ) + index_docs.append( + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.zeros(10)), tens=np.zeros(10)), + tens=np.ones(10), + ) + ) + index.index(index_docs) + + query = DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.ones(10)), tens=np.ones(10)), tens=np.ones(10) + ) + + # find on root level + docs, scores = index.find(query, search_field='tens', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens, index_docs[-1].tens) + + # find on first nesting level + docs, scores = index.find(query, search_field='d__tens', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-2].id + assert np.allclose(docs[0].d.tens, index_docs[-2].d.tens) + + # find on second nesting level + docs, scores = index.find(query, search_field='d__d__tens', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-3].id + assert np.allclose(docs[0].d.d.tens, index_docs[-3].d.d.tens) + + +def test_simple_usage(): + class MyDoc(BaseDoc): + text: str + embedding: NdArray[128] + + docs = [MyDoc(text='hey', embedding=np.random.rand(128)) for _ in range(200)] + queries = docs[0:3] + index = RedisDocumentIndex[MyDoc](host='localhost') + index.index(docs=DocList[MyDoc](docs)) + resp = index.find_batched(queries=queries, search_field='embedding', limit=10) + docs_responses = resp.documents + assert len(docs_responses) == 3 + for q, matches in zip(queries, docs_responses): + assert len(matches) == 10 + assert q.id == matches[0].id + + +def test_query_builder(): + class SimpleSchema(BaseDoc): + tensor: NdArray[10] = Field(space='cosine') + price: int + + db = RedisDocumentIndex[SimpleSchema](host='localhost') + + index_docs = [ + SimpleSchema(tensor=np.array([i + 1] * 10), price=i + 1) for i in range(10) + ] + db.index(index_docs) + + q = ( + db.build_query() + .find(query=np.ones(10), search_field='tensor', limit=5) + .filter(filter_query='@price:[-inf 3]') + .build() + ) + + docs, scores = db.execute_query(q) + + assert len(docs) == 3 + for doc in docs: + assert doc.price <= 3 + + +def test_text_search(): + class SimpleSchema(BaseDoc): + description: str + some_field: Optional[int] + + texts_to_index = [ + "Text processing with Python is a valuable skill for data analysis.", + "Gardening tips for a beautiful backyard oasis.", + "Explore the wonders of deep-sea diving in tropical locations.", + "The history and art of classical music compositions.", + "An introduction to the world of gourmet cooking.", + ] + + query_string = "Python and text processing" + + docs = [SimpleSchema(description=text) for text in texts_to_index] + + db = RedisDocumentIndex[SimpleSchema](host='localhost') + db.index(docs) + + docs, _ = db.text_search(query=query_string, search_field='description') + + assert docs[0].description == texts_to_index[0] + + +def test_filter(): + class SimpleSchema(BaseDoc): + description: str + price: int + + doc1 = SimpleSchema(description='Python book', price=50) + doc2 = SimpleSchema(description='Python book by some author', price=60) + doc3 = SimpleSchema(description='Random book', price=40) + docs = [doc1, doc2, doc3] + + db = RedisDocumentIndex[SimpleSchema](host='localhost') + db.index(docs) + + # filter on price < 45 + docs = db.filter(filter_query='@price:[-inf 45]') + assert len(docs) == 1 + assert docs[0].price == 40 + + # filter on price >= 50 + docs = db.filter(filter_query='@price:[50 inf]') + assert len(docs) == 2 + for doc in docs: + assert doc.price >= 50 + + # get documents with the phrase "python book" in the description + docs = db.filter(filter_query='@description:"python book"') + assert len(docs) == 2 + for doc in docs: + assert 'python book' in doc.description.lower() + + # get documents with the word "book" in the description that have price <= 45 + docs = db.filter(filter_query='@description:"book" @price:[-inf 45]') + assert len(docs) == 1 + assert docs[0].description == 'Random book' and docs[0].price == 40 diff --git a/tests/index/redis/test_index_get_del.py b/tests/index/redis/test_index_get_del.py new file mode 100644 index 00000000000..dd6cc64433d --- /dev/null +++ b/tests/index/redis/test_index_get_del.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import RedisDocumentIndex +from docarray.typing import NdArray +from tests.index.redis.fixtures import start_redis # noqa: F401 + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(dim=1000) + + +@pytest.fixture +def ten_simple_docs(): + return [SimpleDoc(tens=np.random.randn(10)) for _ in range(10)] + + +def test_num_docs(ten_simple_docs): + index = RedisDocumentIndex[SimpleDoc](host='localhost') + index.index(ten_simple_docs) + + assert index.num_docs() == 10 + + del index[ten_simple_docs[0].id] + assert index.num_docs() == 9 + + del index[ten_simple_docs[3].id, ten_simple_docs[5].id] + assert index.num_docs() == 7 + + more_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(5)] + index.index(more_docs) + assert index.num_docs() == 12 + + del index[more_docs[2].id, ten_simple_docs[7].id] + assert index.num_docs() == 10 + + +def test_get_single(ten_simple_docs): + index = RedisDocumentIndex[SimpleDoc](host='localhost') + index.index(ten_simple_docs) + + assert index.num_docs() == 10 + doc_to_get = ten_simple_docs[3] + doc_id = doc_to_get.id + retrieved_doc = index[doc_id] + assert retrieved_doc.id == doc_id + assert np.allclose(retrieved_doc.tens, doc_to_get.tens) + + with pytest.raises(KeyError): + index['some_id'] + + +def test_get_multiple(ten_simple_docs): + docs_to_get_idx = [0, 2, 4, 6, 8] + index = RedisDocumentIndex[SimpleDoc](host='localhost') + index.index(ten_simple_docs) + + assert index.num_docs() == 10 + docs_to_get = [ten_simple_docs[i] for i in docs_to_get_idx] + ids_to_get = [d.id for d in docs_to_get] + retrieved_docs = index[ids_to_get] + for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): + assert d_out.id == id_ + assert np.allclose(d_out.tens, d_in.tens) + + +def test_del_single(ten_simple_docs): + index = RedisDocumentIndex[SimpleDoc](host='localhost') + index.index(ten_simple_docs) + assert index.num_docs() == 10 + + doc_id = ten_simple_docs[3].id + del index[doc_id] + + assert index.num_docs() == 9 + + with pytest.raises(KeyError): + index[doc_id] + + +def test_del_multiple(ten_simple_docs): + docs_to_del_idx = [0, 2, 4, 6, 8] + + index = RedisDocumentIndex[SimpleDoc](host='localhost') + index.index(ten_simple_docs) + + assert index.num_docs() == 10 + docs_to_del = [ten_simple_docs[i] for i in docs_to_del_idx] + ids_to_del = [d.id for d in docs_to_del] + del index[ids_to_del] + for i, doc in enumerate(ten_simple_docs): + if i in docs_to_del_idx: + with pytest.raises(KeyError): + index[doc.id] + else: + assert index[doc.id].id == doc.id + assert np.allclose(index[doc.id].tens, doc.tens) diff --git a/tests/index/redis/test_persist_data.py b/tests/index/redis/test_persist_data.py new file mode 100644 index 00000000000..3898e1827fa --- /dev/null +++ b/tests/index/redis/test_persist_data.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import RedisDocumentIndex +from docarray.typing import NdArray +from tests.index.redis.fixtures import start_redis # noqa: F401 + + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(dim=1000) + + +def test_persist(): + query = SimpleDoc(tens=np.random.random((10,))) + + # create index + index = RedisDocumentIndex[SimpleDoc](host='localhost') + index_name = index._db_config.index_name + + assert index.num_docs() == 0 + + index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) + assert index.num_docs() == 10 + find_results_before = index.find(query, search_field='tens', limit=5) + + # load existing index + index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=index_name) + assert index.num_docs() == 10 + find_results_after = index.find(query, search_field='tens', limit=5) + for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): + assert doc_before.id == doc_after.id + assert (doc_before.tens == doc_after.tens).all() + + # add new data + index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)]) + assert index.num_docs() == 15 diff --git a/tests/index/redis/tests.py b/tests/index/redis/tests.py deleted file mode 100644 index 65ef62a8cae..00000000000 --- a/tests/index/redis/tests.py +++ /dev/null @@ -1,74 +0,0 @@ -import numpy as np -import pytest - -from docarray import BaseDoc -from docarray.index import RedisDocumentIndex -from pydantic import Field -from docarray.typing import NdArray -from tests.index.redis.fixtures import start_redis, tmp_collection_name # noqa: F401 -from typing import Optional - - -@pytest.mark.parametrize('space', ['cosine']) -def test_find_simple_schema(space): - class SimpleSchema(BaseDoc): - tens: Optional[NdArray[10]] = Field(space=space, algorithm='HNSW') # type: ignore[valid-type] - bla: int - title: str - smth: Optional[str] = None - tenss: Optional[NdArray[10]] = None - - index = RedisDocumentIndex[SimpleSchema](host='localhost') - - docs = [SimpleSchema(bla=i, title=f'zdall {i}', tens=np.random.rand(10)) for i in range(5)] - docs.append(SimpleSchema(bla=6, title=f'hey everyone how are you', tens=np.random.rand(10))) - docs.append(SimpleSchema(bla=7, title=f'hey how are you', tens=np.random.rand(10))) - - - index.index(docs) - - query = np.random.rand(10) - results = index.find(query, search_field='tens') - print(len(results)) - - results = index.find_batched(np.array([np.random.rand(10), np.random.rand(10)]), search_field='tens') - print('find batched', results) - res = index[docs[0].id] - print(index.num_docs()) - del index[docs[0].id] - print(index.num_docs()) - - docs = index.filter({'bla': {'$gt': 3}}) - - print('filtered', docs) - - docs = index.filter_batched([{'bla': {'$gt': 3}}, {'bla': {'$lte': 3}}]) - print('batched filt', docs) - - docs = index.text_search(query='hey everyone', search_field='title') - print(docs) - - docs = index.text_search_batched(queries=['hey hey', 'hey everyone'], search_field='title') - print(docs) - - -def test_simple_scenario(): - # Define a document schema - class SimpleSchema(BaseDoc): - tensor: Optional[NdArray[10]] = Field(space='COSINE') - year: int - title: Optional[str] = None - - # Create a document index - index = RedisDocumentIndex[SimpleSchema](host='localhost') - - # Prepare documents - docs = [SimpleSchema(year=i, title=f'some text {i}', tensor=np.random.rand(10)) for i in range(5)] - - # Index - index.index(docs) - - # Search - query = np.random.rand(10) - results = index.find(query, search_field='tensor') - print(results) From cb25869c0e46b952ea18a5dbfd2fd8204e18ef69 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 09:21:03 +0200 Subject: [PATCH 05/28] chore: update poetry lock Signed-off-by: jupyterjazz --- docarray/utils/_internal/misc.py | 2 +- poetry.lock | 217 +++++++++++++++++++++++++++++-- tests/index/redis/fixtures.py | 4 +- 3 files changed, 211 insertions(+), 12 deletions(-) diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index b2a5d2ea3fc..207ad832f61 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -42,7 +42,7 @@ 'smart_open': '"docarray[aws]"', 'boto3': '"docarray[aws]"', 'botocore': '"docarray[aws]"', - 'redis': '"docarray[redis]' + 'redis': '"docarray[redis]', } diff --git a/poetry.lock b/poetry.lock index a36859ff04e..55e1d93a79b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "aiohttp" version = "3.8.4" description = "Async http client/server framework (asyncio)" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -112,6 +113,7 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -126,6 +128,7 @@ frozenlist = ">=1.1.0" name = "anyio" version = "3.6.2" description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" optional = false python-versions = ">=3.6.2" files = [ @@ -146,6 +149,7 @@ trio = ["trio (>=0.16,<0.22)"] name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" +category = "dev" optional = false python-versions = "*" files = [ @@ -157,6 +161,7 @@ files = [ name = "argon2-cffi" version = "21.3.0" description = "The secure Argon2 password hashing algorithm." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -176,6 +181,7 @@ tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"] name = "argon2-cffi-bindings" version = "21.2.0" description = "Low-level CFFI bindings for Argon2" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -213,6 +219,7 @@ tests = ["pytest"] name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -224,6 +231,7 @@ files = [ name = "attrs" version = "22.1.0" description = "Classes Without Boilerplate" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -241,6 +249,7 @@ tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy name = "authlib" version = "1.2.0" description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." +category = "main" optional = true python-versions = "*" files = [ @@ -255,6 +264,7 @@ cryptography = ">=3.2" name = "av" version = "10.0.0" description = "Pythonic bindings for FFmpeg's libraries." +category = "main" optional = true python-versions = "*" files = [ @@ -308,6 +318,7 @@ files = [ name = "babel" version = "2.11.0" description = "Internationalization utilities" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -322,6 +333,7 @@ pytz = ">=2015.7" name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" +category = "dev" optional = false python-versions = "*" files = [ @@ -333,6 +345,7 @@ files = [ name = "beautifulsoup4" version = "4.11.1" description = "Screen-scraping library" +category = "dev" optional = false python-versions = ">=3.6.0" files = [ @@ -351,6 +364,7 @@ lxml = ["lxml"] name = "black" version = "22.10.0" description = "The uncompromising code formatter." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -395,6 +409,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "blacken-docs" version = "1.13.0" description = "Run Black on Python code blocks in documentation files." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -409,6 +424,7 @@ black = ">=22.1.0" name = "bleach" version = "5.0.1" description = "An easy safelist-based HTML-sanitizing tool." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -428,6 +444,7 @@ dev = ["Sphinx (==4.3.2)", "black (==22.3.0)", "build (==0.8.0)", "flake8 (==4.0 name = "boto3" version = "1.26.95" description = "The AWS SDK for Python" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -447,6 +464,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.29.95" description = "Low-level, data-driven core of boto 3." +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -466,6 +484,7 @@ crt = ["awscrt (==0.16.9)"] name = "bracex" version = "2.3.post1" description = "Bash style brace expander." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -477,6 +496,7 @@ files = [ name = "certifi" version = "2022.9.24" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -488,6 +508,7 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." +category = "main" optional = false python-versions = "*" files = [ @@ -564,6 +585,7 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." +category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -575,6 +597,7 @@ files = [ name = "chardet" version = "5.1.0" description = "Universal encoding detector for Python 3" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -586,6 +609,7 @@ files = [ name = "charset-normalizer" version = "2.0.12" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" optional = false python-versions = ">=3.5.0" files = [ @@ -600,6 +624,7 @@ unicode-backport = ["unicodedata2"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -614,6 +639,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -625,6 +651,7 @@ files = [ name = "colorlog" version = "6.7.0" description = "Add colours to the output of Python's logging module." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -642,6 +669,7 @@ development = ["black", "flake8", "mypy", "pytest", "types-colorama"] name = "commonmark" version = "0.9.1" description = "Python parser for the CommonMark Markdown spec" +category = "main" optional = false python-versions = "*" files = [ @@ -656,6 +684,7 @@ test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] name = "cryptography" version = "40.0.1" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -697,6 +726,7 @@ tox = ["tox"] name = "debugpy" version = "1.6.3" description = "An implementation of the Debug Adapter Protocol for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -724,6 +754,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -735,6 +766,7 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -746,6 +778,7 @@ files = [ name = "distlib" version = "0.3.6" description = "Distribution utilities" +category = "dev" optional = false python-versions = "*" files = [ @@ -757,6 +790,7 @@ files = [ name = "docker" version = "6.0.1" description = "A Python library for the Docker Engine API." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -778,6 +812,7 @@ ssh = ["paramiko (>=2.4.3)"] name = "ecdsa" version = "0.18.0" description = "ECDSA cryptographic signature library (pure python)" +category = "main" optional = true python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -796,6 +831,7 @@ gmpy2 = ["gmpy2"] name = "elastic-transport" version = "8.4.0" description = "Transport classes and utilities shared among Python Elastic client libraries" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -814,6 +850,7 @@ develop = ["aiohttp", "mock", "pytest", "pytest-asyncio", "pytest-cov", "pytest- name = "elasticsearch" version = "7.10.1" description = "Python client for Elasticsearch" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4" files = [ @@ -835,6 +872,7 @@ requests = ["requests (>=2.4.0,<3.0.0)"] name = "entrypoints" version = "0.4" description = "Discover and load entry points from installed packages." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -846,6 +884,7 @@ files = [ name = "exceptiongroup" version = "1.1.0" description = "Backport of PEP 654 (exception groups)" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -860,6 +899,7 @@ test = ["pytest (>=6)"] name = "fastapi" version = "0.87.0" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -881,6 +921,7 @@ test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==22.8.0)", "coverage[toml] (>=6 name = "fastjsonschema" version = "2.16.2" description = "Fastest Python implementation of JSON schema" +category = "dev" optional = false python-versions = "*" files = [ @@ -895,6 +936,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.8.0" description = "A platform independent file lock." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -910,6 +952,7 @@ testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pyt name = "frozenlist" version = "1.3.3" description = "A list-like structure which implements collections.abc.MutableSequence" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -993,6 +1036,7 @@ files = [ name = "ghp-import" version = "2.1.0" description = "Copy your docs directly to the gh-pages branch." +category = "dev" optional = false python-versions = "*" files = [ @@ -1010,6 +1054,7 @@ dev = ["flake8", "markdown", "twine", "wheel"] name = "griffe" version = "0.25.5" description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1027,6 +1072,7 @@ async = ["aiofiles (>=0.7,<1.0)"] name = "grpcio" version = "1.53.0" description = "HTTP/2-based RPC framework" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1084,6 +1130,7 @@ protobuf = ["grpcio-tools (>=1.53.0)"] name = "grpcio-tools" version = "1.53.0" description = "Protobuf code generator for gRPC" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1143,6 +1190,7 @@ setuptools = "*" name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1154,6 +1202,7 @@ files = [ name = "h2" version = "4.1.0" description = "HTTP/2 State-Machine based protocol implementation" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -1169,6 +1218,7 @@ hyperframe = ">=6.0,<7" name = "hnswlib" version = "0.7.0" description = "hnswlib" +category = "main" optional = true python-versions = "*" files = [ @@ -1182,6 +1232,7 @@ numpy = "*" name = "hpack" version = "4.0.0" description = "Pure-Python HPACK header compression" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -1193,6 +1244,7 @@ files = [ name = "httpcore" version = "0.16.1" description = "A minimal low-level HTTP client." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1204,16 +1256,17 @@ files = [ anyio = ">=3.0,<5.0" certifi = "*" h11 = ">=0.13,<0.15" -sniffio = "==1.*" +sniffio = ">=1.0.0,<2.0.0" [package.extras] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] [[package]] name = "httpx" version = "0.23.1" description = "The next generation HTTP client." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1230,14 +1283,15 @@ sniffio = "*" [package.extras] brotli = ["brotli", "brotlicffi"] -cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<13)"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<13)"] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] [[package]] name = "hyperframe" version = "6.0.1" description = "HTTP/2 framing layer for Python" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -1249,6 +1303,7 @@ files = [ name = "identify" version = "2.5.8" description = "File identification library for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1263,6 +1318,7 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1274,6 +1330,7 @@ files = [ name = "importlib-metadata" version = "5.0.0" description = "Read metadata from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1293,6 +1350,7 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "5.10.0" description = "Read resources from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1311,6 +1369,7 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "iniconfig" version = "1.1.1" description = "iniconfig: brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = "*" files = [ @@ -1322,6 +1381,7 @@ files = [ name = "ipykernel" version = "6.16.2" description = "IPython Kernel for Jupyter" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1350,6 +1410,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-cov", "p name = "ipython" version = "7.34.0" description = "IPython: Productive Interactive Computing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1386,6 +1447,7 @@ test = ["ipykernel", "nbformat", "nose (>=0.10.1)", "numpy (>=1.17)", "pygments" name = "ipython-genutils" version = "0.2.0" description = "Vestigial utilities from IPython" +category = "dev" optional = false python-versions = "*" files = [ @@ -1397,6 +1459,7 @@ files = [ name = "isort" version = "5.11.5" description = "A Python utility / library to sort Python imports." +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -1414,6 +1477,7 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] name = "jedi" version = "0.18.1" description = "An autocompletion tool for Python that can be used for text editors." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1432,6 +1496,7 @@ testing = ["Django (<3.1)", "colorama", "docopt", "pytest (<7.0.0)"] name = "jina-hubble-sdk" version = "0.34.0" description = "SDK for Hubble API at Jina AI." +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -1457,6 +1522,7 @@ full = ["aiohttp", "black (==22.3.0)", "docker", "filelock", "flake8 (==4.0.1)", name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1474,6 +1540,7 @@ i18n = ["Babel (>=2.7)"] name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1485,6 +1552,7 @@ files = [ name = "json5" version = "0.9.10" description = "A Python implementation of the JSON5 data format." +category = "dev" optional = false python-versions = "*" files = [ @@ -1499,6 +1567,7 @@ dev = ["hypothesis"] name = "jsonschema" version = "4.17.0" description = "An implementation of JSON Schema validation for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1520,6 +1589,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jupyter-client" version = "7.4.6" description = "Jupyter protocol implementation and client libraries" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1544,6 +1614,7 @@ test = ["codecov", "coverage", "ipykernel (>=6.12)", "ipython", "mypy", "pre-com name = "jupyter-core" version = "4.12.0" description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1562,6 +1633,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyter-server" version = "1.23.2" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1594,6 +1666,7 @@ test = ["coverage", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console name = "jupyterlab" version = "3.5.0" description = "JupyterLab computational environment" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1621,6 +1694,7 @@ ui-tests = ["build"] name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1632,6 +1706,7 @@ files = [ name = "jupyterlab-server" version = "2.16.3" description = "A set of server components for JupyterLab and JupyterLab like applications." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1658,6 +1733,7 @@ test = ["codecov", "ipykernel", "jupyter-server[test]", "openapi-core (>=0.14.2, name = "lxml" version = "4.9.2" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" files = [ @@ -1750,6 +1826,7 @@ source = ["Cython (>=0.29.7)"] name = "lz4" version = "4.3.2" description = "LZ4 Bindings for Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1799,6 +1876,7 @@ tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] name = "mapbox-earcut" version = "1.0.1" description = "Python bindings for the mapbox earcut C++ polygon triangulation library." +category = "main" optional = true python-versions = "*" files = [ @@ -1873,6 +1951,7 @@ test = ["pytest"] name = "markdown" version = "3.3.7" description = "Python implementation of Markdown." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1890,6 +1969,7 @@ testing = ["coverage", "pyyaml"] name = "markupsafe" version = "2.1.1" description = "Safely add untrusted strings to HTML/XML markup." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1939,6 +2019,7 @@ files = [ name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1953,6 +2034,7 @@ traitlets = "*" name = "mergedeep" version = "1.3.4" description = "A deep merge function for 🐍." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1964,6 +2046,7 @@ files = [ name = "mistune" version = "2.0.4" description = "A sane Markdown parser with useful plugins and renderers" +category = "dev" optional = false python-versions = "*" files = [ @@ -1975,6 +2058,7 @@ files = [ name = "mkdocs" version = "1.4.2" description = "Project documentation with Markdown." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2003,6 +2087,7 @@ min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4)", "ghp-imp name = "mkdocs-autorefs" version = "0.4.1" description = "Automatically link across pages in MkDocs." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2018,6 +2103,7 @@ mkdocs = ">=1.1" name = "mkdocs-awesome-pages-plugin" version = "2.8.0" description = "An MkDocs plugin that simplifies configuring page titles and their order" +category = "dev" optional = false python-versions = ">=3.6.2" files = [ @@ -2034,6 +2120,7 @@ wcmatch = ">=7" name = "mkdocs-material" version = "9.1.3" description = "Documentation that simply works" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2056,6 +2143,7 @@ requests = ">=2.26" name = "mkdocs-material-extensions" version = "1.1.1" description = "Extension pack for Python Markdown and MkDocs Material." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2067,6 +2155,7 @@ files = [ name = "mkdocs-video" version = "1.5.0" description = "" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2082,6 +2171,7 @@ mkdocs = ">=1.1.0,<2" name = "mkdocstrings" version = "0.20.0" description = "Automatic documentation from sources, for MkDocs." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2107,6 +2197,7 @@ python-legacy = ["mkdocstrings-python-legacy (>=0.2.1)"] name = "mkdocstrings-python" version = "0.8.3" description = "A Python handler for mkdocstrings." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2122,6 +2213,7 @@ mkdocstrings = ">=0.19" name = "mktestdocs" version = "0.2.0" description = "" +category = "dev" optional = false python-versions = "*" files = [ @@ -2136,6 +2228,7 @@ test = ["pytest (>=4.0.2)"] name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" +category = "main" optional = true python-versions = "*" files = [ @@ -2153,6 +2246,7 @@ tests = ["pytest (>=4.6)"] name = "multidict" version = "6.0.4" description = "multidict implementation" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2236,6 +2330,7 @@ files = [ name = "mypy" version = "1.0.0" description = "Optional static typing for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2282,6 +2377,7 @@ reports = ["lxml"] name = "mypy-extensions" version = "0.4.3" description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "main" optional = false python-versions = "*" files = [ @@ -2293,6 +2389,7 @@ files = [ name = "natsort" version = "8.3.1" description = "Simple yet flexible natural sorting in Python." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2308,6 +2405,7 @@ icu = ["PyICU (>=1.0.0)"] name = "nbclassic" version = "0.4.8" description = "A web-based notebook environment for interactive computing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2343,6 +2441,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "pytest-playwright", "pytes name = "nbclient" version = "0.7.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -2364,6 +2463,7 @@ test = ["black", "check-manifest", "flake8", "ipykernel", "ipython", "ipywidgets name = "nbconvert" version = "7.2.5" description = "Converting Jupyter Notebooks" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2402,6 +2502,7 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.7.0" description = "The Jupyter Notebook format" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2422,6 +2523,7 @@ test = ["check-manifest", "pep440", "pre-commit", "pytest", "testpath"] name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2433,6 +2535,7 @@ files = [ name = "networkx" version = "2.6.3" description = "Python package for creating and manipulating graphs and networks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2451,6 +2554,7 @@ test = ["codecov (>=2.1)", "pytest (>=6.2)", "pytest-cov (>=2.12)"] name = "nodeenv" version = "1.7.0" description = "Node.js virtual environment builder" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -2465,6 +2569,7 @@ setuptools = "*" name = "notebook" version = "6.5.2" description = "A web-based notebook environment for interactive computing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2499,6 +2604,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "requests", "requests-unixs name = "notebook-shim" version = "0.2.2" description = "A shim layer for notebook traits and config" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2516,6 +2622,7 @@ test = ["pytest", "pytest-console-scripts", "pytest-tornasync"] name = "numpy" version = "1.21.1" description = "NumPy is the fundamental package for array computing with Python." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2553,6 +2660,7 @@ files = [ name = "nvidia-cublas-cu11" version = "11.10.3.66" description = "CUBLAS native runtime libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2568,6 +2676,7 @@ wheel = "*" name = "nvidia-cuda-nvrtc-cu11" version = "11.7.99" description = "NVRTC native runtime libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2584,6 +2693,7 @@ wheel = "*" name = "nvidia-cuda-runtime-cu11" version = "11.7.99" description = "CUDA Runtime native Libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2599,6 +2709,7 @@ wheel = "*" name = "nvidia-cudnn-cu11" version = "8.5.0.96" description = "cuDNN runtime libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2614,6 +2725,7 @@ wheel = "*" name = "orjson" version = "3.8.2" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2672,6 +2784,7 @@ files = [ name = "packaging" version = "21.3" description = "Core utilities for Python packages" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2686,6 +2799,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" name = "pandas" version = "1.1.0" description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -2719,6 +2833,7 @@ test = ["hypothesis (>=3.58)", "pytest (>=4.0.2)", "pytest-xdist"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2730,6 +2845,7 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2745,6 +2861,7 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathspec" version = "0.10.2" description = "Utility library for gitignore style pattern matching of file paths." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2756,6 +2873,7 @@ files = [ name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." +category = "dev" optional = false python-versions = "*" files = [ @@ -2770,6 +2888,7 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" +category = "dev" optional = false python-versions = "*" files = [ @@ -2781,6 +2900,7 @@ files = [ name = "pillow" version = "9.3.0" description = "Python Imaging Library (Fork)" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2855,6 +2975,7 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2866,6 +2987,7 @@ files = [ name = "platformdirs" version = "2.5.4" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2881,6 +3003,7 @@ test = ["appdirs (==1.4.4)", "pytest (>=7.2)", "pytest-cov (>=4)", "pytest-mock name = "pluggy" version = "0.13.1" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2895,6 +3018,7 @@ dev = ["pre-commit", "tox"] name = "pre-commit" version = "2.20.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2914,6 +3038,7 @@ virtualenv = ">=20.0.8" name = "prometheus-client" version = "0.15.0" description = "Python client for the Prometheus monitoring system." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2928,6 +3053,7 @@ twisted = ["twisted"] name = "prompt-toolkit" version = "3.0.32" description = "Library for building powerful interactive command lines in Python" +category = "dev" optional = false python-versions = ">=3.6.2" files = [ @@ -2942,6 +3068,7 @@ wcwidth = "*" name = "protobuf" version = "4.21.9" description = "" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2965,6 +3092,7 @@ files = [ name = "psutil" version = "5.9.4" description = "Cross-platform lib for process and system monitoring in Python." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2991,6 +3119,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -3002,6 +3131,7 @@ files = [ name = "py" version = "1.11.0" description = "library with cross-python path, ini-parsing, io, code, log facilities" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -3013,6 +3143,7 @@ files = [ name = "pyasn1" version = "0.4.8" description = "ASN.1 types and codecs" +category = "main" optional = true python-versions = "*" files = [ @@ -3024,6 +3155,7 @@ files = [ name = "pycollada" version = "0.7.2" description = "python library for reading and writing collada documents" +category = "main" optional = true python-versions = "*" files = [ @@ -3041,6 +3173,7 @@ validation = ["lxml"] name = "pycparser" version = "2.21" description = "C parser in Python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3052,6 +3185,7 @@ files = [ name = "pydantic" version = "1.10.2" description = "Data validation and settings management using python type hints" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3104,6 +3238,7 @@ email = ["email-validator (>=1.0.3)"] name = "pydub" version = "0.25.1" description = "Manipulate audio with an simple and easy high level interface" +category = "main" optional = true python-versions = "*" files = [ @@ -3115,6 +3250,7 @@ files = [ name = "pygments" version = "2.14.0" description = "Pygments is a syntax highlighting package written in Python." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3129,6 +3265,7 @@ plugins = ["importlib-metadata"] name = "pymdown-extensions" version = "9.10" description = "Extension pack for Python Markdown." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3144,6 +3281,7 @@ pyyaml = "*" name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -3158,6 +3296,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyrsistent" version = "0.19.2" description = "Persistent/Functional/Immutable data structures" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3189,6 +3328,7 @@ files = [ name = "pytest" version = "7.2.1" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3212,6 +3352,7 @@ testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2. name = "pytest-asyncio" version = "0.20.2" description = "Pytest support for asyncio" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3229,6 +3370,7 @@ testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -3243,6 +3385,7 @@ six = ">=1.5" name = "python-jose" version = "3.3.0" description = "JOSE implementation in Python" +category = "main" optional = true python-versions = "*" files = [ @@ -3264,6 +3407,7 @@ pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] name = "pytz" version = "2022.6" description = "World timezone definitions, modern and historical" +category = "main" optional = false python-versions = "*" files = [ @@ -3275,6 +3419,7 @@ files = [ name = "pywin32" version = "305" description = "Python for Window Extensions" +category = "main" optional = false python-versions = "*" files = [ @@ -3298,6 +3443,7 @@ files = [ name = "pywinpty" version = "2.0.9" description = "Pseudo terminal support for Windows from Python." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3313,6 +3459,7 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3362,6 +3509,7 @@ files = [ name = "pyyaml-env-tag" version = "0.1" description = "A custom YAML tag for referencing environment variables in YAML files. " +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3376,6 +3524,7 @@ pyyaml = "*" name = "pyzmq" version = "24.0.1" description = "Python bindings for 0MQ" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3463,6 +3612,7 @@ py = {version = "*", markers = "implementation_name == \"pypy\""} name = "qdrant-client" version = "1.1.4" description = "Client library for the Qdrant vector search engine" +category = "main" optional = true python-versions = ">=3.7,<3.12" files = [ @@ -3484,7 +3634,7 @@ name = "redis" version = "4.5.5" description = "Python client for Redis database and key-value store" category = "main" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "redis-4.5.5-py3-none-any.whl", hash = "sha256:77929bc7f5dab9adf3acba2d3bb7d7658f1e0c2f1cafe7eb36434e751c471119"}, @@ -3493,8 +3643,6 @@ files = [ [package.dependencies] async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""} -importlib-metadata = {version = ">=1.0", markers = "python_version < \"3.8\""} -typing-extensions = {version = "*", markers = "python_version < \"3.8\""} [package.extras] hiredis = ["hiredis (>=1.0.0)"] @@ -3504,6 +3652,7 @@ ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)" name = "regex" version = "2022.10.31" description = "Alternative regular expression module, to replace re." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3601,6 +3750,7 @@ files = [ name = "requests" version = "2.28.2" description = "Python HTTP for Humans." +category = "main" optional = false python-versions = ">=3.7, <4" files = [ @@ -3622,6 +3772,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rfc3986" version = "1.5.0" description = "Validating URI References per RFC 3986" +category = "main" optional = false python-versions = "*" files = [ @@ -3639,6 +3790,7 @@ idna2008 = ["idna"] name = "rich" version = "13.1.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3658,6 +3810,7 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" +category = "main" optional = true python-versions = ">=3.6,<4" files = [ @@ -3672,6 +3825,7 @@ pyasn1 = ">=0.1.3" name = "rtree" version = "1.0.1" description = "R-Tree spatial index for Python GIS" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3726,6 +3880,7 @@ files = [ name = "ruff" version = "0.0.243" description = "An extremely fast Python linter, written in Rust." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3751,6 +3906,7 @@ files = [ name = "s3transfer" version = "0.6.0" description = "An Amazon S3 Transfer Manager" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -3768,6 +3924,7 @@ crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] name = "scipy" version = "1.6.1" description = "SciPy: Scientific Library for Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3799,6 +3956,7 @@ numpy = ">=1.16.5" name = "send2trash" version = "1.8.0" description = "Send file to trash natively under Mac OS X, Windows and Linux." +category = "dev" optional = false python-versions = "*" files = [ @@ -3815,6 +3973,7 @@ win32 = ["pywin32"] name = "setuptools" version = "65.5.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3831,6 +3990,7 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( name = "shapely" version = "2.0.1" description = "Manipulation and analysis of geometric objects" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3878,13 +4038,14 @@ files = [ numpy = ">=1.14" [package.extras] -docs = ["matplotlib", "numpydoc (==1.1.*)", "sphinx", "sphinx-book-theme", "sphinx-remove-toctrees"] +docs = ["matplotlib", "numpydoc (>=1.1.0,<1.2.0)", "sphinx", "sphinx-book-theme", "sphinx-remove-toctrees"] test = ["pytest", "pytest-cov"] [[package]] name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -3896,6 +4057,7 @@ files = [ name = "smart-open" version = "6.3.0" description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)" +category = "main" optional = true python-versions = ">=3.6,<4.0" files = [ @@ -3920,6 +4082,7 @@ webhdfs = ["requests"] name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3931,6 +4094,7 @@ files = [ name = "soupsieve" version = "2.3.2.post1" description = "A modern CSS selector implementation for Beautiful Soup." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3942,6 +4106,7 @@ files = [ name = "starlette" version = "0.21.0" description = "The little ASGI library that shines." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3960,6 +4125,7 @@ full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyam name = "svg-path" version = "6.2" description = "SVG path objects and parser" +category = "main" optional = true python-versions = "*" files = [ @@ -3974,6 +4140,7 @@ test = ["Pillow", "pytest", "pytest-cov"] name = "sympy" version = "1.10.1" description = "Computer algebra system (CAS) in Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3988,6 +4155,7 @@ mpmath = ">=0.19" name = "terminado" version = "0.17.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4008,6 +4176,7 @@ test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4026,6 +4195,7 @@ test = ["flake8", "isort", "pytest"] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" +category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -4037,6 +4207,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4048,6 +4219,7 @@ files = [ name = "torch" version = "1.13.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -4088,6 +4260,7 @@ opt-einsum = ["opt-einsum (>=3.3)"] name = "tornado" version = "6.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "dev" optional = false python-versions = ">= 3.7" files = [ @@ -4108,6 +4281,7 @@ files = [ name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4128,6 +4302,7 @@ telegram = ["requests"] name = "traitlets" version = "5.5.0" description = "" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4143,6 +4318,7 @@ test = ["pre-commit", "pytest"] name = "trimesh" version = "3.21.2" description = "Import, export, process, analyze and view triangular meshes." +category = "main" optional = true python-versions = "*" files = [ @@ -4178,6 +4354,7 @@ test = ["autopep8", "coveralls", "ezdxf", "pyinstrument", "pytest", "pytest-cov" name = "types-pillow" version = "9.3.0.1" description = "Typing stubs for Pillow" +category = "main" optional = true python-versions = "*" files = [ @@ -4189,6 +4366,7 @@ files = [ name = "types-protobuf" version = "3.20.4.5" description = "Typing stubs for protobuf" +category = "dev" optional = false python-versions = "*" files = [ @@ -4200,6 +4378,7 @@ files = [ name = "types-requests" version = "2.28.11.7" description = "Typing stubs for requests" +category = "main" optional = false python-versions = "*" files = [ @@ -4214,6 +4393,7 @@ types-urllib3 = "<1.27" name = "types-urllib3" version = "1.26.25.4" description = "Typing stubs for urllib3" +category = "main" optional = false python-versions = "*" files = [ @@ -4225,6 +4405,7 @@ files = [ name = "typing-extensions" version = "4.4.0" description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4236,6 +4417,7 @@ files = [ name = "typing-inspect" version = "0.8.0" description = "Runtime inspection utilities for typing module." +category = "main" optional = false python-versions = "*" files = [ @@ -4251,6 +4433,7 @@ typing-extensions = ">=3.7.4" name = "urllib3" version = "1.26.14" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -4267,6 +4450,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "uvicorn" version = "0.19.0" description = "The lightning-fast ASGI server." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4285,6 +4469,7 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", name = "validators" version = "0.20.0" description = "Python Data Validation for Humans™." +category = "main" optional = true python-versions = ">=3.4" files = [ @@ -4301,6 +4486,7 @@ test = ["flake8 (>=2.4.0)", "isort (>=4.2.2)", "pytest (>=2.2.3)"] name = "virtualenv" version = "20.16.7" description = "Virtual Python Environment builder" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4321,6 +4507,7 @@ testing = ["coverage (>=6.2)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7 name = "watchdog" version = "2.3.1" description = "Filesystem events monitoring" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4361,6 +4548,7 @@ watchmedo = ["PyYAML (>=3.10)"] name = "wcmatch" version = "8.4.1" description = "Wildcard/glob file name matcher." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4375,6 +4563,7 @@ bracex = ">=2.1.1" name = "wcwidth" version = "0.2.5" description = "Measures the displayed width of unicode strings in a terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -4386,6 +4575,7 @@ files = [ name = "weaviate-client" version = "3.17.1" description = "A python native weaviate client" +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -4406,6 +4596,7 @@ grpc = ["grpcio", "grpcio-tools"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" +category = "dev" optional = false python-versions = "*" files = [ @@ -4417,6 +4608,7 @@ files = [ name = "websocket-client" version = "1.4.2" description = "WebSocket client for Python with low level API options" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4433,6 +4625,7 @@ test = ["websockets"] name = "wheel" version = "0.38.4" description = "A built-package format for Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4447,6 +4640,7 @@ test = ["pytest (>=3.0.0)"] name = "xxhash" version = "3.2.0" description = "Python binding for xxHash" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4554,6 +4748,7 @@ files = [ name = "yarl" version = "1.8.2" description = "Yet another URL library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4641,6 +4836,7 @@ multidict = ">=4.0" name = "zipp" version = "3.10.0" description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4664,6 +4860,7 @@ mesh = ["trimesh"] pandas = ["pandas"] proto = ["lz4", "protobuf"] qdrant = ["qdrant-client"] +redis = ["redis"] torch = ["torch"] video = ["av"] weaviate = ["weaviate-client"] @@ -4672,4 +4869,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "efa11671865be91f94b93c6988ac1d348f8cad1b3c9314ca989ac7f471fce497" +content-hash = "3301ebf3cf49f3980af50c544315315578fed5f1cce94d3307a7a1f94ea396a8" diff --git a/tests/index/redis/fixtures.py b/tests/index/redis/fixtures.py index b1e55f4a8ec..d4db7375a3f 100644 --- a/tests/index/redis/fixtures.py +++ b/tests/index/redis/fixtures.py @@ -7,7 +7,9 @@ @pytest.fixture(scope='session', autouse=True) def start_redis(): - os.system('docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest') + os.system( + 'docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest' + ) time.sleep(1) yield From 341fa9afca2d4542e6b5249dd5b23d67e7fbd851 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 10:00:49 +0200 Subject: [PATCH 06/28] chore: run tests Signed-off-by: jupyterjazz --- .github/workflows/ci.yml | 2 +- docarray/utils/_internal/misc.py | 2 +- poetry.lock | 35 ++++++++++++++++++++++++++++++-- pyproject.toml | 3 ++- 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2a82474b005..c0d9508822d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -196,7 +196,7 @@ jobs: fail-fast: false matrix: python-version: [3.8] - db_test_folder: [base_classes, elastic, hnswlib, qdrant, weaviate] + db_test_folder: [base_classes, elastic, hnswlib, qdrant, weaviate, redis] steps: - uses: actions/checkout@v2.5.0 - name: Set up Python ${{ matrix.python-version }} diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index 207ad832f61..1ac8bc659b6 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -42,7 +42,7 @@ 'smart_open': '"docarray[aws]"', 'boto3': '"docarray[aws]"', 'botocore': '"docarray[aws]"', - 'redis': '"docarray[redis]', + 'redis': '"docarray[redis]"', } diff --git a/poetry.lock b/poetry.lock index 55e1d93a79b..c5543885141 100644 --- a/poetry.lock +++ b/poetry.lock @@ -685,7 +685,7 @@ name = "cryptography" version = "40.0.1" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." category = "main" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "cryptography-40.0.1-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:918cb89086c7d98b1b86b9fdb70c712e5a9325ba6f7d7cfb509e784e0cfc6917"}, @@ -4374,6 +4374,37 @@ files = [ {file = "types_protobuf-3.20.4.5-py3-none-any.whl", hash = "sha256:97af5ce70d890fdb94cb0c906f5a6624ca2fef58bc04e27990a25509e992a950"}, ] +[[package]] +name = "types-pyopenssl" +version = "23.2.0.1" +description = "Typing stubs for pyOpenSSL" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "types-pyOpenSSL-23.2.0.1.tar.gz", hash = "sha256:beeb5d22704c625a1e4b6dc756355c5b4af0b980138b702a9d9f932acf020903"}, + {file = "types_pyOpenSSL-23.2.0.1-py3-none-any.whl", hash = "sha256:0568553f104466f1b8e0db3360fbe6770137d02e21a1a45c209bf2b1b03d90d4"}, +] + +[package.dependencies] +cryptography = ">=35.0.0" + +[[package]] +name = "types-redis" +version = "4.6.0.0" +description = "Typing stubs for redis" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "types-redis-4.6.0.0.tar.gz", hash = "sha256:4ad588026d89ba72eae29b6276448ea117d77e5e4df258c0429d274da652ef9c"}, + {file = "types_redis-4.6.0.0-py3-none-any.whl", hash = "sha256:528038f32a0a2642e00d9c80dd95879a348ced6071bb747c746c0cb1ad06426c"}, +] + +[package.dependencies] +cryptography = ">=35.0.0" +types-pyOpenSSL = "*" + [[package]] name = "types-requests" version = "2.28.11.7" @@ -4869,4 +4900,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "3301ebf3cf49f3980af50c544315315578fed5f1cce94d3307a7a1f94ea396a8" +content-hash = "145f974cefae9a7b73b729acb0f46f005c233cccbcf03fe17f0ad783ce0496f1" diff --git a/pyproject.toml b/pyproject.toml index 2de21647e09..d22c1b637dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ smart-open = {version = ">=6.3.0", extras = ["s3"], optional = true} jina-hubble-sdk = {version = ">=0.34.0", optional = true} elastic-transport = {version ="^8.4.0", optional = true } qdrant-client = {version = ">=1.1.4", python = "<3.12", optional = true } -redis = {version = "^4.5.5", optional = true } +redis = {version = "^4.5.5", optional = true} [tool.poetry.extras] proto = ["protobuf", "lz4"] @@ -90,6 +90,7 @@ black = ">=22.10.0" isort = ">=5.10.1" ruff = ">=0.0.243" blacken-docs = ">=1.13.0" +types-redis = ">=4.6.0.0" [tool.poetry.group.dev.dependencies] uvicorn = ">=0.19.0" From abca1dd228fa4df59d385250e6d540908390d334 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 10:43:35 +0200 Subject: [PATCH 07/28] fix: defaultdict for column config Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index b96ea730a90..0d4d70f8310 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -1,4 +1,5 @@ import uuid +from collections import defaultdict from typing import ( TypeVar, Generic, @@ -187,6 +188,9 @@ class DBConfig(BaseDocIndex.DBConfig): ef_runtime: Optional[int] = None block_size: Optional[int] = None initial_cap: Optional[int] = None + default_column_config: Dict[Type, Dict[str, Any]] = field( + default_factory=lambda: defaultdict(dict) + ) def __post_init__(self): self.algorithm = self.algorithm.upper() From c5abd80320136e5989b4e90a1d2666ddc4c9bc8a Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 11:01:04 +0200 Subject: [PATCH 08/28] style: ignore mypy errors Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 32 +++++++++---------- tests/index/redis/test_find.py | 53 +++++++++++++++++--------------- 2 files changed, 45 insertions(+), 40 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 0d4d70f8310..2342141a294 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -37,12 +37,12 @@ if TYPE_CHECKING: import redis from redis.commands.search.query import Query - from redis.commands.search.field import ( + from redis.commands.search.field import ( # type: ignore[import] NumericField, TextField, VectorField, ) - from redis.commands.search.indexDefinition import IndexDefinition, IndexType + from redis.commands.search.indexDefinition import IndexDefinition, IndexType # type: ignore[import] else: redis = import_library('redis') @@ -75,7 +75,7 @@ def __init__(self, db_config=None, **kwargs): if not self._db_config.index_name: self._db_config.index_name = 'index_name__' + self._random_name() - self._prefix = self._db_config.index_name + ':' + self._prefix = self._db_config.index_name + ':' # type: ignore[operator] # initialize Redis client self._client = redis.Redis( @@ -93,7 +93,7 @@ def _random_name(): return uuid.uuid4().hex def _create_index(self): - if not self._check_index_exists(self._db_config.index_name): + if not self._check_index_exists(self._db_config.index_name): # type: ignore[arg-type] schema = [] for column, info in self._column_infos.items(): if info.db_type == VectorField: @@ -103,16 +103,16 @@ def _create_index(self): if space.upper() == valid_dist: space = valid_dist if space not in VALID_DISTANCES: - space = self._db_config.distance + space = self._db_config.distance # type: ignore[union-attr] attributes = { 'TYPE': 'FLOAT32', 'DIM': info.n_dim or info.config.get('dim'), 'DISTANCE_METRIC': space, - 'EF_CONSTRUCTION': self._db_config.ef_construction, - 'EF_RUNTIME': self._db_config.ef_runtime, - 'M': self._db_config.m, - 'INITIAL_CAP': self._db_config.initial_cap, + 'EF_CONSTRUCTION': self._db_config.ef_construction, # type: ignore[union-attr] + 'EF_RUNTIME': self._db_config.ef_runtime, # type: ignore[union-attr] + 'M': self._db_config.m, # type: ignore[union-attr] + 'INITIAL_CAP': self._db_config.initial_cap, # type: ignore[union-attr] } attributes = { name: value for name, value in attributes.items() if value @@ -121,7 +121,7 @@ def _create_index(self): info.db_type( '$.' + column, algorithm=info.config.get( - 'algorithm', self._db_config.algorithm + 'algorithm', self._db_config.algorithm # type: ignore[union-attr] ), attributes=attributes, as_name=column, @@ -131,7 +131,7 @@ def _create_index(self): schema.append(info.db_type('$.' + column, as_name=column)) # Create Redis Index - self._client.ft(self._db_config.index_name).create_index( + self._client.ft(self._db_config.index_name).create_index( # type: ignore[arg-type] schema, definition=IndexDefinition( prefix=[self._prefix], index_type=IndexType.JSON @@ -287,7 +287,7 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): return ids def num_docs(self) -> int: - num_docs = self._client.ft(self._db_config.index_name).info()['num_docs'] + num_docs = self._client.ft(self._db_config.index_name).info()['num_docs'] # type: ignore[arg-type] return int(num_docs) def _del_items(self, doc_ids: Sequence[str]): @@ -360,7 +360,7 @@ def _hybrid_search( 'vec': np.array(query, dtype=np.float32).tobytes() # type: ignore } results = ( - self._client.ft(self._db_config.index_name) + self._client.ft(self._db_config.index_name) # type: ignore[arg-type] .search(redis_query, query_params) .docs ) @@ -397,7 +397,7 @@ def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: q = Query(filter_query) q.paging(0, limit) - results = self._client.ft(index_name=self._db_config.index_name).search(q).docs + results = self._client.ft(index_name=self._db_config.index_name).search(q).docs # type: ignore[arg-type] docs = [json.loads(doc.json) for doc in results] return docs @@ -473,12 +473,12 @@ def _text_search( query_str = '|'.join(query.split(' ')) q = ( Query(f'@{search_field}:{query_str}') - .scorer(self._db_config.text_scorer) + .scorer(self._db_config.text_scorer) # type: ignore[union-attr] .with_scores() .paging(0, limit) ) - results = self._client.ft(index_name=self._db_config.index_name).search(q).docs + results = self._client.ft(index_name=self._db_config.index_name).search(q).docs # type: ignore[arg-type] scores: NdArray = NdArray._docarray_from_native( np.array([document['score'] for document in results]) diff --git a/tests/index/redis/test_find.py b/tests/index/redis/test_find.py index 35665389b1c..93c97369826 100644 --- a/tests/index/redis/test_find.py +++ b/tests/index/redis/test_find.py @@ -10,6 +10,8 @@ from docarray.typing import NdArray, TorchTensor from tests.index.redis.fixtures import start_redis # noqa: F401 +pytestmark = [pytest.mark.slow, pytest.mark.index] + N_DIM = 10 @@ -57,8 +59,8 @@ def test_find_empty_index(): def test_find_limit_larger_than_index(): schema = get_simple_schema() db = RedisDocumentIndex[schema](host='localhost') - query = schema(tens=np.ones(10)) - index_docs = [schema(tens=np.zeros(10)) for _ in range(10)] + query = schema(tens=np.ones(N_DIM)) + index_docs = [schema(tens=np.zeros(N_DIM)) for _ in range(10)] db.index(index_docs) docs, scores = db.find(query, search_field='tens', limit=20) assert len(docs) == 10 @@ -68,14 +70,14 @@ def test_find_limit_larger_than_index(): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) def test_find_torch(space): db = RedisDocumentIndex[TorchDoc](host='localhost') - index_docs = [TorchDoc(tens=np.random.rand(10)) for _ in range(10)] - index_docs.append(TorchDoc(tens=np.ones(10, dtype=np.float32))) + index_docs = [TorchDoc(tens=np.random.rand(N_DIM)) for _ in range(10)] + index_docs.append(TorchDoc(tens=np.ones(N_DIM, dtype=np.float32))) db.index(index_docs) for doc in index_docs: assert isinstance(doc.tens, TorchTensor) - query = TorchDoc(tens=np.ones(10, dtype=np.float32)) + query = TorchDoc(tens=np.ones(N_DIM, dtype=np.float32)) result_docs, scores = db.find(query, search_field='tens', limit=5) @@ -90,20 +92,20 @@ def test_find_torch(space): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) def test_find_flat_schema(space): class FlatSchema(BaseDoc): - tens_one: NdArray = Field(dim=10, space=space) + tens_one: NdArray = Field(dim=N_DIM, space=space) tens_two: NdArray = Field(dim=50, space=space) index = RedisDocumentIndex[FlatSchema](host='localhost') index_docs = [ - FlatSchema(tens_one=np.random.rand(10), tens_two=np.random.rand(50)) + FlatSchema(tens_one=np.random.rand(N_DIM), tens_two=np.random.rand(50)) for _ in range(10) ] - index_docs.append(FlatSchema(tens_one=np.zeros(10), tens_two=np.ones(50))) - index_docs.append(FlatSchema(tens_one=np.ones(10), tens_two=np.zeros(50))) + index_docs.append(FlatSchema(tens_one=np.zeros(N_DIM), tens_two=np.ones(50))) + index_docs.append(FlatSchema(tens_one=np.ones(N_DIM), tens_two=np.zeros(50))) index.index(index_docs) - query = FlatSchema(tens_one=np.ones(10), tens_two=np.ones(50)) + query = FlatSchema(tens_one=np.ones(N_DIM), tens_two=np.ones(50)) # find on tens_one docs, scores = index.find(query, search_field='tens_one', limit=5) @@ -125,47 +127,50 @@ class FlatSchema(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) def test_find_nested_schema(space): class SimpleDoc(BaseDoc): - tens: NdArray[10] = Field(space=space) + tens: NdArray[N_DIM] = Field(space=space) class NestedDoc(BaseDoc): d: SimpleDoc - tens: NdArray[10] = Field(space=space) + tens: NdArray[N_DIM] = Field(space=space) class DeepNestedDoc(BaseDoc): d: NestedDoc - tens: NdArray = Field(space=space, dim=10) + tens: NdArray = Field(space=space, dim=N_DIM) index = RedisDocumentIndex[DeepNestedDoc](host='localhost') index_docs = [ DeepNestedDoc( - d=NestedDoc(d=SimpleDoc(tens=np.random.rand(10)), tens=np.random.rand(10)), - tens=np.random.rand(10), + d=NestedDoc( + d=SimpleDoc(tens=np.random.rand(N_DIM)), tens=np.random.rand(N_DIM) + ), + tens=np.random.rand(N_DIM), ) for _ in range(10) ] index_docs.append( DeepNestedDoc( - d=NestedDoc(d=SimpleDoc(tens=np.ones(10)), tens=np.zeros(10)), - tens=np.zeros(10), + d=NestedDoc(d=SimpleDoc(tens=np.ones(N_DIM)), tens=np.zeros(N_DIM)), + tens=np.zeros(N_DIM), ) ) index_docs.append( DeepNestedDoc( - d=NestedDoc(d=SimpleDoc(tens=np.zeros(10)), tens=np.ones(10)), - tens=np.zeros(10), + d=NestedDoc(d=SimpleDoc(tens=np.zeros(N_DIM)), tens=np.ones(N_DIM)), + tens=np.zeros(N_DIM), ) ) index_docs.append( DeepNestedDoc( - d=NestedDoc(d=SimpleDoc(tens=np.zeros(10)), tens=np.zeros(10)), - tens=np.ones(10), + d=NestedDoc(d=SimpleDoc(tens=np.zeros(N_DIM)), tens=np.zeros(N_DIM)), + tens=np.ones(N_DIM), ) ) index.index(index_docs) query = DeepNestedDoc( - d=NestedDoc(d=SimpleDoc(tens=np.ones(10)), tens=np.ones(10)), tens=np.ones(10) + d=NestedDoc(d=SimpleDoc(tens=np.ones(N_DIM)), tens=np.ones(N_DIM)), + tens=np.ones(N_DIM), ) # find on root level @@ -209,7 +214,7 @@ class MyDoc(BaseDoc): def test_query_builder(): class SimpleSchema(BaseDoc): - tensor: NdArray[10] = Field(space='cosine') + tensor: NdArray[N_DIM] = Field(space='cosine') price: int db = RedisDocumentIndex[SimpleSchema](host='localhost') @@ -221,7 +226,7 @@ class SimpleSchema(BaseDoc): q = ( db.build_query() - .find(query=np.ones(10), search_field='tensor', limit=5) + .find(query=np.ones(N_DIM), search_field='tensor', limit=5) .filter(filter_query='@price:[-inf 3]') .build() ) From 22434f16b91175c636a97a6abc7dccb45bb897a7 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 11:25:00 +0200 Subject: [PATCH 09/28] refactor: put vectorfield args in column info Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 62 ++++++++++++++------------------ 1 file changed, 27 insertions(+), 35 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 2342141a294..ce3d5c8678f 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -97,22 +97,24 @@ def _create_index(self): schema = [] for column, info in self._column_infos.items(): if info.db_type == VectorField: - space = info.config.get('space') - if space: - for valid_dist in VALID_DISTANCES: - if space.upper() == valid_dist: - space = valid_dist - if space not in VALID_DISTANCES: - space = self._db_config.distance # type: ignore[union-attr] + space = info.config.get('space') or info.config.get('distance') + for valid_dist in VALID_DISTANCES: + if space.upper() == valid_dist: # type: ignore[union-attr] + space = valid_dist + if not space: + raise ValueError( + f"Invalid distance metric '{space}' provided. " + f"Must be one of: {', '.join(VALID_DISTANCES)}" + ) attributes = { 'TYPE': 'FLOAT32', 'DIM': info.n_dim or info.config.get('dim'), 'DISTANCE_METRIC': space, - 'EF_CONSTRUCTION': self._db_config.ef_construction, # type: ignore[union-attr] - 'EF_RUNTIME': self._db_config.ef_runtime, # type: ignore[union-attr] - 'M': self._db_config.m, # type: ignore[union-attr] - 'INITIAL_CAP': self._db_config.initial_cap, # type: ignore[union-attr] + 'EF_CONSTRUCTION': info.config['ef_construction'], + 'EF_RUNTIME': info.config['ef_runtime'], + 'M': info.config['m'], + 'INITIAL_CAP': info.config['initial_cap'], } attributes = { name: value for name, value in attributes.items() if value @@ -120,9 +122,7 @@ def _create_index(self): schema.append( info.db_type( '$.' + column, - algorithm=info.config.get( - 'algorithm', self._db_config.algorithm # type: ignore[union-attr] - ), + algorithm=info.config['algorithm'], attributes=attributes, as_name=column, ) @@ -180,33 +180,25 @@ class DBConfig(BaseDocIndex.DBConfig): index_name: Optional[str] = None username: Optional[str] = None password: Optional[str] = None - algorithm: str = field(default='FLAT') - distance: str = field(default='COSINE') text_scorer: str = field(default='BM25') - ef_construction: Optional[int] = None - m: Optional[int] = None - ef_runtime: Optional[int] = None - block_size: Optional[int] = None - initial_cap: Optional[int] = None default_column_config: Dict[Type, Dict[str, Any]] = field( - default_factory=lambda: defaultdict(dict) + default_factory=lambda: defaultdict( + dict, + { + VectorField: { + 'algorithm': 'FLAT', + 'distance': 'COSINE', + 'ef_construction': None, + 'm': None, + 'ef_runtime': None, + 'initial_cap': None, + }, + }, + ) ) def __post_init__(self): - self.algorithm = self.algorithm.upper() - self.distance = self.distance.upper() self.text_scorer = self.text_scorer.upper() - if self.algorithm not in VALID_ALGORITHMS: - raise ValueError( - f"Invalid algorithm '{self.algorithm}' provided. " - f"Must be one of: {', '.join(VALID_ALGORITHMS)}" - ) - - if self.distance not in VALID_DISTANCES: - raise ValueError( - f"Invalid distance metric '{self.distance}' provided. " - f"Must be one of: {', '.join(VALID_DISTANCES)}" - ) if self.text_scorer not in VALID_TEXT_SCORERS: raise ValueError( From 4a961946469d3a232e892beb14bdd925d2ffa10f Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 11:28:17 +0200 Subject: [PATCH 10/28] chore: remove unused code Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 66 ++++---------------------------- 1 file changed, 7 insertions(+), 59 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index ce3d5c8678f..24cb69ac2e3 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -119,10 +119,16 @@ def _create_index(self): attributes = { name: value for name, value in attributes.items() if value } + algorithm = info.config['algorithm'].upper() + if algorithm not in VALID_ALGORITHMS: + raise ValueError( + f"Invalid algorithm '{algorithm}' provided. " + f"Must be one of: {', '.join(VALID_ALGORITHMS)}" + ) schema.append( info.db_type( '$.' + column, - algorithm=info.config['algorithm'], + algorithm=algorithm, attributes=attributes, as_name=column, ) @@ -393,64 +399,6 @@ def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: docs = [json.loads(doc.json) for doc in results] return docs - # def _build_query_node(self, key, condition): - # operator = list(condition.keys())[0] - # value = condition[operator] - # - # query_dict = {} - # - # if operator in ['$ne', '$eq']: - # if isinstance(value, bool): - # query_dict[key] = equal(int(value)) - # elif isinstance(value, (int, float)): - # query_dict[key] = equal(value) - # else: - # query_dict[key] = '"' + value + '"' - # elif operator == '$gt': - # query_dict[key] = gt(value) - # elif operator == '$gte': - # query_dict[key] = ge(value) - # elif operator == '$lt': - # query_dict[key] = lt(value) - # elif operator == '$lte': - # query_dict[key] = le(value) - # else: - # raise ValueError( - # f'Expecting filter operator one of $gt, $gte, $lt, $lte, $eq, $ne, $and OR $or, got {operator} instead' - # ) - # - # if operator == '$ne': - # return DistjunctUnion(**query_dict) - # return IntersectNode(**query_dict) - # - # def _build_query_nodes(self, filter): - # nodes = [] - # for k, v in filter.items(): - # if k == '$and': - # children = self._build_query_nodes(v) - # node = intersect(*children) - # nodes.append(node) - # elif k == '$or': - # children = self._build_query_nodes(v) - # node = union(*children) - # nodes.append(node) - # else: - # child = self._build_query_node(k, v) - # nodes.append(child) - # - # return nodes - # - # def _get_redis_filter_query(self, filter: Union[str, Dict]): - # if isinstance(filter, dict): - # nodes = self._build_query_nodes(filter) - # query_str = intersect(*nodes).to_string() - # elif isinstance(filter, str): - # query_str = filter - # else: - # raise ValueError(f'Unexpected type of filter: {type(filter)}, expected str') - # - # return query_str - def _filter_batched( self, filter_queries: Any, limit: int ) -> Union[List[DocList], List[List[Dict]]]: From a52cdfde72d8dd41867e1735aeb33e1e0e9ad608 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 13:28:32 +0200 Subject: [PATCH 11/28] docs: add docstrings Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 147 ++++++++++++++++++++++++++----- 1 file changed, 127 insertions(+), 20 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 24cb69ac2e3..491456a34bb 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -70,6 +70,7 @@ class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): + """Initialize RedisDocumentIndex""" super().__init__(db_config=db_config, **kwargs) self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config) @@ -89,24 +90,23 @@ def __init__(self, db_config=None, **kwargs): self._logger.info(f'{self.__class__.__name__} has been initialized') @staticmethod - def _random_name(): + def _random_name() -> str: + """Generate a random index name.""" return uuid.uuid4().hex - def _create_index(self): + def _create_index(self) -> None: + """Create a new index in the Redis database if it doesn't already exist.""" if not self._check_index_exists(self._db_config.index_name): # type: ignore[arg-type] schema = [] for column, info in self._column_infos.items(): if info.db_type == VectorField: space = info.config.get('space') or info.config.get('distance') - for valid_dist in VALID_DISTANCES: - if space.upper() == valid_dist: # type: ignore[union-attr] - space = valid_dist - if not space: + space = space.upper() + if space not in VALID_DISTANCES: raise ValueError( f"Invalid distance metric '{space}' provided. " f"Must be one of: {', '.join(VALID_DISTANCES)}" ) - attributes = { 'TYPE': 'FLOAT32', 'DIM': info.n_dim or info.config.get('dim'), @@ -151,13 +151,18 @@ def _create_index(self): ) def _check_index_exists(self, index_name: str) -> bool: - """Check if Redis index exists.""" + """ + Check if an index exists in the Redis database. + + :param index_name: The name of the index. + :return: True if the index exists, False otherwise. + """ try: self._client.ft(index_name).info() except: # noqa: E722 - self._logger.info("Index does not exist") + self._logger.info(f'Index {index_name} does not exist') return False - self._logger.info("Index already exists") + self._logger.info(f'Index {index_name} already exists') return True class QueryBuilder(BaseDocIndex.QueryBuilder): @@ -179,8 +184,18 @@ def build(self, *args, **kwargs) -> Any: @dataclass class DBConfig(BaseDocIndex.DBConfig): - """Dataclass that contains all "static" configurations of RedisDocumentIndex.""" - + """Dataclass that contains all "static" configurations of RedisDocumentIndex. + + :param host: The host address for the Redis server. Default is 'localhost'. + :param port: The port number for the Redis server. Default is 6379. + :param index_name: The name of the index in the Redis database. + In case it's not provided, a random index name will be generated. + :param username: The username for the Redis server. Default is None. + :param password: The password for the Redis server. Default is None. + :param text_scorer: The method for scoring text during text search. + Default is 'BM25'. + :param default_column_config: Default configuration for columns. + """ host: str = 'localhost' port: int = 6379 index_name: Optional[str] = None @@ -225,6 +240,12 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig): ) def python_type_to_db_type(self, python_type: Type) -> Any: + """ + Map python types to corresponding Redis types. + + :param python_type: Python type. + :return: Corresponding Redis type. + """ type_map = { int: NumericField, float: NumericField, @@ -273,32 +294,57 @@ def _generate_item( yield item_dict - def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): + def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]) -> List[str]: + """ + Indexes the given data into Redis. + + :param column_to_data: A dictionary where each key is a column and each value is a generator. + :return: A list of document ids that have been indexed. + """ ids = [] for item in self._generate_item(column_to_data): ids.append(item['id']) doc_id = self._prefix + item['id'] self._client.json().set(doc_id, '$', item) - - num_docs = self.num_docs() - print('indexed', num_docs) return ids def num_docs(self) -> int: + """ + Fetch the number of documents in the index. + + :return: Number of documents in the index. + """ num_docs = self._client.ft(self._db_config.index_name).info()['num_docs'] # type: ignore[arg-type] return int(num_docs) - def _del_items(self, doc_ids: Sequence[str]): + def _del_items(self, doc_ids: Sequence[str]) -> None: + """ + Deletes documents from the index based on document ids. + + :param doc_ids: A sequence of document ids to be deleted. + """ doc_ids = [self._prefix + id for id in doc_ids if self._doc_exists(id)] if doc_ids: self._client.delete(*doc_ids) - def _doc_exists(self, doc_id): - return self._client.exists(self._prefix + doc_id) + def _doc_exists(self, doc_id) -> bool: + """ + Checks if a document exists in the index. + + :param doc_id: The id of the document. + :return: True if the document exists, False otherwise. + """ + return bool(self._client.exists(self._prefix + doc_id)) def _get_items( self, doc_ids: Sequence[str] ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: + """ + Fetches the documents from the index based on document ids. + + :param doc_ids: A sequence of document ids. + :return: A sequence of documents from the index. + """ if not doc_ids: return [] @@ -313,6 +359,12 @@ def _get_items( return docs def execute_query(self, query: Any, *args: Any, **kwargs: Any) -> Any: + """ + Executes a hybrid query on the index. + + :param query: Query to execute on the index. + :return: Query results. + """ components: Dict[str, List[Dict[str, Any]]] = {} for component, value in query: if component not in components: @@ -347,7 +399,16 @@ def execute_query(self, query: Any, *args: Any, **kwargs: Any) -> Any: def _hybrid_search( self, query: np.ndarray, filter_query: str, search_field: str, limit: int - ): + ) -> _FindResult: + """ + Conducts a hybrid search (a combination of vector search and filter-based search) on the index. + + :param query: The query to search. + :param filter_query: The filter condition. + :param search_field: The vector field to search on. + :param limit: The maximum number of results to return. + :return: Query results. + """ redis_query = ( Query(f'{filter_query}=>[KNN {limit} @{search_field} $vec AS vector_score]') .sort_by('vector_score') @@ -376,6 +437,14 @@ def _hybrid_search( def _find( self, query: np.ndarray, limit: int, search_field: str = '' ) -> _FindResult: + """ + Conducts a search on the index. + + :param query: The vector query to search. + :param limit: The maximum number of results to return. + :param search_field: The field to search the query. + :return: Search results. + """ return self._hybrid_search( query=query, filter_query='*', search_field=search_field, limit=limit ) @@ -383,6 +452,14 @@ def _find( def _find_batched( self, queries: np.ndarray, limit: int, search_field: str = '' ) -> _FindResultBatched: + """ + Conducts a batched search on the index. + + :param queries: The queries to search. + :param limit: The maximum number of results to return for each query. + :param search_field: The field to search the queries. + :return: Search results. + """ docs, scores = [], [] for query in queries: results = self._find(query=query, search_field=search_field, limit=limit) @@ -392,6 +469,13 @@ def _find_batched( return _FindResultBatched(documents=docs, scores=scores) def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: + """ + Filters the index based on the given filter query. + + :param filter_query: The filter condition. + :param limit: The maximum number of results to return. + :return: Filter results. + """ q = Query(filter_query) q.paging(0, limit) @@ -402,6 +486,13 @@ def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: def _filter_batched( self, filter_queries: Any, limit: int ) -> Union[List[DocList], List[List[Dict]]]: + """ + Filters the index based on the given batch of filter queries. + + :param filter_queries: The filter conditions. + :param limit: The maximum number of results to return for each filter query. + :return: Filter results. + """ results = [] for query in filter_queries: results.append(self._filter(filter_query=query, limit=limit)) @@ -410,6 +501,14 @@ def _filter_batched( def _text_search( self, query: str, limit: int, search_field: str = '' ) -> _FindResult: + """ + Conducts a text-based search on the index. + + :param query: The query to search. + :param limit: The maximum number of results to return. + :param search_field: The field to search the query. + :return: Search results. + """ query_str = '|'.join(query.split(' ')) q = ( Query(f'@{search_field}:{query_str}') @@ -431,6 +530,14 @@ def _text_search( def _text_search_batched( self, queries: Sequence[str], limit: int, search_field: str = '' ) -> _FindResultBatched: + """ + Conducts a batched text-based search on the index. + + :param queries: The queries to search. + :param limit: The maximum number of results to return for each query. + :param search_field: The field to search the queries. + :return: Search results. + """ docs, scores = [], [] for query in queries: results = self._text_search( From 9ac683f31bcbae1193593fbc6023a3cd77a07154 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 13:41:55 +0200 Subject: [PATCH 12/28] test: add tensorflow test Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 7 +++++-- tests/index/redis/test_find.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 491456a34bb..e1a69dcd153 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -101,7 +101,7 @@ def _create_index(self) -> None: for column, info in self._column_infos.items(): if info.db_type == VectorField: space = info.config.get('space') or info.config.get('distance') - space = space.upper() + space = space.upper() # type: ignore[union-attr] if space not in VALID_DISTANCES: raise ValueError( f"Invalid distance metric '{space}' provided. " @@ -196,6 +196,7 @@ class DBConfig(BaseDocIndex.DBConfig): Default is 'BM25'. :param default_column_config: Default configuration for columns. """ + host: str = 'localhost' port: int = 6379 index_name: Optional[str] = None @@ -294,7 +295,9 @@ def _generate_item( yield item_dict - def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]) -> List[str]: + def _index( + self, column_to_data: Dict[str, Generator[Any, None, None]] + ) -> List[str]: """ Indexes the given data into Redis. diff --git a/tests/index/redis/test_find.py b/tests/index/redis/test_find.py index 93c97369826..911c32d9d2d 100644 --- a/tests/index/redis/test_find.py +++ b/tests/index/redis/test_find.py @@ -89,6 +89,37 @@ def test_find_torch(space): assert torch.allclose(result_docs[0].tens, index_docs[-1].tens) +@pytest.mark.tensorflow +@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) +def test_find_tensorflow(space): + from docarray.typing import TensorFlowTensor + + class TfDoc(BaseDoc): + tens: TensorFlowTensor[10] + + db = RedisDocumentIndex[TorchDoc](host='localhost') + + index_docs = [TfDoc(tens=np.random.rand(N_DIM)) for _ in range(10)] + index_docs.append(TfDoc(tens=np.ones(10))) + db.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TensorFlowTensor) + + query = TfDoc(tens=np.ones(10)) + + result_docs, scores = db.find(query, search_field='tens', limit=5) + + assert len(result_docs) == 5 + assert len(scores) == 5 + for doc in result_docs: + assert isinstance(doc.tens, TensorFlowTensor) + assert result_docs[0].id == index_docs[-1].id + assert np.allclose( + result_docs[0].tens.unwrap().numpy(), index_docs[-1].tens.unwrap().numpy() + ) + + @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) def test_find_flat_schema(space): class FlatSchema(BaseDoc): From a7f54c387ba5d3f199c1aceb91fe502b7ea8aae0 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 13:50:25 +0200 Subject: [PATCH 13/28] fix: tensorflow test Signed-off-by: jupyterjazz --- tests/index/redis/test_find.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/index/redis/test_find.py b/tests/index/redis/test_find.py index 911c32d9d2d..ba871d477b6 100644 --- a/tests/index/redis/test_find.py +++ b/tests/index/redis/test_find.py @@ -97,7 +97,7 @@ def test_find_tensorflow(space): class TfDoc(BaseDoc): tens: TensorFlowTensor[10] - db = RedisDocumentIndex[TorchDoc](host='localhost') + db = RedisDocumentIndex[TensorFlowTensor](host='localhost') index_docs = [TfDoc(tens=np.random.rand(N_DIM)) for _ in range(10)] index_docs.append(TfDoc(tens=np.ones(10))) From be9d771168f609dc68dce20543ecbca349cb9abc Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 13:59:56 +0200 Subject: [PATCH 14/28] fix: tf tst Signed-off-by: jupyterjazz --- tests/index/redis/test_find.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/index/redis/test_find.py b/tests/index/redis/test_find.py index ba871d477b6..d5fa1d8bb52 100644 --- a/tests/index/redis/test_find.py +++ b/tests/index/redis/test_find.py @@ -97,7 +97,7 @@ def test_find_tensorflow(space): class TfDoc(BaseDoc): tens: TensorFlowTensor[10] - db = RedisDocumentIndex[TensorFlowTensor](host='localhost') + db = RedisDocumentIndex[TfDoc](host='localhost') index_docs = [TfDoc(tens=np.random.rand(N_DIM)) for _ in range(10)] index_docs.append(TfDoc(tens=np.ones(10))) From 4dc99e61520665eb222765886ed4c3e25ca861f8 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 16:45:16 +0200 Subject: [PATCH 15/28] refactor: reduce ignore types Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 79 ++++++++++++------------ tests/index/redis/test_configurations.py | 2 +- tests/index/redis/test_persist_data.py | 2 +- 3 files changed, 43 insertions(+), 40 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index e1a69dcd153..cd325ccf676 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -74,9 +74,12 @@ def __init__(self, db_config=None, **kwargs): super().__init__(db_config=db_config, **kwargs) self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config) - if not self._db_config.index_name: - self._db_config.index_name = 'index_name__' + self._random_name() - self._prefix = self._db_config.index_name + ':' # type: ignore[operator] + self._index_name = ( + self._db_config.index_name + if self._db_config.index_name + else 'index_name__' + self._random_name() + ) + self._prefix = self._index_name + ':' # initialize Redis client self._client = redis.Redis( @@ -96,7 +99,7 @@ def _random_name() -> str: def _create_index(self) -> None: """Create a new index in the Redis database if it doesn't already exist.""" - if not self._check_index_exists(self._db_config.index_name): # type: ignore[arg-type] + if not self._check_index_exists(self._index_name): schema = [] for column, info in self._column_infos.items(): if info.db_type == VectorField: @@ -137,18 +140,16 @@ def _create_index(self) -> None: schema.append(info.db_type('$.' + column, as_name=column)) # Create Redis Index - self._client.ft(self._db_config.index_name).create_index( # type: ignore[arg-type] + self._client.ft(self._index_name).create_index( schema, definition=IndexDefinition( prefix=[self._prefix], index_type=IndexType.JSON ), ) - self._logger.info(f'index {self._db_config.index_name} has been created') + self._logger.info(f'index {self._index_name} has been created') else: - self._logger.info( - f'connected to existing {self._db_config.index_name} index' - ) + self._logger.info(f'connected to existing {self._index_name} index') def _check_index_exists(self, index_name: str) -> bool: """ @@ -232,13 +233,7 @@ def __post_init__(self): class RuntimeConfig(BaseDocIndex.RuntimeConfig): """Dataclass that contains all "dynamic" configurations of RedisDocumentIndex.""" - default_column_config: Dict[Any, Dict[str, Any]] = field( - default_factory=lambda: { - TextField: {}, - NumericField: {}, - VectorField: {}, - } - ) + pass def python_type_to_db_type(self, python_type: Type) -> Any: """ @@ -264,8 +259,9 @@ def python_type_to_db_type(self, python_type: Type) -> Any: @staticmethod def _generate_item( - column_to_data: Dict[str, Generator[Any, None, None]] - ) -> Iterator[Dict[str, Any]]: + column_to_data: Dict[str, Generator[Any, None, None]], + batch_size: int = 4, + ) -> Iterator[List[Dict[str, Any]]]: """ Given a dictionary of generators, yield a dictionary where each item consists of a key and a single item from the corresponding generator. @@ -278,12 +274,15 @@ def _generate_item( """ keys = list(column_to_data.keys()) iterators = [iter(column_to_data[key]) for key in keys] + batch = [] + while True: item_dict = {} for key, it in zip(keys, iterators): item = next(it, None) if key == 'id' and not item: + yield batch return if isinstance(item, AbstractTensor): @@ -293,7 +292,10 @@ def _generate_item( elif item is not None: item_dict[key] = item - yield item_dict + batch.append(item_dict) + if len(batch) == batch_size: + yield batch + batch = [] def _index( self, column_to_data: Dict[str, Generator[Any, None, None]] @@ -305,10 +307,16 @@ def _index( :return: A list of document ids that have been indexed. """ ids = [] - for item in self._generate_item(column_to_data): - ids.append(item['id']) - doc_id = self._prefix + item['id'] - self._client.json().set(doc_id, '$', item) + for items in self._generate_item(column_to_data): + for item in items: + ids.append(item['id']) + doc_id = self._prefix + item['id'] + self._client.json().set(doc_id, '$', item) + + ## this does not work for now + # for items in self._generate_item(column_to_data): + # self._client.json().mset(((self._prefix + item['id'], '$', item) for item in items)) + return ids def num_docs(self) -> int: @@ -317,7 +325,7 @@ def num_docs(self) -> int: :return: Number of documents in the index. """ - num_docs = self._client.ft(self._db_config.index_name).info()['num_docs'] # type: ignore[arg-type] + num_docs = self._client.ft(self._index_name).info()['num_docs'] return int(num_docs) def _del_items(self, doc_ids: Sequence[str]) -> None: @@ -351,13 +359,10 @@ def _get_items( if not doc_ids: return [] - docs = [] - for id in doc_ids: - doc = self._client.json().get(self._prefix + id) - if doc: - docs.append(doc) - - if len(docs) == 0: + ids = [self._prefix + id for id in doc_ids] + docs = self._client.json().mget(ids, '$') + docs = [doc[0] for doc in docs if doc] + if not docs: raise KeyError(f'No document with id {doc_ids} found') return docs @@ -418,13 +423,11 @@ def _hybrid_search( .paging(0, limit) .dialect(2) ) - query_params: Mapping[str, str] = { # type: ignore - 'vec': np.array(query, dtype=np.float32).tobytes() # type: ignore + query_params: Mapping[str, str] = { + 'vec': np.array(query, dtype=np.float32).tobytes() # type: ignore[dict-item] } results = ( - self._client.ft(self._db_config.index_name) # type: ignore[arg-type] - .search(redis_query, query_params) - .docs + self._client.ft(self._index_name).search(redis_query, query_params).docs ) scores: NdArray = NdArray._docarray_from_native( @@ -482,7 +485,7 @@ def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: q = Query(filter_query) q.paging(0, limit) - results = self._client.ft(index_name=self._db_config.index_name).search(q).docs # type: ignore[arg-type] + results = self._client.ft(index_name=self._index_name).search(q).docs docs = [json.loads(doc.json) for doc in results] return docs @@ -520,7 +523,7 @@ def _text_search( .paging(0, limit) ) - results = self._client.ft(index_name=self._db_config.index_name).search(q).docs # type: ignore[arg-type] + results = self._client.ft(index_name=self._index_name).search(q).docs scores: NdArray = NdArray._docarray_from_native( np.array([document['score'] for document in results]) diff --git a/tests/index/redis/test_configurations.py b/tests/index/redis/test_configurations.py index 6cd7ae8f3d6..163d196e844 100644 --- a/tests/index/redis/test_configurations.py +++ b/tests/index/redis/test_configurations.py @@ -32,7 +32,7 @@ class Schema(BaseDoc): types = {'id': 'TEXT', 'tens': 'VECTOR', 'title': 'TEXT', 'year': 'NUMERIC'} index = RedisDocumentIndex[Schema](host='localhost') - attr_bytes = index._client.ft(index._db_config.index_name).info()['attributes'] + attr_bytes = index._client.ft(index._index_name).info()['attributes'] attr = [[byte.decode() for byte in sublist] for sublist in attr_bytes] assert len(Schema.__fields__) == len(attr) diff --git a/tests/index/redis/test_persist_data.py b/tests/index/redis/test_persist_data.py index 3898e1827fa..95e8ae7aab2 100644 --- a/tests/index/redis/test_persist_data.py +++ b/tests/index/redis/test_persist_data.py @@ -20,7 +20,7 @@ def test_persist(): # create index index = RedisDocumentIndex[SimpleDoc](host='localhost') - index_name = index._db_config.index_name + index_name = index._index_name assert index.num_docs() == 0 From d2718b44a1089bdd11fe177cac39c6fefb070a9f Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 16:49:43 +0200 Subject: [PATCH 16/28] style: remove other type ignores Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index cd325ccf676..5d7b4b3632e 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -80,7 +80,7 @@ def __init__(self, db_config=None, **kwargs): else 'index_name__' + self._random_name() ) self._prefix = self._index_name + ':' - + self._text_scorer = self._db_config.text_scorer # initialize Redis client self._client = redis.Redis( host=self._db_config.host, @@ -104,12 +104,12 @@ def _create_index(self) -> None: for column, info in self._column_infos.items(): if info.db_type == VectorField: space = info.config.get('space') or info.config.get('distance') - space = space.upper() # type: ignore[union-attr] - if space not in VALID_DISTANCES: + if not space or space.upper() not in VALID_DISTANCES: raise ValueError( f"Invalid distance metric '{space}' provided. " f"Must be one of: {', '.join(VALID_DISTANCES)}" ) + space = space.upper() attributes = { 'TYPE': 'FLOAT32', 'DIM': info.n_dim or info.config.get('dim'), @@ -274,7 +274,7 @@ def _generate_item( """ keys = list(column_to_data.keys()) iterators = [iter(column_to_data[key]) for key in keys] - batch = [] + batch: List[Dict[str, Any]] = [] while True: item_dict = {} @@ -518,7 +518,7 @@ def _text_search( query_str = '|'.join(query.split(' ')) q = ( Query(f'@{search_field}:{query_str}') - .scorer(self._db_config.text_scorer) # type: ignore[union-attr] + .scorer(self._text_scorer) .with_scores() .paging(0, limit) ) From 9f58474fc4f6674d14e90f93d69f65fa89b82c5f Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 16:51:09 +0200 Subject: [PATCH 17/28] style: try removing import ignores Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 5d7b4b3632e..710df1ef733 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -37,12 +37,12 @@ if TYPE_CHECKING: import redis from redis.commands.search.query import Query - from redis.commands.search.field import ( # type: ignore[import] + from redis.commands.search.field import ( NumericField, TextField, VectorField, ) - from redis.commands.search.indexDefinition import IndexDefinition, IndexType # type: ignore[import] + from redis.commands.search.indexDefinition import IndexDefinition, IndexType else: redis = import_library('redis') From 7e791dac79533ed70d463ea3a73944b55df09589 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 28 Jun 2023 16:56:39 +0200 Subject: [PATCH 18/28] style: i think mypy hates me Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 710df1ef733..5d7b4b3632e 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -37,12 +37,12 @@ if TYPE_CHECKING: import redis from redis.commands.search.query import Query - from redis.commands.search.field import ( + from redis.commands.search.field import ( # type: ignore[import] NumericField, TextField, VectorField, ) - from redis.commands.search.indexDefinition import IndexDefinition, IndexType + from redis.commands.search.indexDefinition import IndexDefinition, IndexType # type: ignore[import] else: redis = import_library('redis') From 8ed3dbbebff238cb583cd4b9881a18b6d2b8fb7a Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Thu, 29 Jun 2023 12:03:16 +0200 Subject: [PATCH 19/28] feat: batch indexing Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 69 ++++++++++++++++-------- tests/index/redis/fixtures.py | 2 +- tests/index/redis/test_configurations.py | 13 ++++- tests/index/redis/test_index_get_del.py | 11 ++++ 4 files changed, 71 insertions(+), 24 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 5d7b4b3632e..b52dbca981d 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -31,6 +31,7 @@ ) from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal._typing import safe_issubclass from docarray.utils._internal.misc import import_library from docarray.utils.find import _FindResultBatched, _FindResult, FindResult @@ -74,6 +75,10 @@ def __init__(self, db_config=None, **kwargs): super().__init__(db_config=db_config, **kwargs) self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config) + self._runtime_config: RedisDocumentIndex.RuntimeConfig = cast( + RedisDocumentIndex.RuntimeConfig, self._runtime_config + ) + self._index_name = ( self._db_config.index_name if self._db_config.index_name @@ -231,9 +236,12 @@ def __post_init__(self): @dataclass class RuntimeConfig(BaseDocIndex.RuntimeConfig): - """Dataclass that contains all "dynamic" configurations of RedisDocumentIndex.""" + """Dataclass that contains all "dynamic" configurations of RedisDocumentIndex. + + :param batch_size: Batch size during indexing. + """ - pass + batch_size: int = 100 def python_type_to_db_type(self, python_type: Type) -> Any: """ @@ -258,19 +266,21 @@ def python_type_to_db_type(self, python_type: Type) -> Any: raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') @staticmethod - def _generate_item( + def _generate_items( column_to_data: Dict[str, Generator[Any, None, None]], - batch_size: int = 4, + batch_size: int, ) -> Iterator[List[Dict[str, Any]]]: """ - Given a dictionary of generators, yield a dictionary where each item consists of a key and - a single item from the corresponding generator. + Given a dictionary of generators, yield a list of dictionaries where each + item consists of a key and a single item from the corresponding generator. :param column_to_data: A dictionary where each key is a column and each value is a generator. + :param batch_size: Size of batch to generate each time. - :yield: A dictionary where each item consists of a column name and an item from - the corresponding generator. Yields until all generators are exhausted. + :yield: A list of dictionaries where each item consists of a column name and + an item from the corresponding generator. Yields until all generators + are exhausted. """ keys = list(column_to_data.keys()) iterators = [iter(column_to_data[key]) for key in keys] @@ -282,7 +292,8 @@ def _generate_item( item = next(it, None) if key == 'id' and not item: - yield batch + if batch: + yield batch return if isinstance(item, AbstractTensor): @@ -306,16 +317,15 @@ def _index( :param column_to_data: A dictionary where each key is a column and each value is a generator. :return: A list of document ids that have been indexed. """ - ids = [] - for items in self._generate_item(column_to_data): - for item in items: - ids.append(item['id']) - doc_id = self._prefix + item['id'] - self._client.json().set(doc_id, '$', item) - - ## this does not work for now - # for items in self._generate_item(column_to_data): - # self._client.json().mset(((self._prefix + item['id'], '$', item) for item in items)) + ids: List[str] = [] + for items in self._generate_items( + column_to_data, self._runtime_config.batch_size + ): + doc_id_item_pairs = [ + (self._prefix + item['id'], '$', item) for item in items + ] + ids.extend(doc_id for doc_id, _, _ in doc_id_item_pairs) + self._client.json().mset(doc_id_item_pairs) # type: ignore[attr-defined] return ids @@ -423,11 +433,11 @@ def _hybrid_search( .paging(0, limit) .dialect(2) ) - query_params: Mapping[str, str] = { - 'vec': np.array(query, dtype=np.float32).tobytes() # type: ignore[dict-item] + query_params: Mapping[str, bytes] = { + 'vec': np.array(query, dtype=np.float32).tobytes() } results = ( - self._client.ft(self._index_name).search(redis_query, query_params).docs + self._client.ft(self._index_name).search(redis_query, query_params).docs # type: ignore[arg-type] ) scores: NdArray = NdArray._docarray_from_native( @@ -553,3 +563,18 @@ def _text_search_batched( scores.append(results.scores) return _FindResultBatched(documents=docs, scores=scores) + + def __contains__(self, item: BaseDoc) -> bool: + """ + Checks if a given document exists in the index. + + :param item: The document to check. + It must be an instance of BaseDoc or its subclass. + :return: True if the document exists in the index, False otherwise. + """ + if safe_issubclass(type(item), BaseDoc): + return self._doc_exists(item.id) + else: + raise TypeError( + f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" + ) diff --git a/tests/index/redis/fixtures.py b/tests/index/redis/fixtures.py index d4db7375a3f..0c97f4f0dc9 100644 --- a/tests/index/redis/fixtures.py +++ b/tests/index/redis/fixtures.py @@ -8,7 +8,7 @@ @pytest.fixture(scope='session', autouse=True) def start_redis(): os.system( - 'docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest' + 'docker run --name redis-stack-server -p 6379:6379 -d redis/redis-stack-server:7.2.0-RC2' ) time.sleep(1) diff --git a/tests/index/redis/test_configurations.py b/tests/index/redis/test_configurations.py index 163d196e844..ac3c1ddaf9f 100644 --- a/tests/index/redis/test_configurations.py +++ b/tests/index/redis/test_configurations.py @@ -23,7 +23,7 @@ class Schema(BaseDoc): assert index.num_docs() == 10 -def test_configure_index(tmp_path): +def test_configure_index(): class Schema(BaseDoc): tens: NdArray[100] = Field(space='cosine') title: str @@ -38,3 +38,14 @@ class Schema(BaseDoc): assert len(Schema.__fields__) == len(attr) for field, attr in zip(Schema.__fields__, attr): assert field in attr and types[field] in attr + + +def test_runtime_config(): + class Schema(BaseDoc): + tens: NdArray = Field(dim=10) + + index = RedisDocumentIndex[Schema](host='localhost') + assert index._runtime_config.batch_size == 100 + + index.configure(batch_size=10) + assert index._runtime_config.batch_size == 10 diff --git a/tests/index/redis/test_index_get_del.py b/tests/index/redis/test_index_get_del.py index dd6cc64433d..fbe199a1bd5 100644 --- a/tests/index/redis/test_index_get_del.py +++ b/tests/index/redis/test_index_get_del.py @@ -99,3 +99,14 @@ def test_del_multiple(ten_simple_docs): else: assert index[doc.id].id == doc.id assert np.allclose(index[doc.id].tens, doc.tens) + + +def test_contains(ten_simple_docs): + index = RedisDocumentIndex[SimpleDoc](host='localhost') + index.index(ten_simple_docs) + + for doc in ten_simple_docs: + assert doc in index + + other_doc = SimpleDoc(tens=np.random.randn(10)) + assert other_doc not in index From 83afb5f51b79df676555ff21d8eba1b95161765c Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Thu, 29 Jun 2023 13:09:38 +0200 Subject: [PATCH 20/28] chore: bump redis version Signed-off-by: jupyterjazz --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index c5543885141..615be2579c9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3631,14 +3631,14 @@ urllib3 = ">=1.26.14,<2.0.0" [[package]] name = "redis" -version = "4.5.5" +version = "4.6.0" description = "Python client for Redis database and key-value store" category = "main" optional = true python-versions = ">=3.7" files = [ - {file = "redis-4.5.5-py3-none-any.whl", hash = "sha256:77929bc7f5dab9adf3acba2d3bb7d7658f1e0c2f1cafe7eb36434e751c471119"}, - {file = "redis-4.5.5.tar.gz", hash = "sha256:dc87a0bdef6c8bfe1ef1e1c40be7034390c2ae02d92dcd0c7ca1729443899880"}, + {file = "redis-4.6.0-py3-none-any.whl", hash = "sha256:e2b03db868160ee4591de3cb90d40ebb50a90dd302138775937f6a42b7ed183c"}, + {file = "redis-4.6.0.tar.gz", hash = "sha256:585dc516b9eb042a619ef0a39c3d7d55fe81bdb4df09a52c9cdde0d07bf1aa7d"}, ] [package.dependencies] @@ -4900,4 +4900,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "145f974cefae9a7b73b729acb0f46f005c233cccbcf03fe17f0ad783ce0496f1" +content-hash = "495897558a14972f5d1542e0fd2e9dc48cb7ac7e435340c0673a8e6fb4fbc669" diff --git a/pyproject.toml b/pyproject.toml index d22c1b637dd..c00944a7f68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ smart-open = {version = ">=6.3.0", extras = ["s3"], optional = true} jina-hubble-sdk = {version = ">=0.34.0", optional = true} elastic-transport = {version ="^8.4.0", optional = true } qdrant-client = {version = ">=1.1.4", python = "<3.12", optional = true } -redis = {version = "^4.5.5", optional = true} +redis = {version = "^4.6.0", optional = true} [tool.poetry.extras] proto = ["protobuf", "lz4"] From 4b0bc73f8fc1d199fb81eec6def303e16cbc4440 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Sun, 2 Jul 2023 23:28:56 +0400 Subject: [PATCH 21/28] feat: subindex not fully finished Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 40 ++++++-- tests/index/redis/test_subindex.py | 159 +++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 7 deletions(-) create mode 100644 tests/index/redis/test_subindex.py diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index b52dbca981d..1b7931714c5 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -23,6 +23,7 @@ import numpy as np from numpy import ndarray +from docarray.array import AnyDocArray from docarray.index.backends.helper import _collect_query_args from docarray import BaseDoc, DocList from docarray.index.abstract import ( @@ -72,18 +73,20 @@ class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): """Initialize RedisDocumentIndex""" + if db_config is not None and getattr(db_config, 'index_name'): + self._index_name = db_config.index_name + elif kwargs.get('index_name'): + self._index_name = kwargs.get('index_name') + else: + self._index_name = 'index_name__' + self._random_name() + + super().__init__(db_config=db_config, **kwargs) self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config) self._runtime_config: RedisDocumentIndex.RuntimeConfig = cast( RedisDocumentIndex.RuntimeConfig, self._runtime_config ) - - self._index_name = ( - self._db_config.index_name - if self._db_config.index_name - else 'index_name__' + self._random_name() - ) self._prefix = self._index_name + ':' self._text_scorer = self._db_config.text_scorer # initialize Redis client @@ -107,7 +110,9 @@ def _create_index(self) -> None: if not self._check_index_exists(self._index_name): schema = [] for column, info in self._column_infos.items(): - if info.db_type == VectorField: + if issubclass(info.docarray_type, AnyDocArray): + continue + elif info.db_type == VectorField: space = info.config.get('space') or info.config.get('distance') if not space or space.upper() not in VALID_DISTANCES: raise ValueError( @@ -171,6 +176,17 @@ def _check_index_exists(self, index_name: str) -> bool: self._logger.info(f'Index {index_name} already exists') return True + @property + def index_name(self): + return self._index_name + + @property + def out_schema(self) -> Type[BaseDoc]: + """Return the real schema of the index.""" + if self._is_subindex: + return self._ori_schema + return cast(Type[BaseDoc], self._schema) + class QueryBuilder(BaseDocIndex.QueryBuilder): def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None): super().__init__() @@ -317,6 +333,7 @@ def _index( :param column_to_data: A dictionary where each key is a column and each value is a generator. :return: A list of document ids that have been indexed. """ + self._index_subindex(column_to_data) ids: List[str] = [] for items in self._generate_items( column_to_data, self._runtime_config.batch_size @@ -514,6 +531,15 @@ def _filter_batched( results.append(self._filter(filter_query=query, limit=limit)) return results + def _filter_by_parent_id(self, id: str) -> Optional[List[str]]: + """Filter the ids of the subindex documents given id of root document. + + :param id: the root document id to filter by + :return: a list of ids of the subindex documents + """ + docs = self._filter(filter_query=f'@parent_id:"{id}"', limit=self.num_docs()) + return [doc['id'] for doc in docs] + def _text_search( self, query: str, limit: int, search_field: str = '' ) -> _FindResult: diff --git a/tests/index/redis/test_subindex.py b/tests/index/redis/test_subindex.py new file mode 100644 index 00000000000..ecfc709f230 --- /dev/null +++ b/tests/index/redis/test_subindex.py @@ -0,0 +1,159 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc, DocList +from docarray.index import RedisDocumentIndex +from docarray.typing import NdArray +from tests.index.redis.fixtures import start_redis # noqa: F401 + + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class SimpleDoc(BaseDoc): + simple_tens: NdArray[10] = Field(space='l2') + simple_text: str + + +class ListDoc(BaseDoc): + docs: DocList[SimpleDoc] + simple_doc: SimpleDoc + list_tens: NdArray[20] = Field(space='l2') + + +class MyDoc(BaseDoc): + docs: DocList[SimpleDoc] + list_docs: DocList[ListDoc] + my_tens: NdArray[30] = Field(space='l2') + + +@pytest.fixture(scope='session') +def index(): + index = RedisDocumentIndex[MyDoc](host='localhost') + return index + + +@pytest.fixture(scope='session') +def data(): + my_docs = [ + MyDoc( + id=f'{i}', + docs=DocList[SimpleDoc]( + [ + SimpleDoc( + id=f'docs_{i}_{j}', + simple_tens=np.ones(10) * (j + 1), + simple_text=f'hello {j}', + ) + for j in range(5) + ] + ), + list_docs=DocList[ListDoc]( + [ + ListDoc( + id=f'list_docs_{i}_{j}', + docs=DocList[SimpleDoc]( + [ + SimpleDoc( + id=f'list_docs_docs_{i}_{j}_{k}', + simple_tens=np.ones(10) * (k + 1), + simple_text=f'hello {k}', + ) + for k in range(5) + ] + ), + simple_doc=SimpleDoc( + id=f'list_docs_simple_doc_{i}_{j}', + simple_tens=np.ones(10) * (j + 1), + simple_text=f'hello {j}', + ), + list_tens=np.ones(20) * (j + 1), + ) + for j in range(5) + ] + ), + my_tens=np.ones((30,)) * (i + 1), + ) + for i in range(5) + ] + return my_docs + +def test_subindex_init(index): + assert isinstance(index._subindices['docs'], RedisDocumentIndex) + assert isinstance(index._subindices['list_docs'], RedisDocumentIndex) + assert isinstance( + index._subindices['list_docs']._subindices['docs'], RedisDocumentIndex + ) + +def test_subindex_index(index, data): + index.index(data) + assert index.num_docs() == 5 + assert index._subindices['docs'].num_docs() == 25 + assert index._subindices['list_docs'].num_docs() == 25 + assert index._subindices['list_docs']._subindices['docs'].num_docs() == 125 + +def test_subindex_get(index, data): + index.index(data) + doc = index['1'] + assert type(doc) == MyDoc + assert doc.id == '1' + assert len(doc.docs) == 5 + assert type(doc.docs[0]) == SimpleDoc + assert doc.docs[0].id == 'docs_1_0' + assert np.allclose(doc.docs[0].simple_tens, np.ones(10)) + + assert len(doc.list_docs) == 5 + assert type(doc.list_docs[0]) == ListDoc + assert doc.list_docs[0].id == 'list_docs_1_0' + assert len(doc.list_docs[0].docs) == 5 + assert type(doc.list_docs[0].docs[0]) == SimpleDoc + assert doc.list_docs[0].docs[0].id == 'list_docs_docs_1_0_0' + assert np.allclose(doc.list_docs[0].docs[0].simple_tens, np.ones(10)) + assert doc.list_docs[0].docs[0].simple_text == 'hello 0' + assert type(doc.list_docs[0].simple_doc) == SimpleDoc + assert doc.list_docs[0].simple_doc.id == 'list_docs_simple_doc_1_0' + assert np.allclose(doc.list_docs[0].simple_doc.simple_tens, np.ones(10)) + assert doc.list_docs[0].simple_doc.simple_text == 'hello 0' + assert np.allclose(doc.list_docs[0].list_tens, np.ones(20)) + + assert np.allclose(doc.my_tens, np.ones(30) * 2) + + +def test_subindex_del(index, data): + index.index(data) + del index['0'] + assert index.num_docs() == 4 + assert index._subindices['docs'].num_docs() == 20 + assert index._subindices['list_docs'].num_docs() == 20 + assert index._subindices['list_docs']._subindices['docs'].num_docs() == 100 + + +def test_subindex_contain(index, data): + index.index(data) + # Checks for individual simple_docs within list_docs + for i in range(4): + doc = index[f'{i + 1}'] + for simple_doc in doc.list_docs: + assert index.subindex_contains(simple_doc) + for nested_doc in simple_doc.docs: + assert index.subindex_contains(nested_doc) + + invalid_doc = SimpleDoc( + id='non_existent', + simple_tens=np.zeros(10), + simple_text='invalid', + ) + assert not index.subindex_contains(invalid_doc) + + # Checks for an empty doc + empty_doc = SimpleDoc( + id='', + simple_tens=np.zeros(10), + simple_text='', + ) + assert not index.subindex_contains(empty_doc) + + # Empty index + empty_index = RedisDocumentIndex[MyDoc](host='localhost') + assert empty_doc not in empty_index From 0b15adc5fe2dea5410123203ecc2ae2fd3ce9497 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 5 Jul 2023 08:58:06 +0400 Subject: [PATCH 22/28] feat: finalize subindex Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 7 +++-- tests/index/redis/test_configurations.py | 2 +- tests/index/redis/test_subindex.py | 36 ++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 1b7931714c5..079d2fb4623 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -43,6 +43,7 @@ NumericField, TextField, VectorField, + TagField, ) from redis.commands.search.indexDefinition import IndexDefinition, IndexType # type: ignore[import] else: @@ -52,6 +53,7 @@ NumericField, TextField, VectorField, + TagField, ) from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import Query @@ -80,7 +82,6 @@ def __init__(self, db_config=None, **kwargs): else: self._index_name = 'index_name__' + self._random_name() - super().__init__(db_config=db_config, **kwargs) self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config) @@ -146,6 +147,8 @@ def _create_index(self) -> None: as_name=column, ) ) + elif column in ['id', 'parent_id']: + schema.append(TagField('$.' + column, as_name=column)) else: schema.append(info.db_type('$.' + column, as_name=column)) @@ -537,7 +540,7 @@ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]: :param id: the root document id to filter by :return: a list of ids of the subindex documents """ - docs = self._filter(filter_query=f'@parent_id:"{id}"', limit=self.num_docs()) + docs = self._filter(filter_query=f'@parent_id:{{{id}}}', limit=self.num_docs()) return [doc['id'] for doc in docs] def _text_search( diff --git a/tests/index/redis/test_configurations.py b/tests/index/redis/test_configurations.py index ac3c1ddaf9f..d8de2649091 100644 --- a/tests/index/redis/test_configurations.py +++ b/tests/index/redis/test_configurations.py @@ -29,7 +29,7 @@ class Schema(BaseDoc): title: str year: int - types = {'id': 'TEXT', 'tens': 'VECTOR', 'title': 'TEXT', 'year': 'NUMERIC'} + types = {'id': 'TAG', 'tens': 'VECTOR', 'title': 'TEXT', 'year': 'NUMERIC'} index = RedisDocumentIndex[Schema](host='localhost') attr_bytes = index._client.ft(index._index_name).info()['attributes'] diff --git a/tests/index/redis/test_subindex.py b/tests/index/redis/test_subindex.py index ecfc709f230..c873ea00b00 100644 --- a/tests/index/redis/test_subindex.py +++ b/tests/index/redis/test_subindex.py @@ -79,6 +79,7 @@ def data(): ] return my_docs + def test_subindex_init(index): assert isinstance(index._subindices['docs'], RedisDocumentIndex) assert isinstance(index._subindices['list_docs'], RedisDocumentIndex) @@ -86,6 +87,7 @@ def test_subindex_init(index): index._subindices['list_docs']._subindices['docs'], RedisDocumentIndex ) + def test_subindex_index(index, data): index.index(data) assert index.num_docs() == 5 @@ -93,6 +95,7 @@ def test_subindex_index(index, data): assert index._subindices['list_docs'].num_docs() == 25 assert index._subindices['list_docs']._subindices['docs'].num_docs() == 125 + def test_subindex_get(index, data): index.index(data) doc = index['1'] @@ -157,3 +160,36 @@ def test_subindex_contain(index, data): # Empty index empty_index = RedisDocumentIndex[MyDoc](host='localhost') assert empty_doc not in empty_index + + +def test_find_subindex(index, data): + index.index(data) + # root level + query = np.ones((30,)) + with pytest.raises(ValueError): + _, _ = index.find_subindex(query, subindex='', search_field='my_tens', limit=5) + + # sub level + query = np.ones((10,)) + root_docs, docs, scores = index.find_subindex( + query, subindex='docs', search_field='simple_tens', limit=5 + ) + assert type(root_docs[0]) == MyDoc + assert type(docs[0]) == SimpleDoc + assert len(scores) == 5 + for root_doc, doc in zip(root_docs, docs): + assert np.allclose(doc.simple_tens, np.ones(10)) + assert root_doc.id == f'{doc.id.split("_")[-2]}' + + # sub sub level + query = np.ones((10,)) + root_docs, docs, scores = index.find_subindex( + query, subindex='list_docs__docs', search_field='simple_tens', limit=5 + ) + assert len(docs) == 5 + assert len(scores) == 5 + assert type(root_docs[0]) == MyDoc + assert type(docs[0]) == SimpleDoc + for root_doc, doc in zip(root_docs, docs): + assert np.allclose(doc.simple_tens, np.ones(10)) + assert root_doc.id == f'{doc.id.split("_")[-3]}' From e8fbc4795cf170c806cbea81a9bf61069418d9bf Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 5 Jul 2023 09:08:04 +0400 Subject: [PATCH 23/28] docs: update readme Signed-off-by: jupyterjazz --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7aed194b9b3..3b7d6330c7d 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ DocArray is a Python library expertly crafted for the [representation](#represen - :fire: Offers native support for **[NumPy](https://github.com/numpy/numpy)**, **[PyTorch](https://github.com/pytorch/pytorch)**, and **[TensorFlow](https://github.com/tensorflow/tensorflow)**, catering specifically to **model training scenarios**. - :zap: Based on **[Pydantic](https://github.com/pydantic/pydantic)**, and instantly compatible with web and microservice frameworks like **[FastAPI](https://github.com/tiangolo/fastapi/)** and **[Jina](https://github.com/jina-ai/jina/)**. -- :package: Provides support for vector databases such as **[Weaviate](https://weaviate.io/), [Qdrant](https://qdrant.tech/), [ElasticSearch](https://www.elastic.co/de/elasticsearch/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**. +- :package: Provides support for vector databases such as **[Weaviate](https://weaviate.io/), [Qdrant](https://qdrant.tech/), [ElasticSearch](https://www.elastic.co/de/elasticsearch/), [Redis](https://redis.io/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**. - :chains: Allows data transmission as JSON over **HTTP** or as **[Protobuf](https://protobuf.dev/)** over **[gRPC](https://grpc.io/)**. ## Installation @@ -349,7 +349,7 @@ This is useful for: - :mag: **Neural search** applications - :bulb: **Recommender systems** -Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come! +Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come! The Document Index interface lets you index and retrieve Documents from multiple vector databases, all with the same user interface. @@ -421,7 +421,7 @@ They are now called **Document Indexes** and offer the following improvements (s - **Production-ready:** The new Document Indexes are a much thinner wrapper around the various vector DB libraries, making them more robust and easier to maintain - **Increased flexibility:** We strive to support any configuration or setting that you could perform through the DB's first-party client -For now, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come. +For now, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come. @@ -775,6 +775,7 @@ Currently, DocArray supports the following vector databases: - [Weaviate](https://www.weaviate.io/) - [Qdrant](https://qdrant.tech/) - [Elasticsearch](https://www.elastic.co/elasticsearch/) v8 and v7 +- [Redis](https://redis.io/) - [HNSWlib](https://github.com/nmslib/hnswlib) as a local-first alternative An integration of [OpenSearch](https://opensearch.org/) is currently in progress. @@ -836,6 +837,7 @@ from docarray.index import ( WeaviateDocumentIndex, QdrantDocumentIndex, ElasticDocIndex, + RedisDocumentIndex, ) # Select a suitable backend and initialize it with data From 848e95dc3be7b6b4c49217d3cc8044e367fc2b57 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 5 Jul 2023 11:11:56 +0400 Subject: [PATCH 24/28] chore: commits not showing Signed-off-by: jupyterjazz --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3b7d6330c7d..e7cf30830ab 100644 --- a/README.md +++ b/README.md @@ -349,7 +349,7 @@ This is useful for: - :mag: **Neural search** applications - :bulb: **Recommender systems** -Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come! +Currently, Document Indexes support **[Weaviate](https://weaviate.io/)**, **[Qdrant](https://qdrant.tech/)**, **[ElasticSearch](https://www.elastic.co/)**, **[Redis](https://redis.io/)**, and **[HNSWLib](https://github.com/nmslib/hnswlib)**, with more to come! The Document Index interface lets you index and retrieve Documents from multiple vector databases, all with the same user interface. From 27dd29ac1dc56ca156eb9e6085616d32aa0bc4ac Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Thu, 6 Jul 2023 15:01:45 +0400 Subject: [PATCH 25/28] feat: del and get batched Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 079d2fb4623..0063f8e2e96 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -366,7 +366,10 @@ def _del_items(self, doc_ids: Sequence[str]) -> None: """ doc_ids = [self._prefix + id for id in doc_ids if self._doc_exists(id)] if doc_ids: - self._client.delete(*doc_ids) + for batch in self._generate_batches( + doc_ids, batch_size=self._runtime_config.batch_size + ): + self._client.delete(*batch) def _doc_exists(self, doc_id) -> bool: """ @@ -377,6 +380,11 @@ def _doc_exists(self, doc_id) -> bool: """ return bool(self._client.exists(self._prefix + doc_id)) + @staticmethod + def _generate_batches(data, batch_size): + for i in range(0, len(data), batch_size): + yield data[i : i + batch_size] + def _get_items( self, doc_ids: Sequence[str] ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: @@ -388,10 +396,14 @@ def _get_items( """ if not doc_ids: return [] + docs: List[Dict[str, Any]] = [] + for batch in self._generate_batches( + doc_ids, batch_size=self._runtime_config.batch_size + ): + ids = [self._prefix + id for id in batch] + retrieved_docs = self._client.json().mget(ids, '$') + docs.extend(doc[0] for doc in retrieved_docs if doc) - ids = [self._prefix + id for id in doc_ids] - docs = self._client.json().mget(ids, '$') - docs = [doc[0] for doc in docs if doc] if not docs: raise KeyError(f'No document with id {doc_ids} found') return docs From 162c6a8b565dd6f8a3731065d7be901711f1c06b Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Thu, 6 Jul 2023 15:05:25 +0400 Subject: [PATCH 26/28] docs: update batchsize docstring Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 0063f8e2e96..953eb9e3abc 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -257,7 +257,7 @@ def __post_init__(self): class RuntimeConfig(BaseDocIndex.RuntimeConfig): """Dataclass that contains all "dynamic" configurations of RedisDocumentIndex. - :param batch_size: Batch size during indexing. + :param batch_size: Batch size for index/get/del. """ batch_size: int = 100 From 56829d6986ef5d21e48e097397b506c3ca028f63 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Sun, 9 Jul 2023 15:56:17 +0400 Subject: [PATCH 27/28] refactor: index name Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 57 ++++++++++++++++---------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 953eb9e3abc..69d1ae485e6 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -75,20 +75,14 @@ class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): """Initialize RedisDocumentIndex""" - if db_config is not None and getattr(db_config, 'index_name'): - self._index_name = db_config.index_name - elif kwargs.get('index_name'): - self._index_name = kwargs.get('index_name') - else: - self._index_name = 'index_name__' + self._random_name() - + self._index_name = None super().__init__(db_config=db_config, **kwargs) self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config) self._runtime_config: RedisDocumentIndex.RuntimeConfig = cast( RedisDocumentIndex.RuntimeConfig, self._runtime_config ) - self._prefix = self._index_name + ':' + self._prefix = self.index_name + ':' self._text_scorer = self._db_config.text_scorer # initialize Redis client self._client = redis.Redis( @@ -108,7 +102,7 @@ def _random_name() -> str: def _create_index(self) -> None: """Create a new index in the Redis database if it doesn't already exist.""" - if not self._check_index_exists(self._index_name): + if not self._check_index_exists(self.index_name): schema = [] for column, info in self._column_infos.items(): if issubclass(info.docarray_type, AnyDocArray): @@ -153,16 +147,16 @@ def _create_index(self) -> None: schema.append(info.db_type('$.' + column, as_name=column)) # Create Redis Index - self._client.ft(self._index_name).create_index( + self._client.ft(self.index_name).create_index( schema, definition=IndexDefinition( prefix=[self._prefix], index_type=IndexType.JSON ), ) - self._logger.info(f'index {self._index_name} has been created') + self._logger.info(f'index {self.index_name} has been created') else: - self._logger.info(f'connected to existing {self._index_name} index') + self._logger.info(f'connected to existing {self.index_name} index') def _check_index_exists(self, index_name: str) -> bool: """ @@ -181,6 +175,11 @@ def _check_index_exists(self, index_name: str) -> bool: @property def index_name(self): + if not self._index_name: + self._index_name = index_name = ( + self._db_config.index_name or 'index_name__' + self._random_name() + ) + self._logger.debug(f'Retrieved index name: {index_name}') return self._index_name @property @@ -290,10 +289,10 @@ def _generate_items( batch_size: int, ) -> Iterator[List[Dict[str, Any]]]: """ - Given a dictionary of generators, yield a list of dictionaries where each - item consists of a key and a single item from the corresponding generator. + Given a dictionary of data generators, yield a list of dictionaries where each + item consists of a column name and a single item from the corresponding generator. - :param column_to_data: A dictionary where each key is a column and each value + :param column_to_data: A dictionary where each key is a column name and each value is a generator. :param batch_size: Size of batch to generate each time. @@ -301,28 +300,28 @@ def _generate_items( an item from the corresponding generator. Yields until all generators are exhausted. """ - keys = list(column_to_data.keys()) - iterators = [iter(column_to_data[key]) for key in keys] + column_names = list(column_to_data.keys()) + data_generators = [iter(column_to_data[name]) for name in column_names] batch: List[Dict[str, Any]] = [] while True: - item_dict = {} - for key, it in zip(keys, iterators): - item = next(it, None) + data_dict = {} + for name, generator in zip(column_names, data_generators): + item = next(generator, None) - if key == 'id' and not item: + if name == 'id' and not item: if batch: yield batch return if isinstance(item, AbstractTensor): - item_dict[key] = item._docarray_to_ndarray().tolist() + data_dict[name] = item._docarray_to_ndarray().tolist() elif isinstance(item, ndarray): - item_dict[key] = item.astype(np.float32).tolist() + data_dict[name] = item.astype(np.float32).tolist() elif item is not None: - item_dict[key] = item + data_dict[name] = item - batch.append(item_dict) + batch.append(data_dict) if len(batch) == batch_size: yield batch batch = [] @@ -355,7 +354,7 @@ def num_docs(self) -> int: :return: Number of documents in the index. """ - num_docs = self._client.ft(self._index_name).info()['num_docs'] + num_docs = self._client.ft(self.index_name).info()['num_docs'] return int(num_docs) def _del_items(self, doc_ids: Sequence[str]) -> None: @@ -469,7 +468,7 @@ def _hybrid_search( 'vec': np.array(query, dtype=np.float32).tobytes() } results = ( - self._client.ft(self._index_name).search(redis_query, query_params).docs # type: ignore[arg-type] + self._client.ft(self.index_name).search(redis_query, query_params).docs # type: ignore[arg-type] ) scores: NdArray = NdArray._docarray_from_native( @@ -527,7 +526,7 @@ def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: q = Query(filter_query) q.paging(0, limit) - results = self._client.ft(index_name=self._index_name).search(q).docs + results = self._client.ft(index_name=self.index_name).search(q).docs docs = [json.loads(doc.json) for doc in results] return docs @@ -574,7 +573,7 @@ def _text_search( .paging(0, limit) ) - results = self._client.ft(index_name=self._index_name).search(q).docs + results = self._client.ft(index_name=self.index_name).search(q).docs scores: NdArray = NdArray._docarray_from_native( np.array([document['score'] for document in results]) From 7a5ed5e1ab400f02b01271ae1f29886b26d3e947 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Sun, 9 Jul 2023 17:27:59 +0400 Subject: [PATCH 28/28] refactor: default index name following schema Signed-off-by: jupyterjazz --- docarray/index/backends/redis.py | 19 +++++++++---- tests/index/redis/fixtures.py | 17 +---------- tests/index/redis/test_configurations.py | 8 +++--- tests/index/redis/test_find.py | 36 +++++++++++++----------- tests/index/redis/test_index_get_del.py | 22 +++++++-------- tests/index/redis/test_persist_data.py | 9 +++--- tests/index/redis/test_subindex.py | 14 ++++----- 7 files changed, 59 insertions(+), 66 deletions(-) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index 69d1ae485e6..bc8c8991671 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -75,7 +75,6 @@ class RedisDocumentIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): """Initialize RedisDocumentIndex""" - self._index_name = None super().__init__(db_config=db_config, **kwargs) self._db_config = cast(RedisDocumentIndex.DBConfig, self._db_config) @@ -175,12 +174,20 @@ def _check_index_exists(self, index_name: str) -> bool: @property def index_name(self): - if not self._index_name: - self._index_name = index_name = ( - self._db_config.index_name or 'index_name__' + self._random_name() + default_index_name = ( + self._schema.__name__.lower() if self._schema is not None else None + ) + if default_index_name is None: + err_msg = ( + 'A RedisDocumentIndex must be typed with a Document type. ' + 'To do so, use the syntax: RedisDocumentIndex[DocumentType]' ) - self._logger.debug(f'Retrieved index name: {index_name}') - return self._index_name + + self._logger.error(err_msg) + raise ValueError(err_msg) + index_name = self._db_config.index_name or default_index_name + self._logger.debug(f'Retrieved index name: {index_name}') + return index_name @property def out_schema(self) -> Type[BaseDoc]: diff --git a/tests/index/redis/fixtures.py b/tests/index/redis/fixtures.py index 0c97f4f0dc9..42acb2c1b78 100644 --- a/tests/index/redis/fixtures.py +++ b/tests/index/redis/fixtures.py @@ -2,7 +2,6 @@ import time import uuid import pytest -import redis @pytest.fixture(scope='session', autouse=True) @@ -18,19 +17,5 @@ def start_redis(): @pytest.fixture(scope='function') -def tmp_collection_name(): +def tmp_index_name(): return uuid.uuid4().hex - - -@pytest.fixture -def redis_client(): - """This fixture provides a Redis client""" - client = redis.Redis(host='localhost', port=6379) - yield client - client.flushall() - - -@pytest.fixture -def redis_config(redis_client): - """This fixture provides the Redis client and flushes all data after each test case""" - return redis_client diff --git a/tests/index/redis/test_configurations.py b/tests/index/redis/test_configurations.py index d8de2649091..c2855017ec9 100644 --- a/tests/index/redis/test_configurations.py +++ b/tests/index/redis/test_configurations.py @@ -5,7 +5,7 @@ from docarray import BaseDoc from docarray.index import RedisDocumentIndex from docarray.typing import NdArray -from tests.index.redis.fixtures import start_redis # noqa: F401 +from tests.index.redis.fixtures import start_redis, tmp_index_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -23,16 +23,16 @@ class Schema(BaseDoc): assert index.num_docs() == 10 -def test_configure_index(): +def test_configure_index(tmp_index_name): class Schema(BaseDoc): tens: NdArray[100] = Field(space='cosine') title: str year: int types = {'id': 'TAG', 'tens': 'VECTOR', 'title': 'TEXT', 'year': 'NUMERIC'} - index = RedisDocumentIndex[Schema](host='localhost') + index = RedisDocumentIndex[Schema](host='localhost', index_name=tmp_index_name) - attr_bytes = index._client.ft(index._index_name).info()['attributes'] + attr_bytes = index._client.ft(index.index_name).info()['attributes'] attr = [[byte.decode() for byte in sublist] for sublist in attr_bytes] assert len(Schema.__fields__) == len(attr) diff --git a/tests/index/redis/test_find.py b/tests/index/redis/test_find.py index d5fa1d8bb52..39285650acc 100644 --- a/tests/index/redis/test_find.py +++ b/tests/index/redis/test_find.py @@ -8,7 +8,7 @@ from docarray import BaseDoc, DocList from docarray.index import RedisDocumentIndex from docarray.typing import NdArray, TorchTensor -from tests.index.redis.fixtures import start_redis # noqa: F401 +from tests.index.redis.fixtures import start_redis, tmp_index_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -27,9 +27,9 @@ class TorchDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_simple_schema(space): +def test_find_simple_schema(space, tmp_index_name): schema = get_simple_schema(space=space) - db = RedisDocumentIndex[schema](host='localhost') + db = RedisDocumentIndex[schema](host='localhost', index_name=tmp_index_name) index_docs = [schema(tens=np.random.rand(N_DIM)) for _ in range(10)] index_docs.append(schema(tens=np.ones(N_DIM))) @@ -68,8 +68,8 @@ def test_find_limit_larger_than_index(): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_torch(space): - db = RedisDocumentIndex[TorchDoc](host='localhost') +def test_find_torch(space, tmp_index_name): + db = RedisDocumentIndex[TorchDoc](host='localhost', index_name=tmp_index_name) index_docs = [TorchDoc(tens=np.random.rand(N_DIM)) for _ in range(10)] index_docs.append(TorchDoc(tens=np.ones(N_DIM, dtype=np.float32))) db.index(index_docs) @@ -91,13 +91,13 @@ def test_find_torch(space): @pytest.mark.tensorflow @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_tensorflow(space): +def test_find_tensorflow(space, tmp_index_name): from docarray.typing import TensorFlowTensor class TfDoc(BaseDoc): tens: TensorFlowTensor[10] - db = RedisDocumentIndex[TfDoc](host='localhost') + db = RedisDocumentIndex[TfDoc](host='localhost', index_name=tmp_index_name) index_docs = [TfDoc(tens=np.random.rand(N_DIM)) for _ in range(10)] index_docs.append(TfDoc(tens=np.ones(10))) @@ -121,12 +121,12 @@ class TfDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_flat_schema(space): +def test_find_flat_schema(space, tmp_index_name): class FlatSchema(BaseDoc): tens_one: NdArray = Field(dim=N_DIM, space=space) tens_two: NdArray = Field(dim=50, space=space) - index = RedisDocumentIndex[FlatSchema](host='localhost') + index = RedisDocumentIndex[FlatSchema](host='localhost', index_name=tmp_index_name) index_docs = [ FlatSchema(tens_one=np.random.rand(N_DIM), tens_two=np.random.rand(50)) @@ -156,7 +156,7 @@ class FlatSchema(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_nested_schema(space): +def test_find_nested_schema(space, tmp_index_name): class SimpleDoc(BaseDoc): tens: NdArray[N_DIM] = Field(space=space) @@ -168,7 +168,9 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc tens: NdArray = Field(space=space, dim=N_DIM) - index = RedisDocumentIndex[DeepNestedDoc](host='localhost') + index = RedisDocumentIndex[DeepNestedDoc]( + host='localhost', index_name=tmp_index_name + ) index_docs = [ DeepNestedDoc( @@ -243,12 +245,12 @@ class MyDoc(BaseDoc): assert q.id == matches[0].id -def test_query_builder(): +def test_query_builder(tmp_index_name): class SimpleSchema(BaseDoc): tensor: NdArray[N_DIM] = Field(space='cosine') price: int - db = RedisDocumentIndex[SimpleSchema](host='localhost') + db = RedisDocumentIndex[SimpleSchema](host='localhost', index_name=tmp_index_name) index_docs = [ SimpleSchema(tensor=np.array([i + 1] * 10), price=i + 1) for i in range(10) @@ -269,7 +271,7 @@ class SimpleSchema(BaseDoc): assert doc.price <= 3 -def test_text_search(): +def test_text_search(tmp_index_name): class SimpleSchema(BaseDoc): description: str some_field: Optional[int] @@ -286,7 +288,7 @@ class SimpleSchema(BaseDoc): docs = [SimpleSchema(description=text) for text in texts_to_index] - db = RedisDocumentIndex[SimpleSchema](host='localhost') + db = RedisDocumentIndex[SimpleSchema](host='localhost', index_name=tmp_index_name) db.index(docs) docs, _ = db.text_search(query=query_string, search_field='description') @@ -294,7 +296,7 @@ class SimpleSchema(BaseDoc): assert docs[0].description == texts_to_index[0] -def test_filter(): +def test_filter(tmp_index_name): class SimpleSchema(BaseDoc): description: str price: int @@ -304,7 +306,7 @@ class SimpleSchema(BaseDoc): doc3 = SimpleSchema(description='Random book', price=40) docs = [doc1, doc2, doc3] - db = RedisDocumentIndex[SimpleSchema](host='localhost') + db = RedisDocumentIndex[SimpleSchema](host='localhost', index_name=tmp_index_name) db.index(docs) # filter on price < 45 diff --git a/tests/index/redis/test_index_get_del.py b/tests/index/redis/test_index_get_del.py index fbe199a1bd5..31e67212610 100644 --- a/tests/index/redis/test_index_get_del.py +++ b/tests/index/redis/test_index_get_del.py @@ -5,7 +5,7 @@ from docarray import BaseDoc from docarray.index import RedisDocumentIndex from docarray.typing import NdArray -from tests.index.redis.fixtures import start_redis # noqa: F401 +from tests.index.redis.fixtures import start_redis, tmp_index_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -39,8 +39,8 @@ def test_num_docs(ten_simple_docs): assert index.num_docs() == 10 -def test_get_single(ten_simple_docs): - index = RedisDocumentIndex[SimpleDoc](host='localhost') +def test_get_single(ten_simple_docs, tmp_index_name): + index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -54,9 +54,9 @@ def test_get_single(ten_simple_docs): index['some_id'] -def test_get_multiple(ten_simple_docs): +def test_get_multiple(ten_simple_docs, tmp_index_name): docs_to_get_idx = [0, 2, 4, 6, 8] - index = RedisDocumentIndex[SimpleDoc](host='localhost') + index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -68,8 +68,8 @@ def test_get_multiple(ten_simple_docs): assert np.allclose(d_out.tens, d_in.tens) -def test_del_single(ten_simple_docs): - index = RedisDocumentIndex[SimpleDoc](host='localhost') +def test_del_single(ten_simple_docs, tmp_index_name): + index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -82,10 +82,10 @@ def test_del_single(ten_simple_docs): index[doc_id] -def test_del_multiple(ten_simple_docs): +def test_del_multiple(ten_simple_docs, tmp_index_name): docs_to_del_idx = [0, 2, 4, 6, 8] - index = RedisDocumentIndex[SimpleDoc](host='localhost') + index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -101,8 +101,8 @@ def test_del_multiple(ten_simple_docs): assert np.allclose(index[doc.id].tens, doc.tens) -def test_contains(ten_simple_docs): - index = RedisDocumentIndex[SimpleDoc](host='localhost') +def test_contains(ten_simple_docs, tmp_index_name): + index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name) index.index(ten_simple_docs) for doc in ten_simple_docs: diff --git a/tests/index/redis/test_persist_data.py b/tests/index/redis/test_persist_data.py index 95e8ae7aab2..3e590247f56 100644 --- a/tests/index/redis/test_persist_data.py +++ b/tests/index/redis/test_persist_data.py @@ -5,7 +5,7 @@ from docarray import BaseDoc from docarray.index import RedisDocumentIndex from docarray.typing import NdArray -from tests.index.redis.fixtures import start_redis # noqa: F401 +from tests.index.redis.fixtures import start_redis, tmp_index_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -15,12 +15,11 @@ class SimpleDoc(BaseDoc): tens: NdArray[10] = Field(dim=1000) -def test_persist(): +def test_persist(tmp_index_name): query = SimpleDoc(tens=np.random.random((10,))) # create index - index = RedisDocumentIndex[SimpleDoc](host='localhost') - index_name = index._index_name + index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name) assert index.num_docs() == 0 @@ -29,7 +28,7 @@ def test_persist(): find_results_before = index.find(query, search_field='tens', limit=5) # load existing index - index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=index_name) + index = RedisDocumentIndex[SimpleDoc](host='localhost', index_name=tmp_index_name) assert index.num_docs() == 10 find_results_after = index.find(query, search_field='tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): diff --git a/tests/index/redis/test_subindex.py b/tests/index/redis/test_subindex.py index c873ea00b00..6885dc79db6 100644 --- a/tests/index/redis/test_subindex.py +++ b/tests/index/redis/test_subindex.py @@ -22,7 +22,7 @@ class ListDoc(BaseDoc): list_tens: NdArray[20] = Field(space='l2') -class MyDoc(BaseDoc): +class NestedDoc(BaseDoc): docs: DocList[SimpleDoc] list_docs: DocList[ListDoc] my_tens: NdArray[30] = Field(space='l2') @@ -30,14 +30,14 @@ class MyDoc(BaseDoc): @pytest.fixture(scope='session') def index(): - index = RedisDocumentIndex[MyDoc](host='localhost') + index = RedisDocumentIndex[NestedDoc](host='localhost') return index @pytest.fixture(scope='session') def data(): my_docs = [ - MyDoc( + NestedDoc( id=f'{i}', docs=DocList[SimpleDoc]( [ @@ -99,7 +99,7 @@ def test_subindex_index(index, data): def test_subindex_get(index, data): index.index(data) doc = index['1'] - assert type(doc) == MyDoc + assert type(doc) == NestedDoc assert doc.id == '1' assert len(doc.docs) == 5 assert type(doc.docs[0]) == SimpleDoc @@ -158,7 +158,7 @@ def test_subindex_contain(index, data): assert not index.subindex_contains(empty_doc) # Empty index - empty_index = RedisDocumentIndex[MyDoc](host='localhost') + empty_index = RedisDocumentIndex[NestedDoc](host='localhost') assert empty_doc not in empty_index @@ -174,7 +174,7 @@ def test_find_subindex(index, data): root_docs, docs, scores = index.find_subindex( query, subindex='docs', search_field='simple_tens', limit=5 ) - assert type(root_docs[0]) == MyDoc + assert type(root_docs[0]) == NestedDoc assert type(docs[0]) == SimpleDoc assert len(scores) == 5 for root_doc, doc in zip(root_docs, docs): @@ -188,7 +188,7 @@ def test_find_subindex(index, data): ) assert len(docs) == 5 assert len(scores) == 5 - assert type(root_docs[0]) == MyDoc + assert type(root_docs[0]) == NestedDoc assert type(docs[0]) == SimpleDoc for root_doc, doc in zip(root_docs, docs): assert np.allclose(doc.simple_tens, np.ones(10)) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy