diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 210134ac4ae..377e2311215 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,6 +65,7 @@ jobs: python -m pip install poetry poetry install --without dev poetry run pip install tensorflow==2.11.0 + poetry run pip install jax - name: Test basic import run: poetry run python -c 'from docarray import DocList, BaseDoc' @@ -111,7 +112,7 @@ jobs: - name: Test id: test run: | - poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py + poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py echo "flag it as docarray for codeoverage" echo "codecov_flag=docarray" >> $GITHUB_OUTPUT timeout-minutes: 30 @@ -158,7 +159,7 @@ jobs: - name: Test id: test run: | - poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml tests/integrations/store/test_jac.py + poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml tests/integrations/store/test_jac.py echo "flag it as docarray for codeoverage" echo "codecov_flag=docarray" >> $GITHUB_OUTPUT timeout-minutes: 30 @@ -357,6 +358,49 @@ jobs: flags: ${{ steps.test.outputs.codecov_flag }} fail_ci_if_error: false + docarray-test-jax: + needs: [import-test] + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.8] + steps: + - uses: actions/checkout@v2.5.0 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Prepare environment + run: | + python -m pip install --upgrade pip + python -m pip install poetry + poetry install --all-extras + poetry run pip install jaxlib + poetry run pip install jax + + - name: Test + id: test + run: | + poetry run pytest -m 'jax' --cov=docarray --cov-report=xml tests + echo "flag it as docarray for codeoverage" + echo "codecov_flag=docarray" >> $GITHUB_OUTPUT + timeout-minutes: 30 + - name: Check codecov file + id: check_files + uses: andstor/file-existence-action@v1 + with: + files: "coverage.xml" + - name: Upload coverage from test to Codecov + uses: codecov/codecov-action@v3.1.1 + if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8' + with: + file: coverage.xml + name: benchmark-test-codecov + flags: ${{ steps.test.outputs.codecov_flag }} + fail_ci_if_error: false + + docarray-test-benchmarks: needs: [import-test] diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index 4778cd44604..a175cb4e4aa 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -32,7 +32,11 @@ from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal._typing import is_tensor_union, safe_issubclass -from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils._internal.misc import ( + is_jax_available, + is_tf_available, + is_torch_available, +) if TYPE_CHECKING: import csv @@ -60,6 +64,14 @@ else: TensorFlowTensor = None # type: ignore +jnp_available = is_jax_available() +if jnp_available: + import jax.numpy as jnp # type: ignore + + from docarray.typing import JaxArray # noqa: F401 +else: + JaxArray = None # type: ignore + T_doc = TypeVar('T_doc', bound=BaseDoc) T = TypeVar('T', bound='DocVec') T_io_mixin = TypeVar('T_io_mixin', bound='IOMixinArray') @@ -262,6 +274,19 @@ def _check_doc_field_not_none(field_name, doc): stacked: tf.Tensor = tf.stack(tf_stack) tensor_columns[field_name] = TensorFlowTensor(stacked) + elif jnp_available and issubclass(field_type, JaxArray): + if first_doc_is_none: + _verify_optional_field_of_docs(docs) + tensor_columns[field_name] = None + else: + tf_stack = [] + for i, doc in enumerate(docs): + val = getattr(doc, field_name) + _check_doc_field_not_none(field_name, doc) + tf_stack.append(val.tensor) + + jax_stacked: jnp.ndarray = jnp.stack(tf_stack) + tensor_columns[field_name] = JaxArray(jax_stacked) elif safe_issubclass(field_type, AbstractTensor): if first_doc_is_none: @@ -835,7 +860,6 @@ def to_doc_list(self: T) -> DocList[T_doc]: unstacked_doc_column[field] = doc_col.to_doc_list() if doc_col else None for field, da_col in self._storage.docs_vec_columns.items(): - unstacked_da_column[field] = ( [docs.to_doc_list() for docs in da_col] if da_col else None ) diff --git a/docarray/array/list_advance_indexing.py b/docarray/array/list_advance_indexing.py index 25e966480c8..c3d80ad2f6c 100644 --- a/docarray/array/list_advance_indexing.py +++ b/docarray/array/list_advance_indexing.py @@ -14,8 +14,9 @@ from typing_extensions import SupportsIndex from docarray.utils._internal.misc import ( - is_torch_available, + is_jax_available, is_tf_available, + is_torch_available, ) torch_available = is_torch_available() @@ -24,7 +25,13 @@ tf_available = is_tf_available() if tf_available: import tensorflow as tf # type: ignore + from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing.tensor.jaxarray import JaxArray T_item = TypeVar('T_item') T = TypeVar('T', bound='ListAdvancedIndexing') @@ -100,6 +107,12 @@ def _normalize_index_item( if isinstance(item, TensorFlowTensor): return item.tensor.numpy().tolist() + if jax_available: + if isinstance(item, jnp.ndarray): + return item.__array__().tolist() + if isinstance(item, JaxArray): + return item.tensor.__array__().tolist() + return item def _get_from_indices(self: T, item: Iterable[int]) -> T: diff --git a/docarray/computation/jax_backend.py b/docarray/computation/jax_backend.py new file mode 100644 index 00000000000..f571c79b701 --- /dev/null +++ b/docarray/computation/jax_backend.py @@ -0,0 +1,336 @@ +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple + +import numpy as np + +from docarray.computation.abstract_comp_backend import AbstractComputationalBackend +from docarray.computation.abstract_numpy_based_backend import AbstractNumpyBasedBackend +from docarray.typing import JaxArray +from docarray.utils._internal.misc import import_library + +if TYPE_CHECKING: + import jax + import jax.numpy as jnp +else: + jax = import_library('jax', raise_error=True) + jnp = jax.numpy + + +def _expand_if_single_axis(*matrices: jnp.ndarray) -> List[jnp.ndarray]: + """Expands arrays that only have one axis, at dim 0. + This ensures that all outputs can be treated as matrices, not vectors. + + :param matrices: Matrices to be expanded + :return: List of the input matrices, + where single axis matrices are expanded at dim 0. + """ + expanded = [] + for m in matrices: + if len(m.shape) == 1: + expanded.append(jnp.expand_dims(m, axis=0)) + else: + expanded.append(m) + return expanded + + +def _expand_if_scalar(arr: jnp.ndarray) -> jnp.ndarray: + if len(arr.shape) == 0: # avoid scalar output + arr = jnp.expand_dims(arr, axis=0) + return arr + + +def norm_left(t: jnp.ndarray) -> JaxArray: + return JaxArray(tensor=t) + + +def norm_right(t: JaxArray) -> jnp.ndarray: + return t.tensor + + +class JaxCompBackend(AbstractNumpyBasedBackend): + """ + Computational backend for Jax. + """ + + _module = jnp + _cast_output: Callable = norm_left + _get_tensor: Callable = norm_right + + @classmethod + def to_device(cls, tensor: 'JaxArray', device: str) -> 'JaxArray': + """Move the tensor to the specified device.""" + if cls.device(tensor) == device: + return tensor + else: + jax_devices = jax.devices(device) + return cls._cast_output( + jax.device_put(cls._get_tensor(tensor), jax_devices) + ) + + @classmethod + def device(cls, tensor: 'JaxArray') -> Optional[str]: + """Return device on which the tensor is allocated.""" + return cls._get_tensor(tensor).device().platform + + @classmethod + def to_numpy(cls, array: 'JaxArray') -> 'np.ndarray': + return cls._get_tensor(array).__array__() + + @classmethod + def none_value(cls) -> Any: + """Provide a compatible value that represents None in JAX.""" + return jnp.nan + + @classmethod + def detach(cls, tensor: 'JaxArray') -> 'JaxArray': + """ + Returns the tensor detached from its current graph. + + :param tensor: tensor to be detached + :return: a detached tensor with the same data. + """ + return cls._cast_output(jax.lax.stop_gradient(cls._get_tensor(tensor))) + + @classmethod + def dtype(cls, tensor: 'JaxArray') -> jnp.dtype: + """Get the data type of the tensor.""" + d_type = cls._get_tensor(tensor).dtype + return d_type.name + + @classmethod + def minmax_normalize( + cls, + tensor: 'JaxArray', + t_range: Tuple = (0, 1), + x_range: Optional[Tuple] = None, + eps: float = 1e-7, + ) -> 'JaxArray': + """ + Normalize values in `tensor` into `t_range`. + + `tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then + normalization is row-based. + + !!! note + + - with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1; + - with `t_range=(1, 0)` will normalize the min-value of data to 1, max value + of the data to 0. + + :param tensor: the data to be normalized + :param t_range: a tuple represents the target range. + :param x_range: a tuple represents tensors range. + :param eps: a small jitter to avoid dividing by zero + :return: normalized data in `t_range` + """ + a, b = t_range + + t = jnp.asarray(cls._get_tensor(tensor), jnp.float32) + + min_d = x_range[0] if x_range else jnp.min(t, axis=-1, keepdims=True) + max_d = x_range[1] if x_range else jnp.max(t, axis=-1, keepdims=True) + r = (b - a) * (t - min_d) / (max_d - min_d + eps) + a + + normalized = jnp.clip(r, *((a, b) if a < b else (b, a))) + return cls._cast_output(jnp.asarray(normalized, cls._get_tensor(tensor).dtype)) + + @classmethod + def equal(cls, tensor1: 'JaxArray', tensor2: 'JaxArray') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first tensor + :param tensor2: the second tensor + :return: True if two tensors are equal, False otherwise. + If one or more of the inputs is not a TensorFlowTensor, return False. + """ + t1, t2 = getattr(tensor1, 'tensor', None), getattr(tensor2, 'tensor', None) + if isinstance(t1, jnp.ndarray) and isinstance(t2, jnp.ndarray): + # mypy doesn't know that tf.is_tensor implies that t1, t2 are not None + return t1.shape == t2.shape and jnp.all(jnp.equal(t1, t1)) # type: ignore + return False + + class Retrieval(AbstractComputationalBackend.Retrieval[JaxArray]): + """ + Abstract class for retrieval and ranking functionalities + """ + + @staticmethod + def top_k( + values: 'JaxArray', + k: int, + descending: bool = False, + device: Optional[str] = None, + ) -> Tuple['JaxArray', 'JaxArray']: + """ + Returns the k smallest values in `values` along with their indices. + Can also be used to retrieve the k largest values, + by setting the `descending` flag. + + :param values: Jax tensor of values to rank. + Should be of shape (n_queries, n_values_per_query). + Inputs of shape (n_values_per_query,) will be expanded + to (1, n_values_per_query). + :param k: number of values to retrieve + :param descending: retrieve largest values instead of smallest values + :param device: Not supported for this backend + :return: Tuple containing the retrieved values, and their indices. + Both are of shape (n_queries, k) + """ + comp_be = JaxCompBackend + if device is not None: + values = comp_be.to_device(values, device) + + jax_values: jnp.ndarray = comp_be._get_tensor(values) + + if len(jax_values.shape) == 1: + jax_values = jnp.expand_dims(jax_values, axis=0) + + if descending: + jax_values = -jax_values + + if k >= jax_values.shape[1]: + idx = jax_values.argsort(axis=1)[:, :k] + jax_values = jnp.take_along_axis(jax_values, idx, axis=1) + else: + idx_ps = jax_values.argpartition(kth=k, axis=1)[:, :k] + jax_values = jnp.take_along_axis(jax_values, idx_ps, axis=1) + idx_fs = jax_values.argsort(axis=1) + idx = jnp.take_along_axis(idx_ps, idx_fs, axis=1) + jax_values = jnp.take_along_axis(jax_values, idx_fs, axis=1) + + if descending: + jax_values = -jax_values + + return comp_be._cast_output(jax_values), comp_be._cast_output(idx) + + class Metrics(AbstractComputationalBackend.Metrics[JaxArray]): + """ + Abstract base class for metrics (distances and similarities). + """ + + @staticmethod + def cosine_sim( + x_mat: 'JaxArray', + y_mat: 'JaxArray', + eps: float = 1e-7, + device: Optional[str] = None, + ) -> 'JaxArray': + """Pairwise cosine similarities between all vectors in x_mat and y_mat. + + :param x_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the + number of vectors and n_dim is the number of dimensions of each example. + :param y_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the + number of vectors and n_dim is the number of dimensions of each example. + :param eps: a small jitter to avoid dividing by zero + :param device: the device to use for computations. + If not provided, the devices of x_mat and y_mat are used. + :return: JaxArray of shape (n_vectors, n_vectors) containing all pairwise + cosine distances. + The index [i_x, i_y] contains the cosine distance between + x_mat[i_x] and y_mat[i_y]. + """ + comp_be = JaxCompBackend + x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat) + y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat) + + x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax) + + sims = jnp.clip( + (jnp.dot(x_mat_jax, y_mat_jax.T) + eps) + / ( + jnp.outer( + jnp.linalg.norm(x_mat_jax, axis=1), + jnp.linalg.norm(y_mat_jax, axis=1), + ) + + eps + ), + -1, + 1, + ).squeeze() + sims = _expand_if_scalar(sims) + + return comp_be._cast_output(sims) + + @classmethod + def euclidean_dist( + cls, x_mat: JaxArray, y_mat: JaxArray, device: Optional[str] = None + ) -> JaxArray: + """Pairwise Euclidian distances between all vectors in x_mat and y_mat. + + :param x_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param y_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param eps: a small jitter to avoid dividing by zero + :param device: Not supported for this backend + :return: JaxArray of shape (n_vectors, n_vectors) containing all + pairwise euclidian distances. + The index [i_x, i_y] contains the euclidian distance between + x_mat[i_x] and y_mat[i_y]. + """ + comp_be = JaxCompBackend + x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat) + y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat) + if device is not None: + # warnings.warn('`device` is not supported for numpy operations') + pass + + x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax) + + x_mat_jax_arr: JaxArray = comp_be._cast_output(x_mat_jax) + y_mat_jax_arr: JaxArray = comp_be._cast_output(y_mat_jax) + + dists = _expand_if_scalar( + jnp.sqrt( + comp_be._get_tensor( + cls.sqeuclidean_dist(x_mat_jax_arr, y_mat_jax_arr) + ) + ).squeeze() + ) + + return comp_be._cast_output(dists) + + @staticmethod + def sqeuclidean_dist( + x_mat: JaxArray, + y_mat: JaxArray, + device: Optional[str] = None, + ) -> JaxArray: + """Pairwise Squared Euclidian distances between all vectors in + x_mat and y_mat. + + :param x_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param y_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param device: Not supported for this backend + :return: JaxArray of shape (n_vectors, n_vectors) containing all + pairwise Squared Euclidian distances. + The index [i_x, i_y] contains the cosine Squared Euclidian between + x_mat[i_x] and y_mat[i_y]. + """ + comp_be = JaxCompBackend + x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat) + y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat) + eps: float = 1e-7 # avoid problems with numerical inaccuracies + + if device is not None: + pass + # warnings.warn('`device` is not supported for numpy operations') + + x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax) + + dists = ( + jnp.sum(y_mat_jax**2, axis=1) + + jnp.sum(x_mat_jax**2, axis=1)[:, jnp.newaxis] + - 2 * jnp.dot(x_mat_jax, y_mat_jax.T) + ).squeeze() + + # remove numerical artifacts + dists = jnp.where(np.logical_and(dists < 0, dists > -eps), 0, dists) + dists = _expand_if_scalar(dists) + return comp_be._cast_output(dists) diff --git a/docarray/typing/__init__.py b/docarray/typing/__init__.py index 5fdb578ad04..ed7e1d7b9d2 100644 --- a/docarray/typing/__init__.py +++ b/docarray/typing/__init__.py @@ -24,12 +24,20 @@ if TYPE_CHECKING: from docarray.typing.tensor import TensorFlowTensor # noqa: F401 - from docarray.typing.tensor import TorchEmbedding, TorchTensor # noqa: F401 + from docarray.typing.tensor import ( # noqa: F401 + JaxArray, + JaxArrayEmbedding, + TorchEmbedding, + TorchTensor, + ) + from docarray.typing.tensor.audio import AudioJaxArray # noqa: F401 from docarray.typing.tensor.audio import AudioTensorFlowTensor # noqa: F401 from docarray.typing.tensor.audio import AudioTorchTensor # noqa: F401 from docarray.typing.tensor.embedding import TensorFlowEmbedding # noqa: F401 + from docarray.typing.tensor.image import ImageJaxArray # noqa: F401 from docarray.typing.tensor.image import ImageTensorFlowTensor # noqa: F401 from docarray.typing.tensor.image import ImageTorchTensor # noqa: F401 + from docarray.typing.tensor.video import VideoJaxArray # noqa: F401 from docarray.typing.tensor.video import VideoTensorFlowTensor # noqa: F401 from docarray.typing.tensor.video import VideoTorchTensor # noqa: F401 @@ -73,6 +81,15 @@ 'AudioTensorFlowTensor', 'VideoTensorFlowTensor', ] + +_jax_tensors = [ + 'JaxArray', + 'JaxArrayEmbedding', + 'VideoJaxArray', + 'AudioJaxArray', + 'ImageJaxArray', +] + __all_test__ = __all__ + _torch_tensors @@ -81,6 +98,8 @@ def __getattr__(name: str): import_library('torch', raise_error=True) elif name in _tf_tensors: import_library('tensorflow', raise_error=True) + elif name in _jax_tensors: + import_library('jax', raise_error=True) else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/__init__.py b/docarray/typing/tensor/__init__.py index 4c4077f3cdb..2da7f5939ec 100644 --- a/docarray/typing/tensor/__init__.py +++ b/docarray/typing/tensor/__init__.py @@ -14,14 +14,19 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.audio import AudioJaxArray # noqa: F401 from docarray.typing.tensor.audio import AudioTensorFlowTensor # noqa: F401 from docarray.typing.tensor.audio import AudioTorchTensor # noqa: F401 + from docarray.typing.tensor.embedding import JaxArrayEmbedding # noqa F401 from docarray.typing.tensor.embedding import TensorFlowEmbedding # noqa: F401 from docarray.typing.tensor.embedding import TorchEmbedding # noqa: F401 + from docarray.typing.tensor.image import ImageJaxArray # noqa: F401 from docarray.typing.tensor.image import ImageTensorFlowTensor # noqa: F401 from docarray.typing.tensor.image import ImageTorchTensor # noqa: F401 + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401 from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401 + from docarray.typing.tensor.video import VideoJaxArray # noqa: F401 from docarray.typing.tensor.video import VideoTensorFlowTensor # noqa: F401 from docarray.typing.tensor.video import VideoTorchTensor # noqa: F401 @@ -42,19 +47,23 @@ def __getattr__(name: str): import_library('torch', raise_error=True) elif 'TensorFlow' in name: import_library('tensorflow', raise_error=True) + elif 'Jax' in name: + import_library('jax', raise_error=True) lib: types.ModuleType if name == 'TorchTensor': import docarray.typing.tensor.torch_tensor as lib elif name == 'TensorFlowTensor': import docarray.typing.tensor.tensorflow_tensor as lib - elif name in ['TorchEmbedding', 'TensorFlowEmbedding']: + elif name == 'JaxArray': + import docarray.typing.tensor.jaxarray as lib + elif name in ['TorchEmbedding', 'TensorFlowEmbedding', 'JaxArrayEmbedding']: import docarray.typing.tensor.embedding as lib - elif name in ['ImageTorchTensor', 'ImageTensorFlowTensor']: + elif name in ['ImageTorchTensor', 'ImageTensorFlowTensor', 'ImageJaxArray']: import docarray.typing.tensor.image as lib - elif name in ['AudioTorchTensor', 'AudioTensorFlowTensor']: + elif name in ['AudioTorchTensor', 'AudioTensorFlowTensor', 'AudioJaxArray']: import docarray.typing.tensor.audio as lib - elif name in ['VideoTorchTensor', 'VideoTensorFlowTensor']: + elif name in ['VideoTorchTensor', 'VideoTensorFlowTensor', 'VideoJaxArray']: import docarray.typing.tensor.video as lib else: raise ImportError( diff --git a/docarray/typing/tensor/audio/__init__.py b/docarray/typing/tensor/audio/__init__.py index a505ab05720..5f304ae544f 100644 --- a/docarray/typing/tensor/audio/__init__.py +++ b/docarray/typing/tensor/audio/__init__.py @@ -9,12 +9,13 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.audio.audio_jax_array import AudioJaxArray # noqa from docarray.typing.tensor.audio.audio_tensorflow_tensor import ( # noqa AudioTensorFlowTensor, ) from docarray.typing.tensor.audio.audio_torch_tensor import AudioTorchTensor # noqa -__all__ = ['AudioNdArray', 'AudioTensor'] +__all__ = ['AudioNdArray', 'AudioTensor', 'AudioJaxArray'] def __getattr__(name: str): @@ -25,6 +26,9 @@ def __getattr__(name: str): elif name == 'AudioTensorFlowTensor': import_library('tensorflow', raise_error=True) import docarray.typing.tensor.audio.audio_tensorflow_tensor as lib + elif name == 'AudioJaxArray': + import_library('jax', raise_error=True) + import docarray.typing.tensor.audio.audio_jax_array as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/audio/audio_jax_array.py b/docarray/typing/tensor/audio/audio_jax_array.py new file mode 100644 index 00000000000..793fd627214 --- /dev/null +++ b/docarray/typing/tensor/audio/audio_jax_array.py @@ -0,0 +1,12 @@ +from typing import TypeVar + +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor +from docarray.typing.tensor.jaxarray import JaxArray, metaJax + +T = TypeVar('T', bound='AudioJaxArray') + + +@_register_proto(proto_type_name='audio_jaxarray') +class AudioJaxArray(AbstractAudioTensor, JaxArray, metaclass=metaJax): + ... diff --git a/docarray/typing/tensor/audio/audio_tensor.py b/docarray/typing/tensor/audio/audio_tensor.py index a9171a919b2..56e651b567e 100644 --- a/docarray/typing/tensor/audio/audio_tensor.py +++ b/docarray/typing/tensor/audio/audio_tensor.py @@ -5,7 +5,11 @@ from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor from docarray.typing.tensor.audio.audio_ndarray import AudioNdArray from docarray.typing.tensor.tensor import AnyTensor -from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils._internal.misc import ( + is_jax_available, + is_tf_available, + is_torch_available, +) torch_available = is_torch_available() if torch_available: @@ -23,6 +27,12 @@ ) from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp # type: ignore + + from docarray.typing.tensor.audio.audio_jax_array import AudioJaxArray + from docarray.typing.tensor.jaxarray import JaxArray if TYPE_CHECKING: from pydantic import BaseConfig @@ -91,6 +101,11 @@ def validate( return cast(AudioTensorFlowTensor, value) elif isinstance(value, tf.Tensor): return AudioTensorFlowTensor._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return cast(AudioJaxArray, value) + elif isinstance(value, jnp.ndarray): + return AudioJaxArray._docarray_from_native(value) # noqa try: return AudioNdArray.validate(value, field, config) except Exception: # noqa diff --git a/docarray/typing/tensor/embedding/__init__.py b/docarray/typing/tensor/embedding/__init__.py index c32048b21c6..0e518b67a57 100644 --- a/docarray/typing/tensor/embedding/__init__.py +++ b/docarray/typing/tensor/embedding/__init__.py @@ -10,6 +10,7 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.embedding.jax_array import JaxArrayEmbedding # noqa from docarray.typing.tensor.embedding.tensorflow import TensorFlowEmbedding # noqa from docarray.typing.tensor.embedding.torch import TorchEmbedding # noqa @@ -24,6 +25,9 @@ def __getattr__(name: str): elif name == 'TensorFlowEmbedding': import_library('tensorflow', raise_error=True) import docarray.typing.tensor.embedding.tensorflow as lib + elif name == 'JaxArrayEmbedding': + import_library('jax', raise_error=True) + import docarray.typing.tensor.embedding.jax_array as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/embedding/embedding.py b/docarray/typing/tensor/embedding/embedding.py index b7fd9c462f7..c9bc31dc54a 100644 --- a/docarray/typing/tensor/embedding/embedding.py +++ b/docarray/typing/tensor/embedding/embedding.py @@ -5,7 +5,18 @@ from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding from docarray.typing.tensor.tensor import AnyTensor -from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa +from docarray.utils._internal.misc import ( # noqa + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp # type: ignore + + from docarray.typing.tensor.embedding.jax_array import JaxArrayEmbedding + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 torch_available = is_torch_available() if torch_available: @@ -89,6 +100,11 @@ def validate( return cast(TensorFlowEmbedding, value) elif isinstance(value, tf.Tensor): return TensorFlowEmbedding._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return cast(JaxArrayEmbedding, value) + elif isinstance(value, jnp.ndarray): + return JaxArrayEmbedding._docarray_from_native(value) # noqa try: return NdArrayEmbedding.validate(value, field, config) except Exception: # noqa diff --git a/docarray/typing/tensor/embedding/jax_array.py b/docarray/typing/tensor/embedding/jax_array.py new file mode 100644 index 00000000000..4dbb7a67ee0 --- /dev/null +++ b/docarray/typing/tensor/embedding/jax_array.py @@ -0,0 +1,17 @@ +from typing import Any # noqa: F401 + +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin +from docarray.typing.tensor.jaxarray import JaxArray + +jax_base = type(JaxArray) # type: Any +embedding_base = type(EmbeddingMixin) # type: Any + + +class metaJaxAndEmbedding(jax_base, embedding_base): + pass + + +@_register_proto(proto_type_name='jaxarray_embedding') +class JaxArrayEmbedding(JaxArray, EmbeddingMixin, metaclass=metaJaxAndEmbedding): + alternative_type = JaxArray diff --git a/docarray/typing/tensor/image/__init__.py b/docarray/typing/tensor/image/__init__.py index 7af4b852206..d62b096c1fe 100644 --- a/docarray/typing/tensor/image/__init__.py +++ b/docarray/typing/tensor/image/__init__.py @@ -10,6 +10,7 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.image.image_jax_array import ImageJaxArray # noqa from docarray.typing.tensor.image.image_tensorflow_tensor import ( # noqa ImageTensorFlowTensor, ) @@ -26,6 +27,9 @@ def __getattr__(name: str): elif name == 'ImageTensorFlowTensor': import_library('tensorflow', raise_error=True) import docarray.typing.tensor.image.image_tensorflow_tensor as lib + elif name == 'ImageJaxArray': + import_library('jax', raise_error=True) + import docarray.typing.tensor.image.image_jax_array as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/image/image_jax_array.py b/docarray/typing/tensor/image/image_jax_array.py new file mode 100644 index 00000000000..8fabf91ac24 --- /dev/null +++ b/docarray/typing/tensor/image/image_jax_array.py @@ -0,0 +1,10 @@ +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.image.abstract_image_tensor import AbstractImageTensor +from docarray.typing.tensor.jaxarray import JaxArray, metaJax + +MAX_INT_16 = 2**15 + + +@_register_proto(proto_type_name='image_jaxarray') +class ImageJaxArray(JaxArray, AbstractImageTensor, metaclass=metaJax): + ... diff --git a/docarray/typing/tensor/image/image_tensor.py b/docarray/typing/tensor/image/image_tensor.py index ece9f5978ed..3dc58c737c3 100644 --- a/docarray/typing/tensor/image/image_tensor.py +++ b/docarray/typing/tensor/image/image_tensor.py @@ -5,7 +5,18 @@ from docarray.typing.tensor.image.abstract_image_tensor import AbstractImageTensor from docarray.typing.tensor.image.image_ndarray import ImageNdArray from docarray.typing.tensor.tensor import AnyTensor -from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils._internal.misc import ( + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp # type: ignore + + from docarray.typing.tensor.image.image_jax_array import ImageJaxArray + from docarray.typing.tensor.jaxarray import JaxArray torch_available = is_torch_available() if torch_available: @@ -94,6 +105,11 @@ def validate( return cast(ImageTensorFlowTensor, value) elif isinstance(value, tf.Tensor): return ImageTensorFlowTensor._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return cast(ImageJaxArray, value) + elif isinstance(value, jnp.ndarray): + return ImageJaxArray._docarray_from_native(value) # noqa try: return ImageNdArray.validate(value, field, config) except Exception: # noqa diff --git a/docarray/typing/tensor/jaxarray.py b/docarray/typing/tensor/jaxarray.py new file mode 100644 index 00000000000..4b145c6ac4c --- /dev/null +++ b/docarray/typing/tensor/jaxarray.py @@ -0,0 +1,265 @@ +from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union, cast + +import numpy as np + +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal.misc import import_library + +if TYPE_CHECKING: + import jax + import jax.numpy as jnp + from pydantic import BaseConfig + from pydantic.fields import ModelField + + from docarray.computation.jax_backend import JaxCompBackend + from docarray.proto import NdArrayProto +else: + jax = import_library('jax', raise_error=True) + jnp = jax.numpy +from docarray.base_doc.base_node import BaseNode + +T = TypeVar('T', bound='JaxArray') +ShapeT = TypeVar('ShapeT') + +node_base: type = type(BaseNode) + + +# the mypy error suppression below should not be necessary anymore once the following +# is released in mypy: https://github.com/python/mypy/pull/14135 +class metaJax( + AbstractTensor.__parametrized_meta__, # type: ignore + node_base, # type: ignore +): # type: ignore + pass + + +@_register_proto(proto_type_name='jaxarray') +class JaxArray(AbstractTensor, Generic[ShapeT], metaclass=metaJax): + """ + Subclass of `jnp.ndarray`, intended for use in a Document. + This enables (de)serialization from/to protobuf and json, data validation, + and coercion from compatible types like `torch.Tensor`. + + This type can also be used in a parametrized way, specifying the shape of the array. + + --- + + ```python + from docarray import BaseDoc + from docarray.typing import JaxArray + import jax.numpy as jnp + + + class MyDoc(BaseDoc): + arr: JaxArray + image_arr: JaxArray[3, 224, 224] + square_crop: JaxArray[3, 'x', 'x'] + random_image: JaxArray[3, ...] # first dimension is fixed, can have arbitrary shape + + + # create a document with tensors + doc = MyDoc( + arr=jnp.zeros((128,)), + image_arr=jnp.zeros((3, 224, 224)), + square_crop=jnp.zeros((3, 64, 64)), + random_image=jnp.zeros((3, 128, 256)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # automatic shape conversion + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) + square_crop=np.zeros((3, 128, 128)), + random_image=np.zeros((3, 64, 128)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # !! The following will raise an error due to shape mismatch !! + from pydantic import ValidationError + + try: + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224)), # this will fail validation + square_crop=np.zeros((3, 128, 64)), # this will also fail validation + random_image=np.zeros((4, 64, 128)), # this will also fail validation + ) + except ValidationError as e: + pass + ``` + + --- + """ + + __parametrized_meta__ = metaJax + + def __init__(self, tensor: jnp.ndarray): + super().__init__() + self.tensor = tensor + + def __getitem__(self, item): + from docarray.computation.jax_backend import JaxCompBackend + + tensor = self.unwrap() + if tensor is not None: + tensor = tensor[item] + return JaxCompBackend._cast_output(t=tensor) + + def __setitem__(self, index, value): + """""" + # print(index, value) + self.tensor = self.tensor.at[index : index + 1].set(value) + + def __iter__(self): + for i in range(len(self)): + yield self[i] + + def __len__(self): + return len(self.tensor) + + @classmethod + def __get_validators__(cls): + # one or more validators may be yielded which will be called in the + # order to validate the input, each validator will receive as an input + # the value returned from the previous validator + yield cls.validate + + @classmethod + def validate( + cls: Type[T], + value: Union[T, jnp.ndarray, List[Any], Tuple[Any], Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + if isinstance(value, jax.Array): + return cls._docarray_from_native(value) + elif isinstance(value, JaxArray): + return cast(T, value) + elif isinstance(value, list) or isinstance(value, tuple): + try: + arr_from_list: jnp.ndarray = jnp.asarray(value) + return cls._docarray_from_native(arr_from_list) + except Exception: + pass # handled below + else: + try: + arr: jnp.ndarray = jnp.ndarray(value) + return cls._docarray_from_native(arr) + except Exception: + pass # handled below + raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}') + + @classmethod + def _docarray_from_native(cls: Type[T], value: jnp.ndarray) -> T: + if isinstance(value, JaxArray): + if cls.__unparametrizedcls__: # None if the tensor is parametrized + value.__class__ = cls.__unparametrizedcls__ # type: ignore + else: + value.__class__ = cls # type: ignore + return cast(T, value) + else: + if cls.__unparametrizedcls__: # None if the tensor is parametrized + cls_param_ = cls.__unparametrizedcls__ + cls_param = cast(Type[T], cls_param_) + else: + cls_param = cls + + return cls_param(tensor=value) + + @classmethod + def from_ndarray(cls: Type[T], value: np.ndarray) -> T: + """Create a `TensorFlowTensor` from a numpy array. + + :param value: the numpy array + :return: a `TensorFlowTensor` + """ + return cls._docarray_from_native(jnp.array(value)) + + def _docarray_to_json_compatible(self) -> jnp.ndarray: + """ + Convert `JaxArray` into a json compatible object + :return: a representation of the tensor compatible with orjson + """ + return self.unwrap() + + def unwrap(self) -> jnp.ndarray: + """ + Return the original ndarray without making a copy in memory. + + The original view remains intact and is still a Document `JaxArray` + but the return object is a pure `np.ndarray` and both objects share + the same underlying memory. + + --- + + ```python + from docarray.typing import JaxArray + import numpy as np + + t1 = JaxArray.validate(np.zeros((3, 224, 224)), None, None) + # here t1 is a docarray NdArray + t2 = t1.unwrap() + # here t2 is a pure np.ndarray but t1 is still a Docarray JaxArray + # But both share the same underlying memory + ``` + + --- + + :return: a `jnp.ndarray` + """ + return self.tensor + + @classmethod + def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': + """ + Read ndarray from a proto msg + :param pb_msg: + :return: a numpy array + """ + source = pb_msg.dense + if source.buffer: + x = np.frombuffer(bytearray(source.buffer), dtype=source.dtype) + return cls.from_ndarray(x.reshape(source.shape)) + elif len(source.shape) > 0: + return cls.from_ndarray(np.zeros(source.shape)) + else: + raise ValueError( + f'Proto message {pb_msg} cannot be cast to a TensorFlowTensor.' + ) + + def to_protobuf(self) -> 'NdArrayProto': + """ + Transform self into a NdArrayProto protobuf message + """ + from docarray.proto import NdArrayProto + + nd_proto = NdArrayProto() + + value_np = self.tensor + nd_proto.dense.buffer = value_np.tobytes() + nd_proto.dense.ClearField('shape') + nd_proto.dense.shape.extend(list(value_np.shape)) + nd_proto.dense.dtype = value_np.dtype.str + + return nd_proto + + @staticmethod + def get_comp_backend() -> 'JaxCompBackend': + """Return the computational backend of the tensor""" + from docarray.computation.jax_backend import JaxCompBackend + + return JaxCompBackend() + + def __class_getitem__(cls, item: Any, *args, **kwargs): + # see here for mypy bug: https://github.com/python/mypy/issues/14123 + return AbstractTensor.__class_getitem__.__func__(cls, item) # type: ignore + + @classmethod + def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T: + return cls.from_ndarray(value) + + def _docarray_to_ndarray(self) -> np.ndarray: + """cast itself to a numpy array""" + return self.tensor.__array__() diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index e8935758e42..2f547b55dea 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -5,7 +5,17 @@ from docarray.base_doc.base_node import BaseNode from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor -from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa +from docarray.utils._internal.misc import ( # noqa + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 torch_available = is_torch_available() if torch_available: @@ -124,6 +134,8 @@ def validate( return cls._docarray_from_native(value.detach().cpu().numpy()) elif tf_available and isinstance(value, tf.Tensor): return cls._docarray_from_native(value.numpy()) + elif jax_available and isinstance(value, jnp.ndarray): + return cls._docarray_from_native(value.__array__()) elif isinstance(value, list) or isinstance(value, tuple): try: arr_from_list: np.ndarray = np.asarray(value) diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index e8d84bf04a0..2d5be7cd096 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -4,7 +4,17 @@ from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.typing.tensor.ndarray import NdArray -from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa +from docarray.utils._internal.misc import ( # noqa + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 torch_available = is_torch_available() if torch_available: @@ -27,12 +37,20 @@ # behavior as `Union[TorchTensor, TensorFlowTensor, NdArray]` so it should be fine to use `AnyTensor` as # the type for `tensor` field in `BaseDoc` class. AnyTensor = Union[NdArray] - if torch_available and tf_available: + if torch_available and tf_available and jax_available: + AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor, JaxArray] # type: ignore + elif torch_available and tf_available: AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor] # type: ignore - elif torch_available: - AnyTensor = Union[NdArray, TorchTensor] # type: ignore + elif tf_available and jax_available: + AnyTensor = Union[NdArray, TensorFlowTensor, JaxArray] # type: ignore + elif torch_available and jax_available: + AnyTensor = Union[NdArray, TorchTensor, JaxArray] # type: ignore elif tf_available: AnyTensor = Union[NdArray, TensorFlowTensor] # type: ignore + elif torch_available: + AnyTensor = Union[NdArray, TorchTensor] # type: ignore + elif jax_available: + AnyTensor = Union[NdArray, JaxArray] # type: ignore else: @@ -124,6 +142,11 @@ def validate( return value elif isinstance(value, tf.Tensor): return TensorFlowTensor._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return value + elif isinstance(value, jnp.ndarray): + return JaxArray._docarray_from_native(value) # noqa try: return NdArray.validate(value, field, config) except Exception as e: # noqa diff --git a/docarray/typing/tensor/tensorflow_tensor.py b/docarray/typing/tensor/tensorflow_tensor.py index 1eb2bc7eacf..a42b3a0a5d3 100644 --- a/docarray/typing/tensor/tensorflow_tensor.py +++ b/docarray/typing/tensor/tensorflow_tensor.py @@ -5,7 +5,11 @@ from docarray.base_doc.base_node import BaseNode from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor -from docarray.utils._internal.misc import import_library, is_torch_available +from docarray.utils._internal.misc import ( + import_library, + is_jax_available, + is_torch_available, +) if TYPE_CHECKING: import tensorflow as tf # type: ignore @@ -21,6 +25,10 @@ if torch_available: import torch +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + T = TypeVar('T', bound='TensorFlowTensor') ShapeT = TypeVar('ShapeT') @@ -211,6 +219,8 @@ def validate( return cls._docarray_from_ndarray(value._docarray_to_ndarray()) elif torch_available and isinstance(value, torch.Tensor): return cls._docarray_from_native(value.detach().cpu().numpy()) + elif jax_available and isinstance(value, jnp.ndarray): + return cls._docarray_from_native(value.__array__()) else: try: arr: tf.Tensor = tf.constant(value) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 0f7ff0132d9..a78781f6a9b 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -6,7 +6,11 @@ from docarray.base_doc.base_node import BaseNode from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor -from docarray.utils._internal.misc import import_library, is_tf_available +from docarray.utils._internal.misc import ( + import_library, + is_jax_available, + is_tf_available, +) if TYPE_CHECKING: import torch @@ -22,6 +26,10 @@ if tf_available: import tensorflow as tf # type: ignore +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + T = TypeVar('T', bound='TorchTensor') ShapeT = TypeVar('ShapeT') @@ -132,6 +140,8 @@ def validate( return cls._docarray_from_ndarray(value.numpy()) elif isinstance(value, np.ndarray): return cls._docarray_from_ndarray(value) + elif jax_available and isinstance(value, jnp.ndarray): + return cls._docarray_from_ndarray(value.__array__()) else: try: arr: torch.Tensor = torch.tensor(value) diff --git a/docarray/typing/tensor/video/__init__.py b/docarray/typing/tensor/video/__init__.py index a575e7b6201..18f0a2e5d8b 100644 --- a/docarray/typing/tensor/video/__init__.py +++ b/docarray/typing/tensor/video/__init__.py @@ -10,6 +10,7 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.video.video_jax_array import VideoJaxArray # noqa from docarray.typing.tensor.video.video_tensorflow_tensor import ( # noqa VideoTensorFlowTensor, ) @@ -26,6 +27,9 @@ def __getattr__(name: str): elif name == 'VideoTensorFlowTensor': import_library('tensorflow', raise_error=True) import docarray.typing.tensor.video.video_tensorflow_tensor as lib + elif name == 'VideoJaxArray': + import_library('jax', raise_error=True) + import docarray.typing.tensor.video.video_jax_array as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/video/video_jax_array.py b/docarray/typing/tensor/video/video_jax_array.py new file mode 100644 index 00000000000..5b060e49246 --- /dev/null +++ b/docarray/typing/tensor/video/video_jax_array.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union + +import numpy as np + +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.jaxarray import JaxArray, metaJax +from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin + +T = TypeVar('T', bound='VideoJaxArray') + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + + +@_register_proto(proto_type_name='video_jaxarray') +class VideoJaxArray(JaxArray, VideoTensorMixin, metaclass=metaJax): + """ """ + + @classmethod + def validate( + cls: Type[T], + value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + tensor = super().validate(value=value, field=field, config=config) + return cls.validate_shape(value=tensor) diff --git a/docarray/typing/tensor/video/video_tensor.py b/docarray/typing/tensor/video/video_tensor.py index be77c9db21e..5687ecfe561 100644 --- a/docarray/typing/tensor/video/video_tensor.py +++ b/docarray/typing/tensor/video/video_tensor.py @@ -5,7 +5,18 @@ from docarray.typing.tensor.tensor import AnyTensor from docarray.typing.tensor.video.video_ndarray import VideoNdArray from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin -from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils._internal.misc import ( + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 + from docarray.typing.tensor.video.video_jax_array import VideoJaxArray torch_available = is_torch_available() if torch_available: @@ -94,6 +105,11 @@ def validate( return cast(VideoTensorFlowTensor, value) elif isinstance(value, tf.Tensor): return VideoTensorFlowTensor._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return cast(VideoJaxArray, value) + elif isinstance(value, jnp.ndarray): + return VideoJaxArray._docarray_from_native(value) # noqa if isinstance(value, VideoNdArray): return cast(VideoNdArray, value) if isinstance(value, np.ndarray): diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index 1ac8bc659b6..ad0d28d9c9e 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -22,6 +22,13 @@ tf_imported = True +try: + import jax.numpy as jnp # type: ignore # noqa: F401 +except (ImportError, TypeError): + jnp_imported = False +else: + jnp_imported = True + INSTALL_INSTRUCTIONS = { 'google.protobuf': '"docarray[proto]"', 'lz4': '"docarray[proto]"', @@ -78,6 +85,10 @@ def is_tf_available(): return tf_imported +def is_jax_available(): + return jnp_imported + + def is_np_int(item: Any) -> bool: dtype = getattr(item, 'dtype', None) ndim = getattr(item, 'ndim', None) diff --git a/docarray/utils/find.py b/docarray/utils/find.py index 46c167582f1..2b77bcbb77e 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -1,40 +1,51 @@ __all__ = ['find', 'find_batched'] from typing import ( + TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, + Type, Union, cast, - Type, - TYPE_CHECKING, ) from docarray.array.any_array import AnyDocArray from docarray.array.doc_list.doc_list import DocList from docarray.array.doc_vec.doc_vec import DocVec from docarray.base_doc import BaseDoc -from docarray.typing import AnyTensor from docarray.computation.numpy_backend import NumpyCompBackend +from docarray.typing import AnyTensor from docarray.typing.tensor import NdArray -from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa +from docarray.utils._internal.misc import ( # noqa + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.computation.jax_backend import JaxCompBackend + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 torch_available = is_torch_available() if torch_available: import torch - from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401 from docarray.computation.torch_backend import TorchCompBackend + from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401 tf_available = is_tf_available() if tf_available: import tensorflow as tf # type: ignore - from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401 from docarray.computation.tensorflow_backend import TensorFlowCompBackend + from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401 if TYPE_CHECKING: from docarray.computation.abstract_numpy_based_backend import ( @@ -310,6 +321,9 @@ def _get_tensor_type_and_comp_backend_from_tensor( elif tf_available and isinstance(tensor, (TensorFlowTensor, tf.Tensor)): comp_backend = TensorFlowCompBackend() da_tensor_type = TensorFlowTensor + elif jax_available and isinstance(tensor, (JaxArray, jnp.ndarray)): + comp_backend = JaxCompBackend() + da_tensor_type = JaxArray return da_tensor_type, comp_backend diff --git a/poetry.lock b/poetry.lock index b8b2a97e009..dd8e1b1ef3f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "aiohttp" version = "3.8.4" description = "Async http client/server framework (asyncio)" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -112,6 +113,7 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -126,6 +128,7 @@ frozenlist = ">=1.1.0" name = "anyio" version = "3.6.2" description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" optional = false python-versions = ">=3.6.2" files = [ @@ -146,6 +149,7 @@ trio = ["trio (>=0.16,<0.22)"] name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" +category = "dev" optional = false python-versions = "*" files = [ @@ -157,6 +161,7 @@ files = [ name = "argon2-cffi" version = "21.3.0" description = "The secure Argon2 password hashing algorithm." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -176,6 +181,7 @@ tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"] name = "argon2-cffi-bindings" version = "21.2.0" description = "Low-level CFFI bindings for Argon2" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -213,6 +219,7 @@ tests = ["pytest"] name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -224,6 +231,7 @@ files = [ name = "attrs" version = "22.1.0" description = "Classes Without Boilerplate" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -241,6 +249,7 @@ tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy name = "authlib" version = "1.2.0" description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." +category = "main" optional = true python-versions = "*" files = [ @@ -255,6 +264,7 @@ cryptography = ">=3.2" name = "av" version = "10.0.0" description = "Pythonic bindings for FFmpeg's libraries." +category = "main" optional = true python-versions = "*" files = [ @@ -308,6 +318,7 @@ files = [ name = "babel" version = "2.11.0" description = "Internationalization utilities" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -322,6 +333,7 @@ pytz = ">=2015.7" name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" +category = "dev" optional = false python-versions = "*" files = [ @@ -333,6 +345,7 @@ files = [ name = "beautifulsoup4" version = "4.11.1" description = "Screen-scraping library" +category = "dev" optional = false python-versions = ">=3.6.0" files = [ @@ -351,6 +364,7 @@ lxml = ["lxml"] name = "black" version = "22.10.0" description = "The uncompromising code formatter." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -395,6 +409,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "blacken-docs" version = "1.13.0" description = "Run Black on Python code blocks in documentation files." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -409,6 +424,7 @@ black = ">=22.1.0" name = "bleach" version = "5.0.1" description = "An easy safelist-based HTML-sanitizing tool." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -428,6 +444,7 @@ dev = ["Sphinx (==4.3.2)", "black (==22.3.0)", "build (==0.8.0)", "flake8 (==4.0 name = "boto3" version = "1.26.95" description = "The AWS SDK for Python" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -447,6 +464,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.29.95" description = "Low-level, data-driven core of boto 3." +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -466,6 +484,7 @@ crt = ["awscrt (==0.16.9)"] name = "bracex" version = "2.3.post1" description = "Bash style brace expander." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -477,6 +496,7 @@ files = [ name = "certifi" version = "2022.9.24" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -488,6 +508,7 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." +category = "main" optional = false python-versions = "*" files = [ @@ -564,6 +585,7 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." +category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -575,6 +597,7 @@ files = [ name = "chardet" version = "5.1.0" description = "Universal encoding detector for Python 3" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -586,6 +609,7 @@ files = [ name = "charset-normalizer" version = "2.0.12" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" optional = false python-versions = ">=3.5.0" files = [ @@ -600,6 +624,7 @@ unicode-backport = ["unicodedata2"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -614,6 +639,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -625,6 +651,7 @@ files = [ name = "colorlog" version = "6.7.0" description = "Add colours to the output of Python's logging module." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -642,6 +669,7 @@ development = ["black", "flake8", "mypy", "pytest", "types-colorama"] name = "commonmark" version = "0.9.1" description = "Python parser for the CommonMark Markdown spec" +category = "main" optional = false python-versions = "*" files = [ @@ -656,6 +684,7 @@ test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] name = "coverage" version = "6.2" description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -718,6 +747,7 @@ toml = ["tomli"] name = "cryptography" version = "40.0.1" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -759,6 +789,7 @@ tox = ["tox"] name = "debugpy" version = "1.6.3" description = "An implementation of the Debug Adapter Protocol for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -786,6 +817,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -797,6 +829,7 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -808,6 +841,7 @@ files = [ name = "distlib" version = "0.3.6" description = "Distribution utilities" +category = "dev" optional = false python-versions = "*" files = [ @@ -819,6 +853,7 @@ files = [ name = "docker" version = "6.0.1" description = "A Python library for the Docker Engine API." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -840,6 +875,7 @@ ssh = ["paramiko (>=2.4.3)"] name = "ecdsa" version = "0.18.0" description = "ECDSA cryptographic signature library (pure python)" +category = "main" optional = true python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -858,6 +894,7 @@ gmpy2 = ["gmpy2"] name = "elastic-transport" version = "8.4.0" description = "Transport classes and utilities shared among Python Elastic client libraries" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -876,6 +913,7 @@ develop = ["aiohttp", "mock", "pytest", "pytest-asyncio", "pytest-cov", "pytest- name = "elasticsearch" version = "7.10.1" description = "Python client for Elasticsearch" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4" files = [ @@ -897,6 +935,7 @@ requests = ["requests (>=2.4.0,<3.0.0)"] name = "entrypoints" version = "0.4" description = "Discover and load entry points from installed packages." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -908,6 +947,7 @@ files = [ name = "exceptiongroup" version = "1.1.0" description = "Backport of PEP 654 (exception groups)" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -922,6 +962,7 @@ test = ["pytest (>=6)"] name = "fastapi" version = "0.87.0" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -943,6 +984,7 @@ test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==22.8.0)", "coverage[toml] (>=6 name = "fastjsonschema" version = "2.16.2" description = "Fastest Python implementation of JSON schema" +category = "dev" optional = false python-versions = "*" files = [ @@ -957,6 +999,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.8.0" description = "A platform independent file lock." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -972,6 +1015,7 @@ testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pyt name = "frozenlist" version = "1.3.3" description = "A list-like structure which implements collections.abc.MutableSequence" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1055,6 +1099,7 @@ files = [ name = "ghp-import" version = "2.1.0" description = "Copy your docs directly to the gh-pages branch." +category = "dev" optional = false python-versions = "*" files = [ @@ -1072,6 +1117,7 @@ dev = ["flake8", "markdown", "twine", "wheel"] name = "griffe" version = "0.25.5" description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1089,6 +1135,7 @@ async = ["aiofiles (>=0.7,<1.0)"] name = "grpcio" version = "1.53.0" description = "HTTP/2-based RPC framework" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1146,6 +1193,7 @@ protobuf = ["grpcio-tools (>=1.53.0)"] name = "grpcio-tools" version = "1.53.0" description = "Protobuf code generator for gRPC" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1205,6 +1253,7 @@ setuptools = "*" name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1216,6 +1265,7 @@ files = [ name = "h2" version = "4.1.0" description = "HTTP/2 State-Machine based protocol implementation" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -1231,6 +1281,7 @@ hyperframe = ">=6.0,<7" name = "hnswlib" version = "0.7.0" description = "hnswlib" +category = "main" optional = true python-versions = "*" files = [ @@ -1244,6 +1295,7 @@ numpy = "*" name = "hpack" version = "4.0.0" description = "Pure-Python HPACK header compression" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -1255,6 +1307,7 @@ files = [ name = "httpcore" version = "0.16.1" description = "A minimal low-level HTTP client." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1266,16 +1319,17 @@ files = [ anyio = ">=3.0,<5.0" certifi = "*" h11 = ">=0.13,<0.15" -sniffio = "==1.*" +sniffio = ">=1.0.0,<2.0.0" [package.extras] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] [[package]] name = "httpx" version = "0.23.1" description = "The next generation HTTP client." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1292,14 +1346,15 @@ sniffio = "*" [package.extras] brotli = ["brotli", "brotlicffi"] -cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<13)"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<13)"] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] [[package]] name = "hyperframe" version = "6.0.1" description = "HTTP/2 framing layer for Python" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -1311,6 +1366,7 @@ files = [ name = "identify" version = "2.5.8" description = "File identification library for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1325,6 +1381,7 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1336,6 +1393,7 @@ files = [ name = "importlib-metadata" version = "5.0.0" description = "Read metadata from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1355,6 +1413,7 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "5.10.0" description = "Read resources from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1373,6 +1432,7 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "iniconfig" version = "1.1.1" description = "iniconfig: brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = "*" files = [ @@ -1384,6 +1444,7 @@ files = [ name = "ipykernel" version = "6.16.2" description = "IPython Kernel for Jupyter" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1412,6 +1473,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-cov", "p name = "ipython" version = "7.34.0" description = "IPython: Productive Interactive Computing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1448,6 +1510,7 @@ test = ["ipykernel", "nbformat", "nose (>=0.10.1)", "numpy (>=1.17)", "pygments" name = "ipython-genutils" version = "0.2.0" description = "Vestigial utilities from IPython" +category = "dev" optional = false python-versions = "*" files = [ @@ -1459,6 +1522,7 @@ files = [ name = "isort" version = "5.11.5" description = "A Python utility / library to sort Python imports." +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -1472,10 +1536,42 @@ pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib" plugins = ["setuptools"] requirements-deprecated-finder = ["pip-api", "pipreqs"] +[[package]] +name = "jax" +version = "0.4.13" +description = "Differentiate, compile, and transform Numpy code." +category = "main" +optional = true +python-versions = ">=3.8" +files = [ + {file = "jax-0.4.13.tar.gz", hash = "sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa"}, +] + +[package.dependencies] +importlib_metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} +ml_dtypes = ">=0.1.0" +numpy = ">=1.21" +opt_einsum = "*" +scipy = ">=1.7" + +[package.extras] +australis = ["protobuf (>=3.13,<4)"] +ci = ["jaxlib (==0.4.12)"] +cpu = ["jaxlib (==0.4.13)"] +cuda = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.13+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.13+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.13+cuda12.cudnn89)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] + [[package]] name = "jedi" version = "0.18.1" description = "An autocompletion tool for Python that can be used for text editors." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1494,6 +1590,7 @@ testing = ["Django (<3.1)", "colorama", "docopt", "pytest (<7.0.0)"] name = "jina-hubble-sdk" version = "0.34.0" description = "SDK for Hubble API at Jina AI." +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -1519,6 +1616,7 @@ full = ["aiohttp", "black (==22.3.0)", "docker", "filelock", "flake8 (==4.0.1)", name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1536,6 +1634,7 @@ i18n = ["Babel (>=2.7)"] name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1547,6 +1646,7 @@ files = [ name = "json5" version = "0.9.10" description = "A Python implementation of the JSON5 data format." +category = "dev" optional = false python-versions = "*" files = [ @@ -1561,6 +1661,7 @@ dev = ["hypothesis"] name = "jsonschema" version = "4.17.0" description = "An implementation of JSON Schema validation for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1582,6 +1683,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jupyter-client" version = "7.4.6" description = "Jupyter protocol implementation and client libraries" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1606,6 +1708,7 @@ test = ["codecov", "coverage", "ipykernel (>=6.12)", "ipython", "mypy", "pre-com name = "jupyter-core" version = "4.12.0" description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1624,6 +1727,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyter-server" version = "1.23.2" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1656,6 +1760,7 @@ test = ["coverage", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console name = "jupyterlab" version = "3.5.0" description = "JupyterLab computational environment" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1683,6 +1788,7 @@ ui-tests = ["build"] name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1694,6 +1800,7 @@ files = [ name = "jupyterlab-server" version = "2.16.3" description = "A set of server components for JupyterLab and JupyterLab like applications." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1720,6 +1827,7 @@ test = ["codecov", "ipykernel", "jupyter-server[test]", "openapi-core (>=0.14.2, name = "lxml" version = "4.9.2" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" files = [ @@ -1812,6 +1920,7 @@ source = ["Cython (>=0.29.7)"] name = "lz4" version = "4.3.2" description = "LZ4 Bindings for Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1861,6 +1970,7 @@ tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] name = "mapbox-earcut" version = "1.0.1" description = "Python bindings for the mapbox earcut C++ polygon triangulation library." +category = "main" optional = true python-versions = "*" files = [ @@ -1935,6 +2045,7 @@ test = ["pytest"] name = "markdown" version = "3.3.7" description = "Python implementation of Markdown." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1952,6 +2063,7 @@ testing = ["coverage", "pyyaml"] name = "markupsafe" version = "2.1.1" description = "Safely add untrusted strings to HTML/XML markup." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2001,6 +2113,7 @@ files = [ name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2015,6 +2128,7 @@ traitlets = "*" name = "mergedeep" version = "1.3.4" description = "A deep merge function for 🐍." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2026,6 +2140,7 @@ files = [ name = "mistune" version = "2.0.4" description = "A sane Markdown parser with useful plugins and renderers" +category = "dev" optional = false python-versions = "*" files = [ @@ -2037,6 +2152,7 @@ files = [ name = "mkdocs" version = "1.4.2" description = "Project documentation with Markdown." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2065,6 +2181,7 @@ min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4)", "ghp-imp name = "mkdocs-autorefs" version = "0.4.1" description = "Automatically link across pages in MkDocs." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2080,6 +2197,7 @@ mkdocs = ">=1.1" name = "mkdocs-awesome-pages-plugin" version = "2.8.0" description = "An MkDocs plugin that simplifies configuring page titles and their order" +category = "dev" optional = false python-versions = ">=3.6.2" files = [ @@ -2096,6 +2214,7 @@ wcmatch = ">=7" name = "mkdocs-material" version = "9.1.3" description = "Documentation that simply works" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2118,6 +2237,7 @@ requests = ">=2.26" name = "mkdocs-material-extensions" version = "1.1.1" description = "Extension pack for Python Markdown and MkDocs Material." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2129,6 +2249,7 @@ files = [ name = "mkdocs-video" version = "1.5.0" description = "" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2144,6 +2265,7 @@ mkdocs = ">=1.1.0,<2" name = "mkdocstrings" version = "0.20.0" description = "Automatic documentation from sources, for MkDocs." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2169,6 +2291,7 @@ python-legacy = ["mkdocstrings-python-legacy (>=0.2.1)"] name = "mkdocstrings-python" version = "0.8.3" description = "A Python handler for mkdocstrings." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2184,6 +2307,7 @@ mkdocstrings = ">=0.19" name = "mktestdocs" version = "0.2.0" description = "" +category = "dev" optional = false python-versions = "*" files = [ @@ -2194,10 +2318,48 @@ files = [ [package.extras] test = ["pytest (>=4.0.2)"] +[[package]] +name = "ml-dtypes" +version = "0.2.0" +description = "" +category = "main" +optional = true +python-versions = ">=3.7" +files = [ + {file = "ml_dtypes-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:df6a76e1c8adf484feb138ed323f9f40a7b6c21788f120f7c78bec20ac37ee81"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc29a0524ef5e23a7fbb8d881bdecabeb3fc1d19d9db61785d077a86cb94fab2"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08c391c2794f2aad358e6f4c70785a9a7b1df980ef4c232b3ccd4f6fe39f719"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:75015818a7fccf99a5e8ed18720cb430f3e71a8838388840f4cdf225c036c983"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e70047ec2c83eaee01afdfdabee2c5b0c133804d90d0f7db4dd903360fcc537c"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d28b8861a8931695e5a31176cad5ae85f6504906650dea5598fbec06c94606"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e85ba8e24cf48d456e564688e981cf379d4c8e644db0a2f719b78de281bac2ca"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:832a019a1b6db5c4422032ca9940a990fa104eee420f643713241b3a518977fa"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8faaf0897942c8253dd126662776ba45f0a5861968cf0f06d6d465f8a7bc298a"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b984cddbe8173b545a0e3334fe56ea1a5c3eb67c507f60d0cfde1d3fa8f8c2"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:022d5a4ee6be14569c2a9d1549e16f1ec87ca949681d0dca59995445d5fcdd5b"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:50845af3e9a601810751b55091dee6c2562403fa1cb4e0123675cf3a4fc2c17a"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f00c71c8c63e03aff313bc6a7aeaac9a4f1483a921a6ffefa6d4404efd1af3d0"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80d304c836d73f10605c58ccf7789c171cc229bfb678748adfb7cea2510dfd0e"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32107e7fa9f62db9a5281de923861325211dfff87bd23faefb27b303314635ab"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:1749b60348da71fd3c2ab303fdbc1965958dc50775ead41f5669c932a341cafd"}, + {file = "ml_dtypes-0.2.0.tar.gz", hash = "sha256:6488eb642acaaf08d8020f6de0a38acee7ac324c1e6e92ee0c0fea42422cb797"}, +] + +[package.dependencies] +numpy = [ + {version = ">1.20", markers = "python_version <= \"3.9\""}, + {version = ">=1.23.3", markers = "python_version > \"3.10\""}, + {version = ">=1.21.2", markers = "python_version > \"3.9\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + [[package]] name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" +category = "main" optional = true python-versions = "*" files = [ @@ -2215,6 +2377,7 @@ tests = ["pytest (>=4.6)"] name = "multidict" version = "6.0.4" description = "multidict implementation" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2298,6 +2461,7 @@ files = [ name = "mypy" version = "1.0.0" description = "Optional static typing for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2344,6 +2508,7 @@ reports = ["lxml"] name = "mypy-extensions" version = "0.4.3" description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "main" optional = false python-versions = "*" files = [ @@ -2355,6 +2520,7 @@ files = [ name = "natsort" version = "8.3.1" description = "Simple yet flexible natural sorting in Python." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2370,6 +2536,7 @@ icu = ["PyICU (>=1.0.0)"] name = "nbclassic" version = "0.4.8" description = "A web-based notebook environment for interactive computing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2405,6 +2572,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "pytest-playwright", "pytes name = "nbclient" version = "0.7.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -2426,6 +2594,7 @@ test = ["black", "check-manifest", "flake8", "ipykernel", "ipython", "ipywidgets name = "nbconvert" version = "7.2.5" description = "Converting Jupyter Notebooks" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2464,6 +2633,7 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.7.0" description = "The Jupyter Notebook format" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2484,6 +2654,7 @@ test = ["check-manifest", "pep440", "pre-commit", "pytest", "testpath"] name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2495,6 +2666,7 @@ files = [ name = "networkx" version = "2.6.3" description = "Python package for creating and manipulating graphs and networks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2513,6 +2685,7 @@ test = ["codecov (>=2.1)", "pytest (>=6.2)", "pytest-cov (>=2.12)"] name = "nodeenv" version = "1.7.0" description = "Node.js virtual environment builder" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -2527,6 +2700,7 @@ setuptools = "*" name = "notebook" version = "6.5.2" description = "A web-based notebook environment for interactive computing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2561,6 +2735,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "requests", "requests-unixs name = "notebook-shim" version = "0.2.2" description = "A shim layer for notebook traits and config" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2578,6 +2753,7 @@ test = ["pytest", "pytest-console-scripts", "pytest-tornasync"] name = "numpy" version = "1.21.1" description = "NumPy is the fundamental package for array computing with Python." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2611,10 +2787,49 @@ files = [ {file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"}, ] +[[package]] +name = "numpy" +version = "1.24.4" +description = "Fundamental package for array computing in Python" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, + {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, + {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, + {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, + {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, + {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, + {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, + {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, + {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, + {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, + {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, +] + [[package]] name = "nvidia-cublas-cu11" version = "11.10.3.66" description = "CUBLAS native runtime libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2630,6 +2845,7 @@ wheel = "*" name = "nvidia-cuda-nvrtc-cu11" version = "11.7.99" description = "NVRTC native runtime libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2646,6 +2862,7 @@ wheel = "*" name = "nvidia-cuda-runtime-cu11" version = "11.7.99" description = "CUDA Runtime native Libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2661,6 +2878,7 @@ wheel = "*" name = "nvidia-cudnn-cu11" version = "8.5.0.96" description = "cuDNN runtime libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2672,10 +2890,30 @@ files = [ setuptools = "*" wheel = "*" +[[package]] +name = "opt-einsum" +version = "3.3.0" +description = "Optimizing numpys einsum function" +category = "main" +optional = true +python-versions = ">=3.5" +files = [ + {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, + {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, +] + +[package.dependencies] +numpy = ">=1.7" + +[package.extras] +docs = ["numpydoc", "sphinx (==1.2.3)", "sphinx-rtd-theme", "sphinxcontrib-napoleon"] +tests = ["pytest", "pytest-cov", "pytest-pep8"] + [[package]] name = "orjson" version = "3.8.2" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2734,6 +2972,7 @@ files = [ name = "packaging" version = "21.3" description = "Core utilities for Python packages" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2748,6 +2987,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" name = "pandas" version = "1.1.0" description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -2781,6 +3021,7 @@ test = ["hypothesis (>=3.58)", "pytest (>=4.0.2)", "pytest-xdist"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2792,6 +3033,7 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2807,6 +3049,7 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathspec" version = "0.10.2" description = "Utility library for gitignore style pattern matching of file paths." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2818,6 +3061,7 @@ files = [ name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." +category = "dev" optional = false python-versions = "*" files = [ @@ -2832,6 +3076,7 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" +category = "dev" optional = false python-versions = "*" files = [ @@ -2843,6 +3088,7 @@ files = [ name = "pillow" version = "9.3.0" description = "Python Imaging Library (Fork)" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2917,6 +3163,7 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2928,6 +3175,7 @@ files = [ name = "platformdirs" version = "2.5.4" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2943,6 +3191,7 @@ test = ["appdirs (==1.4.4)", "pytest (>=7.2)", "pytest-cov (>=4)", "pytest-mock name = "pluggy" version = "0.13.1" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2957,6 +3206,7 @@ dev = ["pre-commit", "tox"] name = "pre-commit" version = "2.20.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2976,6 +3226,7 @@ virtualenv = ">=20.0.8" name = "prometheus-client" version = "0.15.0" description = "Python client for the Prometheus monitoring system." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2990,6 +3241,7 @@ twisted = ["twisted"] name = "prompt-toolkit" version = "3.0.32" description = "Library for building powerful interactive command lines in Python" +category = "dev" optional = false python-versions = ">=3.6.2" files = [ @@ -3004,6 +3256,7 @@ wcwidth = "*" name = "protobuf" version = "4.21.9" description = "" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3027,6 +3280,7 @@ files = [ name = "psutil" version = "5.9.4" description = "Cross-platform lib for process and system monitoring in Python." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3053,6 +3307,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -3064,6 +3319,7 @@ files = [ name = "py" version = "1.11.0" description = "library with cross-python path, ini-parsing, io, code, log facilities" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -3075,6 +3331,7 @@ files = [ name = "pyasn1" version = "0.4.8" description = "ASN.1 types and codecs" +category = "main" optional = true python-versions = "*" files = [ @@ -3086,6 +3343,7 @@ files = [ name = "pycollada" version = "0.7.2" description = "python library for reading and writing collada documents" +category = "main" optional = true python-versions = "*" files = [ @@ -3103,6 +3361,7 @@ validation = ["lxml"] name = "pycparser" version = "2.21" description = "C parser in Python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3114,6 +3373,7 @@ files = [ name = "pydantic" version = "1.10.2" description = "Data validation and settings management using python type hints" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3166,6 +3426,7 @@ email = ["email-validator (>=1.0.3)"] name = "pydub" version = "0.25.1" description = "Manipulate audio with an simple and easy high level interface" +category = "main" optional = true python-versions = "*" files = [ @@ -3177,6 +3438,7 @@ files = [ name = "pygments" version = "2.14.0" description = "Pygments is a syntax highlighting package written in Python." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3191,6 +3453,7 @@ plugins = ["importlib-metadata"] name = "pymdown-extensions" version = "9.10" description = "Extension pack for Python Markdown." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3206,6 +3469,7 @@ pyyaml = "*" name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -3220,6 +3484,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyrsistent" version = "0.19.2" description = "Persistent/Functional/Immutable data structures" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3251,6 +3516,7 @@ files = [ name = "pytest" version = "7.2.1" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3274,6 +3540,7 @@ testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2. name = "pytest-asyncio" version = "0.20.2" description = "Pytest support for asyncio" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3291,6 +3558,7 @@ testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy name = "pytest-cov" version = "3.0.0" description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3309,6 +3577,7 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -3323,6 +3592,7 @@ six = ">=1.5" name = "python-jose" version = "3.3.0" description = "JOSE implementation in Python" +category = "main" optional = true python-versions = "*" files = [ @@ -3344,6 +3614,7 @@ pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] name = "pytz" version = "2022.6" description = "World timezone definitions, modern and historical" +category = "main" optional = false python-versions = "*" files = [ @@ -3355,6 +3626,7 @@ files = [ name = "pywin32" version = "305" description = "Python for Window Extensions" +category = "main" optional = false python-versions = "*" files = [ @@ -3378,6 +3650,7 @@ files = [ name = "pywinpty" version = "2.0.9" description = "Pseudo terminal support for Windows from Python." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3393,6 +3666,7 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3442,6 +3716,7 @@ files = [ name = "pyyaml-env-tag" version = "0.1" description = "A custom YAML tag for referencing environment variables in YAML files. " +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3456,6 +3731,7 @@ pyyaml = "*" name = "pyzmq" version = "24.0.1" description = "Python bindings for 0MQ" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3543,6 +3819,7 @@ py = {version = "*", markers = "implementation_name == \"pypy\""} name = "qdrant-client" version = "1.1.4" description = "Client library for the Qdrant vector search engine" +category = "main" optional = true python-versions = ">=3.7,<3.12" files = [ @@ -3563,6 +3840,7 @@ urllib3 = ">=1.26.14,<2.0.0" name = "redis" version = "4.6.0" description = "Python client for Redis database and key-value store" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3581,6 +3859,7 @@ ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)" name = "regex" version = "2022.10.31" description = "Alternative regular expression module, to replace re." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3678,6 +3957,7 @@ files = [ name = "requests" version = "2.28.2" description = "Python HTTP for Humans." +category = "main" optional = false python-versions = ">=3.7, <4" files = [ @@ -3699,6 +3979,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rfc3986" version = "1.5.0" description = "Validating URI References per RFC 3986" +category = "main" optional = false python-versions = "*" files = [ @@ -3716,6 +3997,7 @@ idna2008 = ["idna"] name = "rich" version = "13.1.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3735,6 +4017,7 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" +category = "main" optional = true python-versions = ">=3.6,<4" files = [ @@ -3749,6 +4032,7 @@ pyasn1 = ">=0.1.3" name = "rtree" version = "1.0.1" description = "R-Tree spatial index for Python GIS" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3803,6 +4087,7 @@ files = [ name = "ruff" version = "0.0.243" description = "An extremely fast Python linter, written in Rust." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3828,6 +4113,7 @@ files = [ name = "s3transfer" version = "0.6.0" description = "An Amazon S3 Transfer Manager" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -3843,39 +4129,48 @@ crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] [[package]] name = "scipy" -version = "1.6.1" -description = "SciPy: Scientific Library for Python" +version = "1.9.3" +description = "Fundamental algorithms for scientific computing in Python" +category = "main" optional = true -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "scipy-1.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a15a1f3fc0abff33e792d6049161b7795909b40b97c6cc2934ed54384017ab76"}, - {file = "scipy-1.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:e79570979ccdc3d165456dd62041d9556fb9733b86b4b6d818af7a0afc15f092"}, - {file = "scipy-1.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:a423533c55fec61456dedee7b6ee7dce0bb6bfa395424ea374d25afa262be261"}, - {file = "scipy-1.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:33d6b7df40d197bdd3049d64e8e680227151673465e5d85723b3b8f6b15a6ced"}, - {file = "scipy-1.6.1-cp37-cp37m-win32.whl", hash = "sha256:6725e3fbb47da428794f243864f2297462e9ee448297c93ed1dcbc44335feb78"}, - {file = "scipy-1.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:5fa9c6530b1661f1370bcd332a1e62ca7881785cc0f80c0d559b636567fab63c"}, - {file = "scipy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bd50daf727f7c195e26f27467c85ce653d41df4358a25b32434a50d8870fc519"}, - {file = "scipy-1.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:f46dd15335e8a320b0fb4685f58b7471702234cba8bb3442b69a3e1dc329c345"}, - {file = "scipy-1.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0e5b0ccf63155d90da576edd2768b66fb276446c371b73841e3503be1d63fb5d"}, - {file = "scipy-1.6.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2481efbb3740977e3c831edfd0bd9867be26387cacf24eb5e366a6a374d3d00d"}, - {file = "scipy-1.6.1-cp38-cp38-win32.whl", hash = "sha256:68cb4c424112cd4be886b4d979c5497fba190714085f46b8ae67a5e4416c32b4"}, - {file = "scipy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:5f331eeed0297232d2e6eea51b54e8278ed8bb10b099f69c44e2558c090d06bf"}, - {file = "scipy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8a51d33556bf70367452d4d601d1742c0e806cd0194785914daf19775f0e67"}, - {file = "scipy-1.6.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:83bf7c16245c15bc58ee76c5418e46ea1811edcc2e2b03041b804e46084ab627"}, - {file = "scipy-1.6.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:794e768cc5f779736593046c9714e0f3a5940bc6dcc1dba885ad64cbfb28e9f0"}, - {file = "scipy-1.6.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5da5471aed911fe7e52b86bf9ea32fb55ae93e2f0fac66c32e58897cfb02fa07"}, - {file = "scipy-1.6.1-cp39-cp39-win32.whl", hash = "sha256:8e403a337749ed40af60e537cc4d4c03febddcc56cd26e774c9b1b600a70d3e4"}, - {file = "scipy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a5193a098ae9f29af283dcf0041f762601faf2e595c0db1da929875b7570353f"}, - {file = "scipy-1.6.1.tar.gz", hash = "sha256:c4fceb864890b6168e79b0e714c585dbe2fd4222768ee90bc1aa0f8218691b11"}, + {file = "scipy-1.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1884b66a54887e21addf9c16fb588720a8309a57b2e258ae1c7986d4444d3bc0"}, + {file = "scipy-1.9.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:83b89e9586c62e787f5012e8475fbb12185bafb996a03257e9675cd73d3736dd"}, + {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a72d885fa44247f92743fc20732ae55564ff2a519e8302fb7e18717c5355a8b"}, + {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d01e1dd7b15bd2449c8bfc6b7cc67d630700ed655654f0dfcf121600bad205c9"}, + {file = "scipy-1.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:68239b6aa6f9c593da8be1509a05cb7f9efe98b80f43a5861cd24c7557e98523"}, + {file = "scipy-1.9.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b41bc822679ad1c9a5f023bc93f6d0543129ca0f37c1ce294dd9d386f0a21096"}, + {file = "scipy-1.9.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:90453d2b93ea82a9f434e4e1cba043e779ff67b92f7a0e85d05d286a3625df3c"}, + {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83c06e62a390a9167da60bedd4575a14c1f58ca9dfde59830fc42e5197283dab"}, + {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abaf921531b5aeaafced90157db505e10345e45038c39e5d9b6c7922d68085cb"}, + {file = "scipy-1.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:06d2e1b4c491dc7d8eacea139a1b0b295f74e1a1a0f704c375028f8320d16e31"}, + {file = "scipy-1.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5a04cd7d0d3eff6ea4719371cbc44df31411862b9646db617c99718ff68d4840"}, + {file = "scipy-1.9.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:545c83ffb518094d8c9d83cce216c0c32f8c04aaf28b92cc8283eda0685162d5"}, + {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d54222d7a3ba6022fdf5773931b5d7c56efe41ede7f7128c7b1637700409108"}, + {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cff3a5295234037e39500d35316a4c5794739433528310e117b8a9a0c76d20fc"}, + {file = "scipy-1.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:2318bef588acc7a574f5bfdff9c172d0b1bf2c8143d9582e05f878e580a3781e"}, + {file = "scipy-1.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d644a64e174c16cb4b2e41dfea6af722053e83d066da7343f333a54dae9bc31c"}, + {file = "scipy-1.9.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:da8245491d73ed0a994ed9c2e380fd058ce2fa8a18da204681f2fe1f57f98f95"}, + {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4db5b30849606a95dcf519763dd3ab6fe9bd91df49eba517359e450a7d80ce2e"}, + {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c68db6b290cbd4049012990d7fe71a2abd9ffbe82c0056ebe0f01df8be5436b0"}, + {file = "scipy-1.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:5b88e6d91ad9d59478fafe92a7c757d00c59e3bdc3331be8ada76a4f8d683f58"}, + {file = "scipy-1.9.3.tar.gz", hash = "sha256:fbc5c05c85c1a02be77b1ff591087c83bc44579c6d2bd9fb798bb64ea5e1a027"}, ] [package.dependencies] -numpy = ">=1.16.5" +numpy = ">=1.18.5,<1.26.0" + +[package.extras] +dev = ["flake8", "mypy", "pycodestyle", "typing_extensions"] +doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-panels (>=0.5.2)", "sphinx-tabs"] +test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "send2trash" version = "1.8.0" description = "Send file to trash natively under Mac OS X, Windows and Linux." +category = "dev" optional = false python-versions = "*" files = [ @@ -3892,6 +4187,7 @@ win32 = ["pywin32"] name = "setuptools" version = "65.5.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3908,6 +4204,7 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( name = "shapely" version = "2.0.1" description = "Manipulation and analysis of geometric objects" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3955,13 +4252,14 @@ files = [ numpy = ">=1.14" [package.extras] -docs = ["matplotlib", "numpydoc (==1.1.*)", "sphinx", "sphinx-book-theme", "sphinx-remove-toctrees"] +docs = ["matplotlib", "numpydoc (>=1.1.0,<1.2.0)", "sphinx", "sphinx-book-theme", "sphinx-remove-toctrees"] test = ["pytest", "pytest-cov"] [[package]] name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -3973,6 +4271,7 @@ files = [ name = "smart-open" version = "6.3.0" description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)" +category = "main" optional = true python-versions = ">=3.6,<4.0" files = [ @@ -3997,6 +4296,7 @@ webhdfs = ["requests"] name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4008,6 +4308,7 @@ files = [ name = "soupsieve" version = "2.3.2.post1" description = "A modern CSS selector implementation for Beautiful Soup." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4019,6 +4320,7 @@ files = [ name = "starlette" version = "0.21.0" description = "The little ASGI library that shines." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4037,6 +4339,7 @@ full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyam name = "svg-path" version = "6.2" description = "SVG path objects and parser" +category = "main" optional = true python-versions = "*" files = [ @@ -4051,6 +4354,7 @@ test = ["Pillow", "pytest", "pytest-cov"] name = "sympy" version = "1.10.1" description = "Computer algebra system (CAS) in Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4065,6 +4369,7 @@ mpmath = ">=0.19" name = "terminado" version = "0.17.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4085,6 +4390,7 @@ test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4103,6 +4409,7 @@ test = ["flake8", "isort", "pytest"] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" +category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -4114,6 +4421,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4125,6 +4433,7 @@ files = [ name = "torch" version = "1.13.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -4165,6 +4474,7 @@ opt-einsum = ["opt-einsum (>=3.3)"] name = "tornado" version = "6.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "dev" optional = false python-versions = ">= 3.7" files = [ @@ -4185,6 +4495,7 @@ files = [ name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4205,6 +4516,7 @@ telegram = ["requests"] name = "traitlets" version = "5.5.0" description = "" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4220,6 +4532,7 @@ test = ["pre-commit", "pytest"] name = "trimesh" version = "3.21.2" description = "Import, export, process, analyze and view triangular meshes." +category = "main" optional = true python-versions = "*" files = [ @@ -4255,6 +4568,7 @@ test = ["autopep8", "coveralls", "ezdxf", "pyinstrument", "pytest", "pytest-cov" name = "types-pillow" version = "9.3.0.1" description = "Typing stubs for Pillow" +category = "main" optional = true python-versions = "*" files = [ @@ -4266,6 +4580,7 @@ files = [ name = "types-protobuf" version = "3.20.4.5" description = "Typing stubs for protobuf" +category = "dev" optional = false python-versions = "*" files = [ @@ -4277,6 +4592,7 @@ files = [ name = "types-pyopenssl" version = "23.2.0.1" description = "Typing stubs for pyOpenSSL" +category = "dev" optional = false python-versions = "*" files = [ @@ -4291,6 +4607,7 @@ cryptography = ">=35.0.0" name = "types-redis" version = "4.6.0.0" description = "Typing stubs for redis" +category = "dev" optional = false python-versions = "*" files = [ @@ -4306,6 +4623,7 @@ types-pyOpenSSL = "*" name = "types-requests" version = "2.28.11.7" description = "Typing stubs for requests" +category = "main" optional = false python-versions = "*" files = [ @@ -4320,6 +4638,7 @@ types-urllib3 = "<1.27" name = "types-urllib3" version = "1.26.25.4" description = "Typing stubs for urllib3" +category = "main" optional = false python-versions = "*" files = [ @@ -4331,6 +4650,7 @@ files = [ name = "typing-extensions" version = "4.4.0" description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4342,6 +4662,7 @@ files = [ name = "typing-inspect" version = "0.8.0" description = "Runtime inspection utilities for typing module." +category = "main" optional = false python-versions = "*" files = [ @@ -4357,6 +4678,7 @@ typing-extensions = ">=3.7.4" name = "urllib3" version = "1.26.14" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -4373,6 +4695,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "uvicorn" version = "0.19.0" description = "The lightning-fast ASGI server." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4391,6 +4714,7 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", name = "validators" version = "0.20.0" description = "Python Data Validation for Humans™." +category = "main" optional = true python-versions = ">=3.4" files = [ @@ -4407,6 +4731,7 @@ test = ["flake8 (>=2.4.0)", "isort (>=4.2.2)", "pytest (>=2.2.3)"] name = "virtualenv" version = "20.16.7" description = "Virtual Python Environment builder" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4427,6 +4752,7 @@ testing = ["coverage (>=6.2)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7 name = "watchdog" version = "2.3.1" description = "Filesystem events monitoring" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4467,6 +4793,7 @@ watchmedo = ["PyYAML (>=3.10)"] name = "wcmatch" version = "8.4.1" description = "Wildcard/glob file name matcher." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4481,6 +4808,7 @@ bracex = ">=2.1.1" name = "wcwidth" version = "0.2.5" description = "Measures the displayed width of unicode strings in a terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -4492,6 +4820,7 @@ files = [ name = "weaviate-client" version = "3.17.1" description = "A python native weaviate client" +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -4512,6 +4841,7 @@ grpc = ["grpcio", "grpcio-tools"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" +category = "dev" optional = false python-versions = "*" files = [ @@ -4523,6 +4853,7 @@ files = [ name = "websocket-client" version = "1.4.2" description = "WebSocket client for Python with low level API options" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4539,6 +4870,7 @@ test = ["websockets"] name = "wheel" version = "0.38.4" description = "A built-package format for Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4553,6 +4885,7 @@ test = ["pytest (>=3.0.0)"] name = "xxhash" version = "3.2.0" description = "Python binding for xxHash" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4660,6 +4993,7 @@ files = [ name = "yarl" version = "1.8.2" description = "Yet another URL library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4747,6 +5081,7 @@ multidict = ">=4.0" name = "zipp" version = "3.10.0" description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4762,10 +5097,11 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" audio = ["pydub"] aws = ["smart-open"] elasticsearch = ["elastic-transport", "elasticsearch"] -full = ["av", "lz4", "pandas", "pillow", "protobuf", "pydub", "trimesh", "types-pillow"] +full = ["av", "jax", "lz4", "pandas", "pillow", "protobuf", "pydub", "trimesh", "types-pillow"] hnswlib = ["hnswlib", "protobuf"] image = ["pillow", "types-pillow"] jac = ["jina-hubble-sdk"] +jax = ["jax"] mesh = ["trimesh"] pandas = ["pandas"] proto = ["lz4", "protobuf"] @@ -4779,4 +5115,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "e98157b56ee51d21d5861108878b27420613a9d43d819cce5c3adade89c6c440" +content-hash = "7b92f58355832b250432c909539267349a32496c47e7ee5fa5fddfc59b843d90" diff --git a/pyproject.toml b/pyproject.toml index 02fc1d3b96e..3e0e2ee40a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ jina-hubble-sdk = {version = ">=0.34.0", optional = true} elastic-transport = {version ="^8.4.0", optional = true } qdrant-client = {version = ">=1.1.4", python = "<3.12", optional = true } redis = {version = "^4.6.0", optional = true} +jax = {version = ">=0.4.10", optional = true} [tool.poetry.extras] proto = ["protobuf", "lz4"] @@ -76,9 +77,10 @@ web = ["fastapi"] qdrant = ["qdrant-client"] weaviate = ["weaviate-client"] redis = ['redis'] +jax = ["jaxlib","jax"] # all -full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh"] +full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh", "jax"] [tool.poetry.dev-dependencies] pytest = ">=7.0" diff --git a/tests/integrations/array/test_jax_integration.py b/tests/integrations/array/test_jax_integration.py new file mode 100644 index 00000000000..b120649d4f5 --- /dev/null +++ b/tests/integrations/array/test_jax_integration.py @@ -0,0 +1,38 @@ +from typing import Optional + +import pytest + +from docarray import BaseDoc, DocList +from docarray.utils._internal.misc import is_jax_available + +if is_jax_available(): + import jax.numpy as jnp + from jax import jit + + from docarray.typing import JaxArray + + +@pytest.mark.jax +def test_basic_jax_operation(): + def basic_jax_fn(x): + return jnp.sum(x) + + def abstract_JaxArray(array: 'JaxArray') -> jnp.ndarray: + return array.tensor + + class Mmdoc(BaseDoc): + tensor: Optional[JaxArray[3, 224, 224]] + + N = 10 + + batch = DocList[Mmdoc](Mmdoc() for _ in range(N)) + batch.tensor = jnp.zeros((N, 3, 224, 224)) + + batch = batch.to_doc_vec() + + jax_fn = jit(basic_jax_fn) + result = jax_fn(abstract_JaxArray(batch.tensor)) + + assert ( + result == 0.0 + ) # checking if the sum of the tensor data is zero as initialized diff --git a/tests/units/array/stack/test_array_stacked_jax.py b/tests/units/array/stack/test_array_stacked_jax.py new file mode 100644 index 00000000000..5fd8876f3be --- /dev/null +++ b/tests/units/array/stack/test_array_stacked_jax.py @@ -0,0 +1,301 @@ +from typing import Optional, Union + +import pytest + +from docarray import BaseDoc, DocList +from docarray.array import DocVec +from docarray.typing import ( + AnyEmbedding, + AnyTensor, + AudioTensor, + ImageTensor, + NdArray, + VideoTensor, +) +from docarray.utils._internal.misc import is_jax_available + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing import JaxArray + + +@pytest.fixture() +@pytest.mark.jax +def batch(): + + import jax.numpy as jnp + + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + batch = DocList[Image]([Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)]) + + return batch.to_doc_vec() + + +@pytest.fixture() +@pytest.mark.jax +def nested_batch(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + class MMdoc(BaseDoc): + img: DocList[Image] + + batch = DocVec[MMdoc]( + [ + MMdoc( + img=DocList[Image]( + [Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)] + ) + ) + for _ in range(10) + ] + ) + + return batch + + +@pytest.mark.jax +def test_len(batch): + assert len(batch) == 10 + + +@pytest.mark.jax +def test_getitem(batch): + for i in range(len(batch)): + item = batch[i] + assert isinstance(item.tensor, JaxArray) + assert jnp.allclose(item.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_get_slice(batch): + sliced = batch[0:2] + assert isinstance(sliced, DocVec) + assert len(sliced) == 2 + + +@pytest.mark.jax +def test_iterator(batch): + for doc in batch: + assert jnp.allclose(doc.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_set_after_stacking(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + batch = DocVec[Image]([Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)]) + + batch.tensor = jnp.ones((10, 3, 224, 224)) + assert jnp.allclose(batch.tensor.tensor, jnp.ones((10, 3, 224, 224))) + for i, doc in enumerate(batch): + assert jnp.allclose(doc.tensor.tensor, batch.tensor.tensor[i]) + + +@pytest.mark.jax +def test_stack_optional(batch): + assert jnp.allclose( + batch._storage.tensor_columns['tensor'].tensor, jnp.zeros((10, 3, 224, 224)) + ) + assert jnp.allclose(batch.tensor.tensor, jnp.zeros((10, 3, 224, 224))) + + +@pytest.mark.jax +def test_stack_mod_nested_document(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + class MMdoc(BaseDoc): + img: Image + + batch = DocList[MMdoc]( + [MMdoc(img=Image(tensor=jnp.zeros((3, 224, 224)))) for _ in range(10)] + ).to_doc_vec() + + assert jnp.allclose( + batch._storage.doc_columns['img']._storage.tensor_columns['tensor'].tensor, + jnp.zeros((10, 3, 224, 224)), + ) + + assert jnp.allclose(batch.img.tensor.tensor, jnp.zeros((10, 3, 224, 224))) + + +@pytest.mark.jax +def test_stack_nested_DocArray(nested_batch): + for i in range(len(nested_batch)): + assert jnp.allclose( + nested_batch[i].img._storage.tensor_columns['tensor'].tensor, + jnp.zeros((10, 3, 224, 224)), + ) + + assert jnp.allclose( + nested_batch[i].img.tensor.tensor, jnp.zeros((10, 3, 224, 224)) + ) + + +@pytest.mark.jax +def test_convert_to_da(batch): + da = batch.to_doc_list() + + for doc in da: + assert jnp.allclose(doc.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_unstack_nested_document(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + class MMdoc(BaseDoc): + img: Image + + batch = DocVec[MMdoc]( + [MMdoc(img=Image(tensor=jnp.zeros((3, 224, 224)))) for _ in range(10)] + ) + assert isinstance(batch.img._storage.tensor_columns['tensor'], JaxArray) + da = batch.to_doc_list() + + for doc in da: + assert jnp.allclose(doc.img.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_unstack_nested_DocArray(nested_batch): + batch = nested_batch.to_doc_list() + for i in range(len(batch)): + assert isinstance(batch[i].img, DocList) + for doc in batch[i].img: + assert jnp.allclose(doc.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_stack_call(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + da = DocList[Image]([Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)]) + + da = da.to_doc_vec() + + assert len(da) == 10 + + assert da.tensor.tensor.shape == (10, 3, 224, 224) + + +@pytest.mark.jax +def test_stack_union(): + class Image(BaseDoc): + tensor: Union[JaxArray[3, 224, 224], NdArray[3, 224, 224]] + + DocVec[Image]( + [Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)], + tensor_type=JaxArray, + ) + + # union fields aren't actually doc_vec + # just checking that there is no error + + +@pytest.mark.jax +def test_setitem_tensor(batch): + batch[3].tensor.tensor = jnp.zeros((3, 224, 224)) + + +@pytest.mark.jax +@pytest.mark.skip('not working yet') +def test_setitem_tensor_direct(batch): + batch[3].tensor = jnp.zeros((3, 224, 224)) + + +@pytest.mark.jax +@pytest.mark.parametrize( + 'cls_tensor', [ImageTensor, AudioTensor, VideoTensor, AnyEmbedding, AnyTensor] +) +def test_generic_tensors_with_jnp(cls_tensor): + tensor = jnp.zeros((3, 224, 224)) + + class Image(BaseDoc): + tensor: cls_tensor + + da = DocVec[Image]( + [Image(tensor=tensor) for _ in range(10)], + tensor_type=JaxArray, + ) + + for i in range(len(da)): + assert jnp.allclose(da[i].tensor.tensor, tensor) + + assert 'tensor' in da._storage.tensor_columns.keys() + assert isinstance(da._storage.tensor_columns['tensor'], JaxArray) + + +@pytest.mark.jax +@pytest.mark.parametrize( + 'cls_tensor', [ImageTensor, AudioTensor, VideoTensor, AnyEmbedding, AnyTensor] +) +def test_generic_tensors_with_optional(cls_tensor): + tensor = jnp.zeros((3, 224, 224)) + + class Image(BaseDoc): + tensor: Optional[cls_tensor] + + class TopDoc(BaseDoc): + img: Image + + da = DocVec[TopDoc]( + [TopDoc(img=Image(tensor=tensor)) for _ in range(10)], + tensor_type=JaxArray, + ) + + for i in range(len(da)): + assert jnp.allclose(da.img[i].tensor.tensor, tensor) + + assert 'tensor' in da.img._storage.tensor_columns.keys() + assert isinstance(da.img._storage.tensor_columns['tensor'], JaxArray) + assert isinstance(da.img._storage.tensor_columns['tensor'].tensor, jnp.ndarray) + + +@pytest.mark.jax +def test_get_from_slice_stacked(): + class Doc(BaseDoc): + text: str + tensor: JaxArray + + da = DocVec[Doc]( + [Doc(text=f'hello{i}', tensor=jnp.zeros((3, 224, 224))) for i in range(10)] + ) + + da_sliced = da[0:10:2] + assert isinstance(da_sliced, DocVec) + + tensors = da_sliced.tensor.tensor + assert tensors.shape == (5, 3, 224, 224) + + +@pytest.mark.jax +def test_stack_none(): + class MyDoc(BaseDoc): + tensor: Optional[AnyTensor] + + da = DocVec[MyDoc]([MyDoc(tensor=None) for _ in range(10)], tensor_type=JaxArray) + assert 'tensor' in da._storage.tensor_columns.keys() + + +@pytest.mark.jax +def test_keep_dtype_jnp(): + class MyDoc(BaseDoc): + tensor: JaxArray + + da = DocList[MyDoc]( + [MyDoc(tensor=jnp.zeros([2, 4], dtype=jnp.int32)) for _ in range(3)] + ) + assert da[0].tensor.tensor.dtype == jnp.int32 + + da = da.to_doc_vec() + assert da[0].tensor.tensor.dtype == jnp.int32 + assert da.tensor.tensor.dtype == jnp.int32 diff --git a/tests/units/computation_backends/jax_backend/__init__.py b/tests/units/computation_backends/jax_backend/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/computation_backends/jax_backend/test_basics.py b/tests/units/computation_backends/jax_backend/test_basics.py new file mode 100644 index 00000000000..1b36c39276c --- /dev/null +++ b/tests/units/computation_backends/jax_backend/test_basics.py @@ -0,0 +1,148 @@ +import pytest + +from docarray.utils._internal.misc import is_jax_available + +jax_available = is_jax_available() +if jax_available: + print("is jax available", jax_available) + import jax + import jax.numpy as jnp + + from docarray.computation.jax_backend import JaxCompBackend + from docarray.typing import JaxArray + + jax.config.update("jax_enable_x64", True) + + +@pytest.mark.jax +@pytest.mark.parametrize( + 'shape,result', + [ + ((5), 1), + ((1, 5), 2), + ((5, 5), 2), + ((), 0), + ], +) +def test_n_dim(shape, result): + + array = JaxArray(jnp.zeros(shape)) + assert JaxCompBackend.n_dim(array) == result + + +@pytest.mark.jax +@pytest.mark.parametrize( + 'shape,result', + [ + ((10,), (10,)), + ((5, 5), (5, 5)), + ((), ()), + ], +) +def test_shape(shape, result): + array = JaxArray(jnp.zeros(shape)) + shape = JaxCompBackend.shape(array) + assert shape == result + assert type(shape) == tuple + + +@pytest.mark.jax +def test_to_device(): + array = JaxArray(jnp.zeros((3))) + array = JaxCompBackend.to_device(array, 'cpu') + assert array.tensor.device().platform.endswith('cpu') + + +@pytest.mark.jax +@pytest.mark.parametrize( + 'dtype,result_type', + [ + ('int64', 'int64'), + ('float64', 'float64'), + ('int8', 'int8'), + ('double', 'float64'), + ], +) +def test_dtype(dtype, result_type): + array = JaxArray(jnp.array([1, 2, 3], dtype=dtype)) + assert JaxCompBackend.dtype(array) == result_type + + +@pytest.mark.jax +def test_empty(): + array = JaxCompBackend.empty((10, 3)) + assert array.tensor.shape == (10, 3) + + +@pytest.mark.jax +def test_empty_dtype(): + tf_tensor = JaxCompBackend.empty((10, 3), dtype=jnp.int32) + assert tf_tensor.tensor.shape == (10, 3) + assert tf_tensor.tensor.dtype == jnp.int32 + + +@pytest.mark.jax +def test_empty_device(): + tensor = JaxCompBackend.empty((10, 3), device='cpu') + assert tensor.tensor.shape == (10, 3) + assert tensor.tensor.device().platform.endswith('cpu') + + +@pytest.mark.jax +def test_squeeze(): + tensor = JaxArray(jnp.zeros(shape=(1, 1, 3, 1))) + squeezed = JaxCompBackend.squeeze(tensor) + assert squeezed.tensor.shape == (3,) + + +@pytest.mark.jax +@pytest.mark.parametrize( + 'data_input,t_range,x_range,data_result', + [ + ( + [0, 1, 2, 3, 4, 5], + (0, 10), + None, + [0, 2, 4, 6, 8, 10], + ), + ( + [0, 1, 2, 3, 4, 5], + (0, 10), + (0, 10), + [0, 1, 2, 3, 4, 5], + ), + ( + [[0.0, 1.0], [0.0, 1.0]], + (0, 10), + None, + [[0.0, 10.0], [0.0, 10.0]], + ), + ], +) +def test_minmax_normalize(data_input, t_range, x_range, data_result): + array = JaxArray(jnp.array(data_input)) + output = JaxCompBackend.minmax_normalize( + tensor=array, t_range=t_range, x_range=x_range + ) + assert jnp.allclose(output.tensor, jnp.array(data_result)) + + +@pytest.mark.jax +def test_reshape(): + tensor = JaxArray(jnp.zeros((3, 224, 224))) + reshaped = JaxCompBackend.reshape(tensor, (224, 224, 3)) + assert reshaped.tensor.shape == (224, 224, 3) + + +@pytest.mark.jax +def test_stack(): + t0 = JaxArray(jnp.zeros((3, 224, 224))) + t1 = JaxArray(jnp.ones((3, 224, 224))) + + stacked1 = JaxCompBackend.stack([t0, t1], dim=0) + assert isinstance(stacked1, JaxArray) + assert stacked1.tensor.shape == (2, 3, 224, 224) + + stacked2 = JaxCompBackend.stack([t0, t1], dim=-1) + assert isinstance(stacked2, JaxArray) + assert stacked2.tensor.shape == (3, 224, 224, 2) diff --git a/tests/units/computation_backends/jax_backend/test_metrics.py b/tests/units/computation_backends/jax_backend/test_metrics.py new file mode 100644 index 00000000000..50dc6339d63 --- /dev/null +++ b/tests/units/computation_backends/jax_backend/test_metrics.py @@ -0,0 +1,81 @@ +import pytest + +from docarray.utils._internal.misc import is_jax_available + +jax_available = is_jax_available() +if jax_available: + import jax + import jax.numpy as jnp + + from docarray.computation.jax_backend import JaxCompBackend + from docarray.typing import JaxArray + + metrics = JaxCompBackend.Metrics +else: + metrics = None + + +@pytest.mark.jax +def test_cosine_sim_jax(): + a = JaxArray(jax.random.uniform(jax.random.PRNGKey(0), shape=(128,))) + b = JaxArray(jax.random.uniform(jax.random.PRNGKey(1), shape=(128,))) + assert metrics.cosine_sim(a, b).tensor.shape == (1,) + assert metrics.cosine_sim(a, b).tensor == metrics.cosine_sim(b, a).tensor + + assert jnp.allclose(metrics.cosine_sim(a, a).tensor, jnp.ones((1,))) + + a = JaxArray(jax.random.uniform(jax.random.PRNGKey(2), shape=(10, 3))) + b = JaxArray(jax.random.uniform(jax.random.PRNGKey(3), shape=(5, 3))) + assert metrics.cosine_sim(a, b).tensor.shape == (10, 5) + assert metrics.cosine_sim(b, a).tensor.shape == (5, 10) + diag_dists = jnp.diagonal(metrics.cosine_sim(b, b).tensor) # self-comparisons + assert jnp.allclose(diag_dists, jnp.ones((5,))) + + +@pytest.mark.jax +@pytest.mark.skip +def test_euclidean_dist_jax(): + a = JaxArray(jax.random.normal(jax.random.PRNGKey(0), shape=(128,))) + b = JaxArray(jax.random.normal(jax.random.PRNGKey(1), shape=(128,))) + assert metrics.euclidean_dist(a, b).tensor.shape == (1,) + assert jnp.allclose( + metrics.euclidean_dist(a, b).tensor, metrics.euclidean_dist(b, a).tensor + ) + + assert jnp.allclose(metrics.euclidean_dist(a, a).tensor, jnp.zeros((1,))) + + a = JaxArray(jnp.zeros((1, 1))) + b = JaxArray(jnp.ones((4, 1))) + assert metrics.euclidean_dist(a, b).tensor.shape == (4,) + assert jnp.allclose( + metrics.euclidean_dist(a, b).tensor, metrics.euclidean_dist(b, a).tensor + ) + assert jnp.allclose(metrics.euclidean_dist(a, a).tensor, jnp.zeros((1,))) + + a = JaxArray(jnp.array([0.0, 2.0, 0.0])) + b = JaxArray(jnp.array([0.0, 0.0, 2.0])) + desired_output_singleton = jnp.sqrt(jnp.array([2.0**2.0 + 2.0**2.0])) + assert jnp.allclose(metrics.euclidean_dist(a, b).tensor, desired_output_singleton) + + a = JaxArray(jnp.array([[0.0, 2.0, 0.0], [0.0, 0.0, 2.0]])) + b = JaxArray(jnp.array([[0.0, 0.0, 2.0], [0.0, 2.0, 0.0]])) + desired_output_singleton = jnp.array([[2.828427, 0.0], [0.0, 2.828427]]) + + assert jnp.allclose(metrics.euclidean_dist(a, b).tensor, desired_output_singleton) + + +@pytest.mark.jax +def test_sqeuclidea_dist_jnp(): + a = JaxArray(jax.random.uniform(jax.random.PRNGKey(0), shape=(128,))) + b = JaxArray(jax.random.uniform(jax.random.PRNGKey(1), shape=(128,))) + assert metrics.sqeuclidean_dist(a, b).tensor.shape == (1,) + assert jnp.allclose( + metrics.sqeuclidean_dist(a, b).tensor, metrics.euclidean_dist(a, b).tensor ** 2 + ) + + a = JaxArray(jax.random.uniform(jax.random.PRNGKey(2), shape=(10, 3))) + b = JaxArray(jax.random.uniform(jax.random.PRNGKey(3), shape=(5, 3))) + assert metrics.sqeuclidean_dist(a, b).tensor.shape == (10, 5) + assert jnp.allclose( + metrics.sqeuclidean_dist(a, b).tensor, metrics.euclidean_dist(a, b).tensor ** 2 + ) diff --git a/tests/units/computation_backends/jax_backend/test_retrieval.py b/tests/units/computation_backends/jax_backend/test_retrieval.py new file mode 100644 index 00000000000..9f8a3afb415 --- /dev/null +++ b/tests/units/computation_backends/jax_backend/test_retrieval.py @@ -0,0 +1,66 @@ +import pytest + +from docarray.utils._internal.misc import is_jax_available + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.computation.jax_backend import JaxCompBackend + from docarray.typing import JaxArray + + metrics = JaxCompBackend.Metrics +else: + metrics = None + + +@pytest.mark.jax +def test_top_k_descending_false(): + top_k = JaxCompBackend.Retrieval.top_k + + a = JaxArray(jnp.array([1, 4, 2, 7, 4, 9, 2])) + vals, indices = top_k(a, 3, descending=False) + + assert vals.tensor.shape == (1, 3) + assert indices.tensor.shape == (1, 3) + assert jnp.allclose(jnp.squeeze(vals.tensor), jnp.array([1, 2, 2])) + assert jnp.allclose(jnp.squeeze(indices.tensor), jnp.array([0, 2, 6])) or ( + jnp.allclose(jnp.squeeze.indices.tensor), + jnp.array([0, 6, 2]), + ) + + a = JaxArray(jnp.array([[1, 4, 2, 7, 4, 9, 2], [11, 6, 2, 7, 3, 10, 4]])) + vals, indices = top_k(a, 3, descending=False) + assert vals.tensor.shape == (2, 3) + assert indices.tensor.shape == (2, 3) + assert jnp.allclose(vals.tensor[0], jnp.array([1, 2, 2])) + assert jnp.allclose(indices.tensor[0], jnp.array([0, 2, 6])) or jnp.allclose( + indices.tensor[0], jnp.array([0, 6, 2]) + ) + assert jnp.allclose(vals.tensor[1], jnp.array([2, 3, 4])) + assert jnp.allclose(indices.tensor[1], jnp.array([2, 4, 6])) + + +@pytest.mark.jax +def test_top_k_descending_true(): + top_k = JaxCompBackend.Retrieval.top_k + + a = JaxArray(jnp.array([1, 4, 2, 7, 4, 9, 2])) + vals, indices = top_k(a, 3, descending=True) + + assert vals.tensor.shape == (1, 3) + assert indices.tensor.shape == (1, 3) + assert jnp.allclose(jnp.squeeze(vals.tensor), jnp.array([9, 7, 4])) + assert jnp.allclose(jnp.squeeze(indices.tensor), jnp.array([5, 3, 1])) + + a = JaxArray(jnp.array([[1, 4, 2, 7, 4, 9, 2], [11, 6, 2, 7, 3, 10, 4]])) + vals, indices = top_k(a, 3, descending=True) + + assert vals.tensor.shape == (2, 3) + assert indices.tensor.shape == (2, 3) + + assert jnp.allclose(vals.tensor[0], jnp.array([9, 7, 4])) + assert jnp.allclose(indices.tensor[0], jnp.array([5, 3, 1])) + + assert jnp.allclose(vals.tensor[1], jnp.array([11, 10, 7])) + assert jnp.allclose(indices.tensor[1], jnp.array([0, 5, 3])) diff --git a/tests/units/typing/tensor/test_jax_array.py b/tests/units/typing/tensor/test_jax_array.py new file mode 100644 index 00000000000..34b4c979dfc --- /dev/null +++ b/tests/units/typing/tensor/test_jax_array.py @@ -0,0 +1,201 @@ +import numpy as np +import pytest +from pydantic import schema_json_of +from pydantic.tools import parse_obj_as + +from docarray.base_doc.io.json import orjson_dumps +from docarray.utils._internal.misc import is_jax_available + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + from jax._src.core import InconclusiveDimensionOperation + + from docarray.typing import JaxArray + + +@pytest.mark.jax +def test_proto_tensor(): + from docarray.proto.pb2.docarray_pb2 import NdArrayProto + + tensor = parse_obj_as(JaxArray, jnp.zeros((3, 224, 224))) + proto = tensor.to_protobuf() + assert isinstance(proto, NdArrayProto) + + from_proto = JaxArray.from_protobuf(proto) + assert isinstance(from_proto, JaxArray) + assert jnp.allclose(tensor.tensor, from_proto.tensor) + + +@pytest.mark.jax +def test_json_schema(): + schema_json_of(JaxArray) + + +@pytest.mark.jax +@pytest.mark.skip +def test_dump_json(): + tensor = parse_obj_as(JaxArray, jnp.zeros((2, 56, 56))) + orjson_dumps(tensor) + + +@pytest.mark.jax +def test_unwrap(): + tf_tensor = parse_obj_as(JaxArray, jnp.zeros((3, 224, 224))) + unwrapped = tf_tensor.unwrap() + + assert not isinstance(unwrapped, JaxArray) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(unwrapped, jnp.ndarray) + + assert np.allclose(unwrapped, np.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_from_ndarray(): + nd = np.array([1, 2, 3]) + tensor = JaxArray.from_ndarray(nd) + assert isinstance(tensor, JaxArray) + assert isinstance(tensor.tensor, jnp.ndarray) + + +@pytest.mark.jax +def test_ellipsis_in_shape(): + # ellipsis in the end, two extra dimensions needed + tf_tensor = parse_obj_as(JaxArray[3, ...], jnp.zeros((3, 128, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 128, 224) + + # ellipsis in the beginning, two extra dimensions needed + tf_tensor = parse_obj_as(JaxArray[..., 224], jnp.zeros((3, 128, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 128, 224) + + # more than one ellipsis in the shape + with pytest.raises(ValueError): + parse_obj_as(JaxArray[3, ..., 128, ...], jnp.zeros((3, 128, 224))) + + # wrong shape + with pytest.raises(ValueError): + parse_obj_as(JaxArray[3, 224, ...], jnp.zeros((3, 128, 224))) + + +@pytest.mark.jax +def test_parametrized(): + # correct shape, single axis + tf_tensor = parse_obj_as(JaxArray[128], jnp.zeros(128)) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (128,) + + # correct shape, multiple axis + tf_tensor = parse_obj_as(JaxArray[3, 224, 224], jnp.zeros((3, 224, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 224, 224) + + # wrong but reshapable shape + tf_tensor = parse_obj_as(JaxArray[3, 224, 224], jnp.zeros((224, 3, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 224, 224) + + # wrong and not reshapable shape + with pytest.raises(InconclusiveDimensionOperation): + parse_obj_as(JaxArray[3, 224, 224], jnp.zeros((224, 224))) + + +@pytest.mark.jax +def test_parametrized_with_str(): + # test independent variable dimensions + tf_tensor = parse_obj_as(JaxArray[3, 'x', 'y'], jnp.zeros((3, 224, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 224, 224) + + tf_tensor = parse_obj_as(JaxArray[3, 'x', 'y'], jnp.zeros((3, 60, 128))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 60, 128) + + with pytest.raises(ValueError): + parse_obj_as(JaxArray[3, 'x', 'y'], jnp.zeros((4, 224, 224))) + + with pytest.raises(ValueError): + parse_obj_as(JaxArray[3, 'x', 'y'], jnp.zeros((100, 1))) + + # test dependent variable dimensions + tf_tensor = parse_obj_as(JaxArray[3, 'x', 'x'], jnp.zeros((3, 224, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 224, 224) + + with pytest.raises(ValueError): + _ = parse_obj_as(JaxArray[3, 'x', 'x'], jnp.zeros((3, 60, 128))) + + with pytest.raises(ValueError): + _ = parse_obj_as(JaxArray[3, 'x', 'x'], jnp.zeros((3, 60))) + + +@pytest.mark.jax +@pytest.mark.parametrize('shape', [(3, 224, 224), (224, 224, 3)]) +def test_parameterized_tensor_class_name(shape): + MyTFT = JaxArray[3, 224, 224] + tensor = parse_obj_as(MyTFT, jnp.zeros(shape)) + + assert MyTFT.__name__ == 'JaxArray[3, 224, 224]' + assert MyTFT.__qualname__ == 'JaxArray[3, 224, 224]' + + assert tensor.__class__.__name__ == 'JaxArray' + assert tensor.__class__.__qualname__ == 'JaxArray' + assert f'{tensor.tensor[0][0][0]}' == '0.0' + + +@pytest.mark.jax +def test_parametrized_subclass(): + c1 = JaxArray[128] + c2 = JaxArray[128] + assert issubclass(c1, c2) + assert issubclass(c1, JaxArray) + + assert not issubclass(c1, JaxArray[256]) + + +@pytest.mark.jax +def test_parametrized_instance(): + t = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + assert isinstance(t, JaxArray[128]) + assert isinstance(t, JaxArray) + # assert isinstance(t, jnp.ndarray) + + assert not isinstance(t, JaxArray[256]) + assert not isinstance(t, JaxArray[2, 128]) + assert not isinstance(t, JaxArray[2, 2, 64]) + + +@pytest.mark.jax +def test_parametrized_equality(): + t1 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + t2 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + assert jnp.allclose(t1.tensor, t2.tensor) + + +@pytest.mark.jax +def test_parametrized_operations(): + t1 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + t2 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + t_result = t1.tensor + t2.tensor + assert isinstance(t_result, jnp.ndarray) + assert not isinstance(t_result, JaxArray) + assert not isinstance(t_result, JaxArray[128]) + + +@pytest.mark.jax +def test_set_item(): + t = JaxArray(tensor=jnp.zeros((3, 224, 224))) + t[0] = jnp.ones((1, 224, 224)) + assert jnp.allclose(t.tensor[0], jnp.ones((1, 224, 224))) + assert jnp.allclose(t.tensor[1], jnp.zeros((1, 224, 224))) + assert jnp.allclose(t.tensor[2], jnp.zeros((1, 224, 224))) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy