diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cc5b769d59..b8c4added6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -193,7 +193,7 @@ jobs: fail-fast: false matrix: python-version: [3.8] - db_test_folder: [base_classes, elastic, hnswlib, qdrant, weaviate, redis, milvus] + db_test_folder: [base_classes, elastic, epsilla, hnswlib, qdrant, weaviate, redis, milvus] pydantic-version: ["pydantic-v2", "pydantic-v1"] steps: - uses: actions/checkout@v2.5.0 diff --git a/docarray/index/__init__.py b/docarray/index/__init__.py index dfd0d52f7c..72596cd73a 100644 --- a/docarray/index/__init__.py +++ b/docarray/index/__init__.py @@ -10,16 +10,18 @@ if TYPE_CHECKING: from docarray.index.backends.elastic import ElasticDocIndex # noqa: F401 from docarray.index.backends.elasticv7 import ElasticV7DocIndex # noqa: F401 + from docarray.index.backends.epsilla import EpsillaDocumentIndex # noqa: F401 from docarray.index.backends.hnswlib import HnswDocumentIndex # noqa: F401 + from docarray.index.backends.milvus import MilvusDocumentIndex # noqa: F401 from docarray.index.backends.qdrant import QdrantDocumentIndex # noqa: F401 - from docarray.index.backends.weaviate import WeaviateDocumentIndex # noqa: F401 from docarray.index.backends.redis import RedisDocumentIndex # noqa: F401 - from docarray.index.backends.milvus import MilvusDocumentIndex # noqa: F401 + from docarray.index.backends.weaviate import WeaviateDocumentIndex # noqa: F401 __all__ = [ 'InMemoryExactNNIndex', 'ElasticDocIndex', 'ElasticV7DocIndex', + 'EpsillaDocumentIndex', 'QdrantDocumentIndex', 'WeaviateDocumentIndex', 'RedisDocumentIndex', @@ -38,6 +40,9 @@ def __getattr__(name: str): elif name == 'ElasticV7DocIndex': import_library('elasticsearch', raise_error=True) import docarray.index.backends.elasticv7 as lib + elif name == 'EpsillaDocumentIndex': + import_library('pyepsilla', raise_error=True) + import docarray.index.backends.epsilla as lib elif name == 'QdrantDocumentIndex': import_library('qdrant_client', raise_error=True) import docarray.index.backends.qdrant as lib diff --git a/docarray/index/backends/epsilla.py b/docarray/index/backends/epsilla.py new file mode 100644 index 0000000000..83c171daed --- /dev/null +++ b/docarray/index/backends/epsilla.py @@ -0,0 +1,531 @@ +import copy +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import ( + Any, + Dict, + Generator, + Generic, + List, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, +) + +import numpy as np +from pyepsilla import cloud, vectordb + +from docarray import BaseDoc, DocList +from docarray.index.abstract import ( + BaseDocIndex, + _FindResultBatched, + _raise_not_composable, + _raise_not_supported, +) +from docarray.typing import ID, NdArray +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal._typing import safe_issubclass +from docarray.utils.find import _FindResult + +TSchema = TypeVar('TSchema', bound=BaseDoc) + + +class EpsillaDocumentIndex(BaseDocIndex, Generic[TSchema]): + def __init__(self, db_config=None, **kwargs): + # will set _db_config from args / kwargs + super().__init__(db_config=db_config, **kwargs) + + self._db_config: EpsillaDocumentIndex.DBConfig = cast( + EpsillaDocumentIndex.DBConfig, self._db_config + ) + self._db_config.validate_config() + self._validate_column_info() + + self._table_name = ( + self._db_config.table_name + if self._db_config.table_name + else self._schema.__name__ + ) + + if self._db_config.is_self_hosted: + self._db = vectordb.Client( + protocol=self._db_config.protocol, + host=self._db_config.host, + port=self._db_config.port, + ) + status_code, response = self._db.load_db( + db_name=self._db_config.db_name, + db_path=self._db_config.db_path, + ) + + if status_code != HTTPStatus.OK: + if status_code == HTTPStatus.CONFLICT: + self._logger.info(f'{self._db_config.db_name} already loaded.') + else: + raise IOError( + f"Failed to load database {self._db_config.db_name}. " + f"Error code: {status_code}. Error message: {response}." + ) + self._db.use_db(self._db_config.db_name) + + status_code, response = self._db.list_tables() + if status_code != HTTPStatus.OK: + raise IOError( + f"Failed to list tables. " + f"Error code: {status_code}. Error message: {response}." + ) + + if self._table_name not in response["result"]: + self._create_table_self_hosted() + else: + self._client = cloud.Client( + project_id=self._db_config.cloud_project_id, + api_key=self._db_config.api_key, + ) + self._db = self._client.vectordb(self._db_config.cloud_db_id) + + status_code, response = self._db.list_tables() + if status_code != HTTPStatus.OK: + raise IOError( + f"Failed to list tables. " + f"Error code: {status_code}. Error message: {response}." + ) + + # Epsilla cloud requires table to be created in the web UI before inserting data + # It does not support creating tables from Python client yet. + + def _validate_column_info(self): + vector_columns = [] + for info in self._column_infos.values(): + for type in [list, np.ndarray, AbstractTensor]: + if safe_issubclass(info.docarray_type, type) and info.config.get( + 'is_embedding', False + ): + # check that dimension is present + if info.n_dim is None and info.config.get('dim', None) is None: + raise ValueError("The dimension information is missing") + + vector_columns.append(info.docarray_type) + break + + if len(vector_columns) == 0: + raise ValueError( + "Unable to find any vector columns. Please make sure that at least one " + "column is of a vector type with the is_embedding=True attribute specified." + ) + elif len(vector_columns) > 1: + raise ValueError("Specifying multiple vector fields is not supported.") + + def _create_table_self_hosted(self): + """Use _column_infos to create a table in the database.""" + table_fields = [] + + primary_keys = [] + for column_name, column_info in self._column_infos.items(): + if column_info.docarray_type == ID: + primary_keys.append(column_name) + + # when there is a nested schema, we may have multiple "ID" fields. We use the presence of "__" + # to determine if the field is nested or not + if len(primary_keys) > 1: + sorted_pkeys = sorted(primary_keys, key=lambda x: x.count("__")) + primary_keys = sorted_pkeys[:1] + + for column_name, column_info in self._column_infos.items(): + dim = ( + column_info.n_dim + if column_info.n_dim is not None + else column_info.config.get('dim', None) + ) + if dim is None: + table_fields.append( + { + 'name': column_name, + 'dataType': column_info.db_type, + 'primaryKey': column_name in primary_keys, + } + ) + else: + table_fields.append( + { + 'name': column_name, + 'dataType': column_info.db_type, + 'dimensions': dim, + } + ) + + status_code, response = self._db.create_table( + table_name=self._table_name, + table_fields=table_fields, + ) + if status_code != HTTPStatus.OK: + raise IOError( + f"Failed to create table {self._table_name}. " + f"Error code: {status_code}. Error message: {response}." + ) + + @dataclass + class Query: + """Dataclass describing a query.""" + + vector_field: Optional[str] + vector_query: Optional[NdArray] + filter: Optional[str] + limit: int + + class QueryBuilder(BaseDocIndex.QueryBuilder): + def __init__( + self, + vector_search_field: Optional[str] = None, + vector_queries: Optional[List[NdArray]] = None, + filter: Optional[str] = None, + ): + self._vector_search_field: Optional[str] = vector_search_field + self._vector_queries: List[NdArray] = vector_queries or [] + self._filter: Optional[str] = filter + + def find(self, query: NdArray, search_field: str = ''): + if self._vector_search_field and self._vector_search_field != search_field: + raise ValueError( + f'Trying to call .find for search_field = {search_field}, but ' + f'previously {self._vector_search_field} was used. Only a single ' + f'field might be used in chained calls.' + ) + return EpsillaDocumentIndex.QueryBuilder( + vector_search_field=search_field, + vector_queries=self._vector_queries + [query], + filter=self._filter, + ) + + def filter(self, filter_query: str): # type: ignore[override] + return EpsillaDocumentIndex.QueryBuilder( + vector_search_field=self._vector_search_field, + vector_queries=self._vector_queries, + filter=filter_query, + ) + + def build(self, limit: int) -> Any: + if len(self._vector_queries) > 0: + # If there are multiple vector queries applied, we can average them and + # perform semantic search on a single vector instead + vector_query = np.average(self._vector_queries, axis=0) + else: + vector_query = None + return EpsillaDocumentIndex.Query( + vector_field=self._vector_search_field, + vector_query=vector_query, + filter=self._filter, + limit=limit, + ) + + find_batched = _raise_not_composable('find_batched') + filter_batched = _raise_not_composable('filter_batched') + text_search = _raise_not_supported('text_search') + text_search_batched = _raise_not_supported('text_search_batched') + + @dataclass + class DBConfig(BaseDocIndex.DBConfig): + """Static configuration for EpsillaDocumentIndex""" + + # default value is the schema type name + table_name: Optional[str] = None + + # Indicator for self-hosted or cloud version + is_self_hosted: bool = False + + # self-hosted version uses the following configs + protocol: Optional[str] = None + host: Optional[str] = None + port: Optional[int] = 8888 + db_path: Optional[str] = None + db_name: Optional[str] = None + + # cloud version uses the following configs + cloud_project_id: Optional[str] = None + cloud_db_id: Optional[str] = None + api_key: Optional[str] = None + + default_column_config: Dict[Any, Dict[str, Any]] = field( + default_factory=lambda: { + 'TINYINT': {}, + 'SMALLINT': {}, + 'INT': {}, + 'BIGINT': {}, + 'FLOAT': {}, + 'DOUBLE': {}, + 'STRING': {}, + 'BOOL': {}, + 'JSON': {}, + 'VECTOR_FLOAT': {}, + } + ) + + def validate_config(self): + if self.is_self_hosted: + self.validate_self_hosted_config() + else: + self.validate_cloud_config() + + def validate_self_hosted_config(self): + missing_attributes = [ + attr + for attr in ["protocol", "host", "port", "db_path", "db_name"] + if getattr(self, attr, None) is None + ] + + if missing_attributes: + raise ValueError( + f"Missing required attributes for self-hosted version: {', '.join(missing_attributes)}" + ) + + def validate_cloud_config(self): + missing_attributes_cloud = [ + attr + for attr in ["cloud_project_id", "cloud_db_id", "api_key"] + if getattr(self, attr, None) is None + ] + + if missing_attributes_cloud: + raise ValueError( + f"Missing required attributes for cloud version: {', '.join(missing_attributes_cloud)}" + ) + + @dataclass + class RuntimeConfig(BaseDocIndex.RuntimeConfig): + # No dynamic config used + pass + + @property + def collection_name(self): + return self._db_config.table_name + + @property + def index_name(self): + return self.collection_name + + def python_type_to_db_type(self, python_type: Type) -> str: + # AbstractTensor does not have n_dims, which is required by Epsilla + # Use NdArray instead + for allowed_type in [list, np.ndarray, AbstractTensor]: + if safe_issubclass(python_type, allowed_type): + return 'VECTOR_FLOAT' + + py_type_map = { + ID: 'STRING', + str: 'STRING', + bytes: 'STRING', + int: 'BIGINT', + float: 'FLOAT', + bool: 'BOOL', + np.ndarray: 'VECTOR_FLOAT', + } + + for py_type, epsilla_type in py_type_map.items(): + if safe_issubclass(python_type, py_type): + return epsilla_type + + raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') + + def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): + self._index_subindex(column_to_data) + + rows = list(self._transpose_col_value_dict(column_to_data)) + normalized_rows = [] + for row in rows: + normalized_row = {} + for key, value in row.items(): + if isinstance(value, NdArray): + normalized_row[key] = value.tolist() + elif isinstance(value, np.ndarray): + normalized_row[key] = value.tolist() + else: + normalized_row[key] = value + normalized_rows.append(normalized_row) + + status_code, response = self._db.insert( + table_name=self._table_name, records=normalized_rows + ) + + if status_code != HTTPStatus.OK: + raise IOError( + f"Failed to insert documents. " + f"Error code: {status_code}. Error message: {response}." + ) + + def num_docs(self) -> int: + raise NotImplementedError + + @property + def _is_index_empty(self) -> bool: + """ + Check if index is empty by comparing the number of documents to zero. + :return: True if the index is empty, False otherwise. + """ + # Overriding this method to always return False because Epsilla does not have a count API for num_docs + return False + + def _del_items(self, doc_ids: Sequence[str]): + status_code, response = self._db.delete( + table_name=self._table_name, + primary_keys=list(doc_ids), + ) + if status_code != HTTPStatus.OK: + raise IOError( + f"Failed to get documents with ids {doc_ids}. " + f"Error code: {status_code}. Error message: {response}." + ) + return response['message'] + + def _get_items( + self, doc_ids: Sequence[str] + ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: + status_code, response = self._db.get( + table_name=self._table_name, + primary_keys=list(doc_ids), + ) + if status_code != HTTPStatus.OK: + raise IOError( + f"Failed to get documents with ids {doc_ids}. " + f"Error code: {status_code}. Error message: {response}." + ) + return response['result'] + + def execute_query(self, query: Query) -> DocList: + if query.vector_query is not None: + result = self._find_with_filter_batched( + queries=np.expand_dims(query.vector_query, axis=0), + filter=query.filter, + limit=query.limit, + search_field=query.vector_field, + ) + return self._dict_list_to_docarray(result.documents[0]) + else: + return self._dict_list_to_docarray( + self._filter( + filter_query=query.filter, + limit=query.limit, + ) + ) + + def _doc_exists(self, doc_id: str) -> bool: + return len(self._get_items([doc_id])) > 0 + + def _find( + self, + query: np.ndarray, + limit: int, + search_field: str = '', + ) -> _FindResult: + query_batched = np.expand_dims(query, axis=0) + docs, scores = self._find_batched( + queries=query_batched, limit=limit, search_field=search_field + ) + return _FindResult(documents=docs[0], scores=scores[0]) + + def _find_batched( + self, + queries: np.ndarray, + limit: int, + search_field: str = '', + ) -> _FindResultBatched: + return self._find_with_filter_batched( + queries=queries, limit=limit, search_field=search_field + ) + + def _find_with_filter_batched( + self, + queries: np.ndarray, + limit: int, + search_field: str, + filter: Optional[str] = None, + ) -> _FindResultBatched: + if search_field == '': + raise ValueError( + 'EpsillaDocumentIndex requires a search_field to be specified.' + ) + + responses = [] + for query in queries: + status_code, response = self._db.query( + table_name=self._table_name, + query_field=search_field, + limit=limit, + filter=filter if filter is not None else '', + query_vector=query.tolist(), + with_distance=True, + ) + + if status_code != HTTPStatus.OK: + raise IOError( + f"Failed to find documents with query {query}. " + f"Error code: {status_code}. Error message: {response}." + ) + + results = response['result'] + scores = NdArray._docarray_from_native( + np.array([result['@distance'] for result in results]) + ) + documents = [] + for result in results: + doc = copy.copy(result) + del doc["@distance"] + documents.append(doc) + + responses.append((documents, scores)) + + return _FindResultBatched( + documents=[r[0] for r in responses], + scores=[r[1] for r in responses], + ) + + def _filter( + self, + filter_query: str, + limit: int, + ) -> Union[DocList, List[Dict]]: + query_batched = [filter_query] + docs = self._filter_batched(filter_queries=query_batched, limit=limit) + return docs[0] + + def _filter_batched( + self, + filter_queries: str, + limit: int, + ) -> Union[List[DocList], List[List[Dict]]]: + responses = [] + for filter_query in filter_queries: + status_code, response = self._db.get( + table_name=self._table_name, + limit=limit, + filter=filter_query, + ) + + if status_code != HTTPStatus.OK: + raise IOError( + f"Failed to find documents with filter {filter_query}. " + f"Error code: {status_code}. Error message: {response}." + ) + + results = response['result'] + responses.append(results) + + return responses + + def _text_search( + self, + query: str, + limit: int, + search_field: str = '', + ) -> _FindResult: + raise NotImplementedError(f'{type(self)} does not support text search.') + + def _text_search_batched( + self, + queries: Sequence[str], + limit: int, + search_field: str = '', + ) -> _FindResultBatched: + raise NotImplementedError(f'{type(self)} does not support text search.') diff --git a/poetry.lock b/poetry.lock index 8924ce7bd9..631a0b8d07 100644 --- a/poetry.lock +++ b/poetry.lock @@ -329,6 +329,17 @@ files = [ {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, ] +[[package]] +name = "backoff" +version = "2.2.1" +description = "Function decoration for backoff and retry" +optional = true +python-versions = ">=3.7,<4.0" +files = [ + {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, + {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, +] + [[package]] name = "beautifulsoup4" version = "4.11.1" @@ -1926,6 +1937,14 @@ files = [ {file = "mapbox_earcut-1.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9af9369266bf0ca32f4d401152217c46c699392513f22639c6b1be32bde9c1cc"}, {file = "mapbox_earcut-1.0.1-cp311-cp311-win32.whl", hash = "sha256:ff9a13be4364625697b0e0e04ba6a0f77300148b871bba0a85bfa67e972e85c4"}, {file = "mapbox_earcut-1.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e736557539c74fa969e866889c2b0149fc12668f35e3ae33667d837ff2880d3"}, + {file = "mapbox_earcut-1.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4fe92174410e4120022393013705d77cb856ead5bdf6c81bec614a70df4feb5d"}, + {file = "mapbox_earcut-1.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:082f70a865c6164a60af039aa1c377073901cf1f94fd37b1c5610dfbae2a7369"}, + {file = "mapbox_earcut-1.0.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43d268ece49d0c9e22cb4f92cd54c2cc64f71bf1c5e10800c189880d923e1292"}, + {file = "mapbox_earcut-1.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7748f1730fd36dd1fcf0809d8f872d7e1ddaa945f66a6a466ad37ef3c552ae93"}, + {file = "mapbox_earcut-1.0.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5a82d10c8dec2a0bd9a6a6c90aca7044017c8dad79f7e209fd0667826f842325"}, + {file = "mapbox_earcut-1.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:01b292588cd3f6bad7d76ee31c004ed1b557a92bbd9602a72d2be15513b755be"}, + {file = "mapbox_earcut-1.0.1-cp312-cp312-win32.whl", hash = "sha256:fce236ddc3a56ea7260acc94601a832c260e6ac5619374bb2cec2e73e7414ff0"}, + {file = "mapbox_earcut-1.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:1ce86407353b4f09f5778c436518bbbc6f258f46c5736446f25074fe3d3a3bd8"}, {file = "mapbox_earcut-1.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:aa6111a18efacb79c081f3d3cdd7d25d0585bb0e9f28896b207ebe1d56efa40e"}, {file = "mapbox_earcut-1.0.1-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2911829d1e6e5e1282fbe2840fadf578f606580f02ed436346c2d51c92f810b"}, {file = "mapbox_earcut-1.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01ff909a7b8405a923abedd701b53633c997cc2b5dc9d5b78462f51c25ec2c33"}, @@ -2290,14 +2309,25 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.2", markers = "python_version > \"3.9\" and python_version <= \"3.10\""}, {version = ">1.20", markers = "python_version <= \"3.9\""}, + {version = ">=1.21.2", markers = "python_version > \"3.9\""}, {version = ">=1.23.3", markers = "python_version > \"3.10\""}, ] [package.extras] dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] +[[package]] +name = "monotonic" +version = "1.6" +description = "An implementation of time.monotonic() for Python 2 & < 3.3" +optional = true +python-versions = "*" +files = [ + {file = "monotonic-1.6-py2.py3-none-any.whl", hash = "sha256:68687e19a14f11f26d140dd5c86f3dba4bf5df58003000ed467e0e2a69bca96c"}, + {file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -2842,8 +2872,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -3060,6 +3090,29 @@ docs = ["sphinx (>=1.7.1)"] redis = ["redis"] tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)"] +[[package]] +name = "posthog" +version = "3.0.2" +description = "Integrate PostHog into any python application." +optional = true +python-versions = "*" +files = [ + {file = "posthog-3.0.2-py2.py3-none-any.whl", hash = "sha256:a8c0af6f2401fbe50f90e68c4143d0824b54e872de036b1c2f23b5abb39d88ce"}, + {file = "posthog-3.0.2.tar.gz", hash = "sha256:701fba6e446a4de687c6e861b587e7b7741955ad624bf34fe013c06a0fec6fb3"}, +] + +[package.dependencies] +backoff = ">=1.10.0" +monotonic = ">=1.5" +python-dateutil = ">2.1" +requests = ">=2.7,<3.0" +six = ">=1.5" + +[package.extras] +dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"] +sentry = ["django", "sentry-sdk"] +test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest"] + [[package]] name = "pre-commit" version = "2.20.0" @@ -3280,6 +3333,22 @@ files = [ {file = "pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f"}, ] +[[package]] +name = "pyepsilla" +version = "0.2.3" +description = "Epsilla Python SDK" +optional = true +python-versions = "*" +files = [ + {file = "pyepsilla-0.2.3-py3-none-any.whl", hash = "sha256:05bf5f95dc1bd0dfdacac84b844d1505d8aeac442e0c0eadc834ce3ab75ab845"}, + {file = "pyepsilla-0.2.3.tar.gz", hash = "sha256:ce302ad965d428dbb22acb574f51046bfa8456204ead7f874ebd63bb5bc820a0"}, +] + +[package.dependencies] +posthog = "*" +requests = "*" +sentry-sdk = "*" + [[package]] name = "pygments" version = "2.14.0" @@ -4035,6 +4104,51 @@ nativelib = ["pyobjc-framework-Cocoa", "pywin32"] objc = ["pyobjc-framework-Cocoa"] win32 = ["pywin32"] +[[package]] +name = "sentry-sdk" +version = "1.38.0" +description = "Python client for Sentry (https://sentry.io)" +optional = true +python-versions = "*" +files = [ + {file = "sentry-sdk-1.38.0.tar.gz", hash = "sha256:8feab81de6bbf64f53279b085bd3820e3e737403b0a0d9317f73a2c3374ae359"}, + {file = "sentry_sdk-1.38.0-py2.py3-none-any.whl", hash = "sha256:0017fa73b8ae2d4e57fd2522ee3df30453715b29d2692142793ec5d5f90b94a6"}, +] + +[package.dependencies] +certifi = "*" +urllib3 = {version = ">=1.26.11", markers = "python_version >= \"3.6\""} + +[package.extras] +aiohttp = ["aiohttp (>=3.5)"] +arq = ["arq (>=0.23)"] +asyncpg = ["asyncpg (>=0.23)"] +beam = ["apache-beam (>=2.12)"] +bottle = ["bottle (>=0.12.13)"] +celery = ["celery (>=3)"] +chalice = ["chalice (>=1.16.0)"] +clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] +django = ["django (>=1.8)"] +falcon = ["falcon (>=1.4)"] +fastapi = ["fastapi (>=0.79.0)"] +flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"] +grpcio = ["grpcio (>=1.21.1)"] +httpx = ["httpx (>=0.16.0)"] +huey = ["huey (>=2)"] +loguru = ["loguru (>=0.5)"] +opentelemetry = ["opentelemetry-distro (>=0.35b0)"] +opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] +pure-eval = ["asttokens", "executing", "pure-eval"] +pymongo = ["pymongo (>=3.1)"] +pyspark = ["pyspark (>=2.4.4)"] +quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] +rq = ["rq (>=0.6)"] +sanic = ["sanic (>=0.8)"] +sqlalchemy = ["sqlalchemy (>=1.2)"] +starlette = ["starlette (>=0.19.1)"] +starlite = ["starlite (>=1.48)"] +tornado = ["tornado (>=5)"] + [[package]] name = "setuptools" version = "65.5.1" @@ -4975,6 +5089,7 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" audio = ["pydub"] aws = ["smart-open"] elasticsearch = ["elastic-transport", "elasticsearch"] +epsilla = ["pyepsilla"] full = ["av", "jax", "lz4", "pandas", "pillow", "protobuf", "pydub", "trimesh", "types-pillow"] hnswlib = ["hnswlib", "protobuf"] image = ["pillow", "types-pillow"] @@ -4994,4 +5109,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "76f26e1728fcb194a46799bccdec97ffcb5778bbb1a73eabb7aa9ee18fbced6e" +content-hash = "469714891dd7e3e6ddb406402602f0b1bb09215bfbd3fd8d237a061a0f6b3167" diff --git a/pyproject.toml b/pyproject.toml index c6444e44be..9eae1d0cee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ qdrant-client = {version = ">=1.4.0", python = "<3.12", optional = true } pymilvus = {version = "^2.2.12", optional = true } redis = {version = "^4.6.0", optional = true} jax = {version = ">=0.4.10", optional = true} +pyepsilla = {version = ">=0.2.3", optional = true} [tool.poetry.extras] proto = ["protobuf", "lz4"] @@ -80,6 +81,7 @@ weaviate = ["weaviate-client"] milvus = ["pymilvus"] redis = ['redis'] jax = ["jaxlib","jax"] +epsilla = ["pyepsilla"] # all full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh", "jax"] diff --git a/tests/index/epsilla/__init__.py b/tests/index/epsilla/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/index/epsilla/common.py b/tests/index/epsilla/common.py new file mode 100644 index 0000000000..4dc0d02336 --- /dev/null +++ b/tests/index/epsilla/common.py @@ -0,0 +1,12 @@ +epsilla_config = { + "protocol": 'http', + "host": 'localhost', + "port": 8888, + "is_self_hosted": True, + "db_path": "/epsilla", + "db_name": "tony_doc_array_test", +} + + +def index_len(index, max_len=20): + return len(index.filter("", limit=max_len)) diff --git a/tests/index/epsilla/conftest.py b/tests/index/epsilla/conftest.py new file mode 100644 index 0000000000..8339a4de99 --- /dev/null +++ b/tests/index/epsilla/conftest.py @@ -0,0 +1,11 @@ +import random +import string + +import pytest + + +@pytest.fixture(scope='function') +def tmp_index_name(): + letters = string.ascii_lowercase + random_string = ''.join(random.choice(letters) for _ in range(15)) + return random_string diff --git a/tests/index/epsilla/docker-compose.yml b/tests/index/epsilla/docker-compose.yml new file mode 100644 index 0000000000..8be3fa5dba --- /dev/null +++ b/tests/index/epsilla/docker-compose.yml @@ -0,0 +1,12 @@ +version: '3.5' + +services: + standalone: + container_name: epsilla + image: epsilla/vectordb + ports: + - "8888:8888" + +networks: + default: + name: epsilla \ No newline at end of file diff --git a/tests/index/epsilla/fixtures.py b/tests/index/epsilla/fixtures.py new file mode 100644 index 0000000000..260fdf54f8 --- /dev/null +++ b/tests/index/epsilla/fixtures.py @@ -0,0 +1,16 @@ +import os +import time + +import pytest + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +epsilla_yml = os.path.abspath(os.path.join(cur_dir, 'docker-compose.yml')) + + +@pytest.fixture(scope='session', autouse=True) +def start_storage(): + os.system(f"docker compose -f {epsilla_yml} up -d --remove-orphans") + time.sleep(2) + + yield + os.system(f"docker compose -f {epsilla_yml} down --remove-orphans") diff --git a/tests/index/epsilla/test_configuration.py b/tests/index/epsilla/test_configuration.py new file mode 100644 index 0000000000..5bee7fa543 --- /dev/null +++ b/tests/index/epsilla/test_configuration.py @@ -0,0 +1,62 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import EpsillaDocumentIndex +from docarray.typing import NdArray +from tests.index.epsilla.common import epsilla_config +from tests.index.epsilla.fixtures import start_storage # noqa: F401 + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +def test_configure_dim(): + class Schema1(BaseDoc): + tens: NdArray[10] = Field(is_embedding=True) + + index = EpsillaDocumentIndex[Schema1](**epsilla_config) + + docs = [Schema1(tens=np.random.random((10,))) for _ in range(10)] + + assert len(index.find(docs[0], limit=30, search_field="tens")[0]) == 0 + + index.index(docs) + + doc_found = index.find(docs[0], limit=1, search_field="tens")[0][0] + assert doc_found.id == docs[0].id + + assert len(index.find(docs[0], limit=30, search_field="tens")[0]) == 10 + + class Schema2(BaseDoc): + tens: NdArray = Field(is_embedding=True, dim=10) + + index = EpsillaDocumentIndex[Schema2](**epsilla_config) + + docs = [Schema2(tens=np.random.random((10,))) for _ in range(10)] + index.index(docs) + + assert len(index.find(docs[0], limit=30, search_field="tens")[0]) == 10 + + class Schema3(BaseDoc): + tens: NdArray = Field(is_embedding=True) + + with pytest.raises(ValueError, match='The dimension information is missing'): + EpsillaDocumentIndex[Schema3](**epsilla_config) + + +def test_incorrect_vector_field(): + class Schema1(BaseDoc): + tens: NdArray[10] + + with pytest.raises(ValueError, match='Unable to find any vector columns'): + EpsillaDocumentIndex[Schema1](**epsilla_config) + + class Schema2(BaseDoc): + tens1: NdArray[10] = Field(is_embedding=True) + tens2: NdArray[20] = Field(is_embedding=True) + + with pytest.raises( + ValueError, match='Specifying multiple vector fields is not supported' + ): + EpsillaDocumentIndex[Schema2](**epsilla_config) diff --git a/tests/index/epsilla/test_find.py b/tests/index/epsilla/test_find.py new file mode 100644 index 0000000000..f360163b11 --- /dev/null +++ b/tests/index/epsilla/test_find.py @@ -0,0 +1,323 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc, DocList +from docarray.index import EpsillaDocumentIndex +from docarray.typing import NdArray, TorchTensor +from tests.index.epsilla.common import epsilla_config +from tests.index.epsilla.fixtures import start_storage # noqa: F401 + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(is_embedding=True, dim=1000) # type: ignore[valid-type] + + +class FlatDoc(BaseDoc): + tens_one: NdArray = Field(is_embedding=True, dim=10) + tens_two: NdArray = Field(dim=50) + + +class TorchDoc(BaseDoc): + tens: TorchTensor[10] = Field(is_embedding=True) # type: ignore[valid-type] + + +@pytest.mark.parametrize('space', ['l2', 'ip']) +def test_find_simple_schema(space, tmp_index_name): + class SimpleSchema(BaseDoc): + tens: NdArray[10] = Field(is_embedding=True, space=space) # type: ignore[valid-type] + + index = EpsillaDocumentIndex[SimpleSchema]( + **epsilla_config, table_name=tmp_index_name + ) + + index_docs = [SimpleDoc(tens=np.zeros(10)) for _ in range(10)] + index_docs.append(SimpleDoc(tens=np.ones(10))) + index.index(index_docs) + + query = SimpleDoc(tens=np.ones(10)) + + docs, scores = index.find(query, limit=5, search_field="tens") + + assert len(docs) == 5 + assert len(scores) == 5 + + +def test_find_torch(tmp_index_name): + index = EpsillaDocumentIndex[TorchDoc](**epsilla_config, table_name=tmp_index_name) + + index_docs = [TorchDoc(tens=np.zeros(10)) for _ in range(10)] + index_docs.append(TorchDoc(tens=np.ones(10))) + index.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TorchTensor) + + query = TorchDoc(tens=np.ones(10)) + + result_docs, scores = index.find(query, limit=5, search_field="tens") + + assert len(result_docs) == 5 + assert len(scores) == 5 + for doc in result_docs: + assert isinstance(doc.tens, TorchTensor) + + +@pytest.mark.tensorflow +def test_find_tensorflow(): + from docarray.typing import TensorFlowTensor + + class TfDoc(BaseDoc): + tens: TensorFlowTensor[10] = Field(is_embedding=True) # type: ignore[valid-type] + + index = EpsillaDocumentIndex[TfDoc](**epsilla_config) + + index_docs = [TfDoc(tens=np.random.rand(10)) for _ in range(10)] + index.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TensorFlowTensor) + + query = index_docs[-1] + docs, scores = index.find(query, limit=5, search_field="tens") + + assert len(docs) == 5 + assert len(scores) == 5 + for doc in docs: + assert isinstance(doc.tens, TensorFlowTensor) + + +def test_find_batched(tmp_index_name): # noqa: F811 + class SimpleSchema(BaseDoc): + tens: NdArray[10] = Field(is_embedding=True) + + index = EpsillaDocumentIndex[SimpleSchema]( + **epsilla_config, table_name=tmp_index_name + ) + + index_docs = [SimpleDoc(tens=vector) for vector in np.identity(10)] + index.index(index_docs) + + queries = DocList[SimpleDoc]( + [ + SimpleDoc( + tens=np.array([0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + ), + SimpleDoc( + tens=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1]) + ), + ] + ) + + docs, scores = index.find_batched(queries, limit=1, search_field="tens") + + assert len(docs) == 2 + assert len(docs[0]) == 1 + assert len(docs[1]) == 1 + assert len(scores) == 2 + assert len(scores[0]) == 1 + assert len(scores[1]) == 1 + + +def test_contain(tmp_index_name): + class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(is_embedding=True) + + class SimpleSchema(BaseDoc): + tens: NdArray[10] = Field(is_embedding=True) + + index = EpsillaDocumentIndex[SimpleSchema]( + **epsilla_config, table_name=tmp_index_name + ) + index_docs = [SimpleDoc(tens=np.zeros(10)) for _ in range(10)] + + assert (index_docs[0] in index) is False + + index.index(index_docs) + + for doc in index_docs: + assert (doc in index) is True + + index_docs_new = [SimpleDoc(tens=np.zeros(10)) for _ in range(10)] + for doc in index_docs_new: + assert (doc in index) is False + + +@pytest.mark.parametrize('space', ['l2', 'ip']) +def test_find_flat_schema(space, tmp_index_name): + class FlatSchema(BaseDoc): + tens_one: NdArray[10] = Field(space=space, is_embedding=True) + tens_two: NdArray[50] = Field(space=space) + + index = EpsillaDocumentIndex[FlatSchema]( + **epsilla_config, table_name=tmp_index_name + ) + + index_docs = [ + FlatDoc(tens_one=np.zeros(10), tens_two=np.zeros(50)) for _ in range(10) + ] + index_docs.append(FlatDoc(tens_one=np.zeros(10), tens_two=np.ones(50))) + index_docs.append(FlatDoc(tens_one=np.ones(10), tens_two=np.zeros(50))) + index.index(index_docs) + + query = FlatDoc(tens_one=np.ones(10), tens_two=np.ones(50)) + + # find on tens_one + docs, scores = index.find(query, limit=5, search_field="tens_one") + assert len(docs) == 5 + assert len(scores) == 5 + + +def test_find_nested_schema(tmp_index_name): + class SimpleDoc(BaseDoc): + tens: NdArray[10] # type: ignore[valid-type] + + class NestedDoc(BaseDoc): + d: SimpleDoc + tens: NdArray[10] # type: ignore[valid-type] + + class DeepNestedDoc(BaseDoc): + d: NestedDoc + tens: NdArray[10] = Field(is_embedding=True) + + index = EpsillaDocumentIndex[DeepNestedDoc]( + **epsilla_config, table_name=tmp_index_name + ) + + index_docs = [ + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.zeros(10)), tens=np.zeros(10)), + tens=np.zeros(10), + ) + for _ in range(10) + ] + index_docs.append( + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.ones(10)), tens=np.zeros(10)), + tens=np.zeros(10), + ) + ) + index_docs.append( + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.zeros(10)), tens=np.ones(10)), + tens=np.zeros(10), + ) + ) + index_docs.append( + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.zeros(10)), tens=np.zeros(10)), + tens=np.ones(10), + ) + ) + index.index(index_docs) + + query = DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.ones(10)), tens=np.ones(10)), tens=np.ones(10) + ) + + # find on root level (only support one level now) + docs, scores = index.find(query, limit=5, search_field="tens") + assert len(docs) == 5 + assert len(scores) == 5 + + +def test_find_empty_index(tmp_index_name): + empty_index = EpsillaDocumentIndex[SimpleDoc]( + **epsilla_config, table_name=tmp_index_name + ) + query = SimpleDoc(tens=np.random.rand(10)) + + # find + docs, scores = empty_index.find(query, limit=5, search_field="tens") + assert len(docs) == 0 + assert len(scores) == 0 + + # find_batched + queries = DocList[SimpleDoc]( + [ + SimpleDoc( + tens=np.array([0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + ), + SimpleDoc( + tens=np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1]) + ), + ] + ) + docs_list, scores = empty_index.find_batched(queries, limit=10, search_field="tens") + + for docs in docs_list: + assert len(docs) == 0 + + +def test_simple_usage(tmp_index_name): + class MyDoc(BaseDoc): + text: str + embedding: NdArray[128] = Field(is_embedding=True) + + docs = [MyDoc(text='hey', embedding=np.random.rand(128)) for _ in range(200)] + queries = docs[0:3] + index = EpsillaDocumentIndex[MyDoc](**epsilla_config, table_name=tmp_index_name) + index.index(docs=DocList[MyDoc](docs)) + resp = index.find_batched(queries=queries, limit=5, search_field="embedding") + docs_responses = resp.documents + assert len(docs_responses) == 3 + for q, matches in zip(queries, docs_responses): + assert len(matches) == 5 + assert q.id == matches[0].id + + +def test_filter_range(tmp_index_name): # noqa: F811 + class SimpleSchema(BaseDoc): + embedding: NdArray[10] = Field(space='l2', is_embedding=True) # type: ignore[valid-type] + number: int + + index = EpsillaDocumentIndex[SimpleSchema]( + **epsilla_config, table_name=tmp_index_name + ) + + docs = index.filter("number > 8", limit=5) + assert len(docs) == 0 + + index_docs = [ + SimpleSchema( + embedding=np.zeros(10), + number=i, + ) + for i in range(10) + ] + index.index(index_docs) + + docs = index.filter("number > 8", limit=5) + + assert len(docs) == 1 + + docs = index.filter(f"id = '{index_docs[0].id}'", limit=5) + assert docs[0].id == index_docs[0].id + + +def test_query_builder(tmp_index_name): + class SimpleSchema(BaseDoc): + tensor: NdArray[10] = Field(is_embedding=True) + price: int + + db = EpsillaDocumentIndex[SimpleSchema](**epsilla_config, table_name=tmp_index_name) + + index_docs = [ + SimpleSchema(tensor=np.array([i + 1] * 10), price=i + 1) for i in range(10) + ] + db.index(index_docs) + + q = ( + db.build_query() + .find(query=np.ones(10), search_field="tensor") + .filter(filter_query='price <= 3') + .build(limit=5) + ) + + docs = db.execute_query(q) + + assert len(docs) == 3 + for doc in docs: + assert doc.price <= 3 diff --git a/tests/index/epsilla/test_index_get_del.py b/tests/index/epsilla/test_index_get_del.py new file mode 100644 index 0000000000..2fdf066c56 --- /dev/null +++ b/tests/index/epsilla/test_index_get_del.py @@ -0,0 +1,155 @@ +import numpy as np +import pytest +import torch +from pydantic import Field + +from docarray import BaseDoc, DocList +from docarray.index import EpsillaDocumentIndex +from docarray.typing import NdArray, TorchTensor +from tests.index.epsilla.common import epsilla_config, index_len +from tests.index.epsilla.fixtures import start_storage # noqa: F401 + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(is_embedding=True) + + +class FlatDoc(BaseDoc): + tens_one: NdArray[10] = Field(is_embedding=True) + tens_two: NdArray[50] + + +class NestedDoc(BaseDoc): + d: SimpleDoc + + +class DeepNestedDoc(BaseDoc): + d: NestedDoc + + +class TorchDoc(BaseDoc): + tens: TorchTensor[10] = Field(is_embedding=True) # type: ignore[valid-type] + + +@pytest.fixture +def ten_simple_docs(): + return [SimpleDoc(tens=np.random.randn(10)) for _ in range(10)] + + +@pytest.fixture +def ten_flat_docs(): + return [ + FlatDoc(tens_one=np.random.randn(10), tens_two=np.random.randn(50)) + for _ in range(10) + ] + + +@pytest.fixture +def ten_nested_docs(): + return [NestedDoc(d=SimpleDoc(tens=np.random.randn(10))) for _ in range(10)] + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_simple_schema( + ten_simple_docs, use_docarray, tmp_index_name +): # noqa: F811 + index = EpsillaDocumentIndex[SimpleDoc](**epsilla_config, table_name=tmp_index_name) + if use_docarray: + ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) + + index.index(ten_simple_docs) + assert index_len(index) == 10 + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_flat_schema(ten_flat_docs, use_docarray, tmp_index_name): # noqa: F811 + index = EpsillaDocumentIndex[FlatDoc](**epsilla_config, table_name=tmp_index_name) + if use_docarray: + ten_flat_docs = DocList[FlatDoc](ten_flat_docs) + + index.index(ten_flat_docs) + assert index_len(index) == 10 + + +def test_index_torch(tmp_index_name): + docs = [TorchDoc(tens=np.random.randn(10)) for _ in range(10)] + assert isinstance(docs[0].tens, torch.Tensor) + assert isinstance(docs[0].tens, TorchTensor) + + index = EpsillaDocumentIndex[TorchDoc](**epsilla_config, table_name=tmp_index_name) + + index.index(docs) + assert index_len(index) == 10 + + +def test_del_from_empty(ten_simple_docs, tmp_index_name): # noqa: F811 + index = EpsillaDocumentIndex[SimpleDoc](**epsilla_config, table_name=tmp_index_name) + assert index_len(index) == 0 + del index[ten_simple_docs[0].id] + assert index_len(index) == 0 + + +def test_del_single(ten_simple_docs, tmp_index_name): # noqa: F811 + index = EpsillaDocumentIndex[SimpleDoc](**epsilla_config, table_name=tmp_index_name) + index.index(ten_simple_docs) + # delete once + assert index_len(index) == 10 + del index[ten_simple_docs[0].id] + assert index_len(index) == 9 + for i, d in enumerate(ten_simple_docs): + id_ = d.id + if i == 0: # deleted + with pytest.raises(KeyError): + index[id_] + else: + assert index[id_].id == id_ + # delete again + del index[ten_simple_docs[3].id] + assert index_len(index) == 8 + for i, d in enumerate(ten_simple_docs): + id_ = d.id + if i in (0, 3): # deleted + with pytest.raises(KeyError): + index[id_] + else: + assert index[id_].id == id_ + + +def test_del_multiple(ten_simple_docs, tmp_index_name): + docs_to_del_idx = [0, 2, 4, 6, 8] + + index = EpsillaDocumentIndex[SimpleDoc](**epsilla_config, table_name=tmp_index_name) + index.index(ten_simple_docs) + + assert index_len(index) == 10 + docs_to_del = [ten_simple_docs[i] for i in docs_to_del_idx] + ids_to_del = [d.id for d in docs_to_del] + del index[ids_to_del] + for i, doc in enumerate(ten_simple_docs): + if i in docs_to_del_idx: + with pytest.raises(KeyError): + index[doc.id] + else: + assert index[doc.id].id == doc.id + + +def test_num_docs(ten_simple_docs, tmp_index_name): # noqa: F811 + index = EpsillaDocumentIndex[SimpleDoc](**epsilla_config, table_name=tmp_index_name) + index.index(ten_simple_docs) + + assert index_len(index) == 10 + + del index[ten_simple_docs[0].id] + assert index_len(index) == 9 + + del index[ten_simple_docs[3].id, ten_simple_docs[5].id] + assert index_len(index) == 7 + + more_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(5)] + index.index(more_docs) + assert index_len(index) == 12 + + del index[more_docs[2].id, ten_simple_docs[7].id] # type: ignore[arg-type] + assert index_len(index) == 10 diff --git a/tests/index/epsilla/test_persist_data.py b/tests/index/epsilla/test_persist_data.py new file mode 100644 index 0000000000..16bd6d16c4 --- /dev/null +++ b/tests/index/epsilla/test_persist_data.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import EpsillaDocumentIndex +from docarray.typing import NdArray +from tests.index.epsilla.common import epsilla_config, index_len +from tests.index.epsilla.fixtures import start_storage # noqa: F401 + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(is_embedding=True) + + +def test_persist(tmp_index_name): + query = SimpleDoc(tens=np.random.random((10,))) + + # create index + index = EpsillaDocumentIndex[SimpleDoc](**epsilla_config, table_name=tmp_index_name) + + index_name = index.index_name + + assert index_len(index) == 0 + + index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) + assert index_len(index) == 10 + find_results_before = index.find(query, limit=5, search_field="tens") + + # load existing index + index = EpsillaDocumentIndex[SimpleDoc](**epsilla_config, table_name=index_name) + assert index_len(index) == 10 + find_results_after = index.find(query, limit=5, search_field="tens") + for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): + assert doc_before.id == doc_after.id + assert (doc_before.tens == doc_after.tens).all() + + # add new data + index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)]) + assert index_len(index) == 15 pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy