From 3ea813eaef93eaf51c4f1a8bd16d20f11f95a583 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Mon, 22 May 2023 14:45:37 +0530 Subject: [PATCH 01/25] fix: jax backend boilerplate setup Signed-off-by: agaraman0 --- docarray/computation/jax_backend.py | 207 ++++++++++++++++++ .../typing/tensor/audio/audio_jax_array.py | 0 docarray/typing/tensor/embedding/jax_array.py | 0 .../typing/tensor/image/image_jax_array.py | 0 docarray/typing/tensor/jax_array.py | 106 +++++++++ .../typing/tensor/video/video_jax_array.py | 0 .../jax_backend/__init__.py | 0 .../jax_backend/test_basics.py | 0 .../jax_backend/test_metrics.py | 0 .../jax_backend/test_retrieval.py | 0 10 files changed, 313 insertions(+) create mode 100644 docarray/computation/jax_backend.py create mode 100644 docarray/typing/tensor/audio/audio_jax_array.py create mode 100644 docarray/typing/tensor/embedding/jax_array.py create mode 100644 docarray/typing/tensor/image/image_jax_array.py create mode 100644 docarray/typing/tensor/jax_array.py create mode 100644 docarray/typing/tensor/video/video_jax_array.py create mode 100644 tests/units/computation_backends/jax_backend/__init__.py create mode 100644 tests/units/computation_backends/jax_backend/test_basics.py create mode 100644 tests/units/computation_backends/jax_backend/test_metrics.py create mode 100644 tests/units/computation_backends/jax_backend/test_retrieval.py diff --git a/docarray/computation/jax_backend.py b/docarray/computation/jax_backend.py new file mode 100644 index 00000000000..da4a2e770f0 --- /dev/null +++ b/docarray/computation/jax_backend.py @@ -0,0 +1,207 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +import numpy as np + +from docarray.computation.abstract_comp_backend import AbstractComputationalBackend +from docarray.utils._internal.misc import import_library + +if TYPE_CHECKING: + import jax +else: + torch = import_library('jax', raise_error=True) + + +def _unsqueeze_if_single_axis(*matrices) -> List[torch.Tensor]: + """Unsqueezes tensors that only have one axis, at dim 0. + This ensures that all outputs can be treated as matrices, not vectors. + + :param matrices: Matrices to be unsqueezed + :return: List of the input matrices, + where single axis matrices are unsqueezed at dim 0. + """ + pass + + +def _unsqueeze_if_scalar(t): + pass + + +def identity(array: jax.numpy.ndarray) -> jax.numpy.ndarray: + return array + + +class JaxCompBackend(AbstractComputationalBackend[torch.Tensor]): + """ + Computational backend for Numpy. + """ + + _module = np + _cast_output = identity + _get_tensor = identity + + @classmethod + def to_device(cls, tensor: 'jax.numpy.array', device: str) -> 'jax.numpy.array': + """Move the tensor to the specified device.""" + raise NotImplementedError('Numpy does not support devices (GPU).') + + @classmethod + def device(cls, tensor: 'jax.numpy.array') -> Optional[str]: + """Return device on which the tensor is allocated.""" + return None + + @classmethod + def to_numpy(cls, array: 'jax.numpy.array') -> 'np.ndarray': + return array + + @classmethod + def none_value(cls) -> Any: + """Provide a compatible value that represents None in numpy.""" + return None + + @classmethod + def detach(cls, tensor: 'jax.numpy.array') -> 'jax.numpy.array': + """ + Returns the tensor detached from its current graph. + + :param tensor: tensor to be detached + :return: a detached tensor with the same data. + """ + pass + + @classmethod + def dtype(cls, tensor: 'jax.numpy.array') -> np.dtype: + """Get the data type of the tensor.""" + pass + + @classmethod + def minmax_normalize( + cls, + tensor: 'jax.numpy.array', + t_range: Tuple = (0, 1), + x_range: Optional[Tuple] = None, + eps: float = 1e-7, + ) -> 'jax.numpy.array': + """ + Normalize values in `tensor` into `t_range`. + + `tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then + normalization is row-based. + + !!! note + + - with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1; + - with `t_range=(1, 0)` will normalize the min-value of data to 1, max value + of the data to 0. + + :param tensor: the data to be normalized + :param t_range: a tuple represents the target range. + :param x_range: a tuple represents tensors range. + :param eps: a small jitter to avoid divide by zero + :return: normalized data in `t_range` + """ + pass + + class Retrieval(AbstractComputationalBackend.Retrieval[jax.numpy.array]): + """ + Abstract class for retrieval and ranking functionalities + """ + + @staticmethod + def top_k( + values: 'jax.numpy.array', + k: int, + descending: bool = False, + device: Optional[str] = None, + ) -> Tuple['jax.numpy.array', 'jax.numpy.array']: + """ + Retrieves the top k smallest values in `values`, + and returns them alongside their indices in the input `values`. + Can also be used to retrieve the top k largest values, + by setting the `descending` flag. + + :param values: Torch tensor of values to rank. + Should be of shape (n_queries, n_values_per_query). + Inputs of shape (n_values_per_query,) will be expanded + to (1, n_values_per_query). + :param k: number of values to retrieve + :param descending: retrieve largest values instead of smallest values + :param device: Not supported for this backend + :return: Tuple containing the retrieved values, and their indices. + Both ar of shape (n_queries, k) + """ + pass + + class Metrics(AbstractComputationalBackend.Metrics[jax.numpy.array]): + """ + Abstract base class for metrics (distances and similarities). + """ + + @staticmethod + def cosine_sim( + x_mat: jax.numpy.array, + y_mat: jax.numpy.array, + eps: float = 1e-7, + device: Optional[str] = None, + ) -> jax.numpy.array: + """Pairwise cosine similarities between all vectors in x_mat and y_mat. + + :param x_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param y_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param eps: a small jitter to avoid divde by zero + :param device: Not supported for this backend + :return: jax.numpy.array of shape (n_vectors, n_vectors) containing all + pairwise cosine distances. + The index [i_x, i_y] contains the cosine distance between + x_mat[i_x] and y_mat[i_y]. + """ + pass + + @classmethod + def euclidean_dist( + cls, + x_mat: jax.numpy.array, + y_mat: jax.numpy.array, + device: Optional[str] = None, + ) -> jax.numpy.array: + """Pairwise Euclidian distances between all vectors in x_mat and y_mat. + + :param x_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param y_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param eps: a small jitter to avoid divde by zero + :param device: Not supported for this backend + :return: jax.numpy.array of shape (n_vectors, n_vectors) containing all + pairwise euclidian distances. + The index [i_x, i_y] contains the euclidian distance between + x_mat[i_x] and y_mat[i_y]. + """ + pass + + @staticmethod + def sqeuclidean_dist( + x_mat: jax.numpy.array, + y_mat: jax.numpy.array, + device: Optional[str] = None, + ) -> jax.numpy.array: + """Pairwise Squared Euclidian distances between all vectors in + x_mat and y_mat. + + :param x_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param y_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + the number of vectors and n_dim is the number of dimensions of each + example. + :param device: Not supported for this backend + :return: jax.numpy.array of shape (n_vectors, n_vectors) containing all + pairwise Squared Euclidian distances. + The index [i_x, i_y] contains the cosine Squared Euclidian between + x_mat[i_x] and y_mat[i_y]. + """ diff --git a/docarray/typing/tensor/audio/audio_jax_array.py b/docarray/typing/tensor/audio/audio_jax_array.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/typing/tensor/embedding/jax_array.py b/docarray/typing/tensor/embedding/jax_array.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/typing/tensor/image/image_jax_array.py b/docarray/typing/tensor/image/image_jax_array.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/typing/tensor/jax_array.py b/docarray/typing/tensor/jax_array.py new file mode 100644 index 00000000000..7269d35d4c0 --- /dev/null +++ b/docarray/typing/tensor/jax_array.py @@ -0,0 +1,106 @@ +from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union + +import numpy as np + +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.abstract_tensor import AbstractTensor + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + + +from docarray.base_doc.base_node import BaseNode + +T = TypeVar('T') +ShapeT = TypeVar('ShapeT') + +tensor_base: type = type(BaseNode) + + +# the mypy error suppression below should not be necessary anymore once the following +# is released in mypy: https://github.com/python/mypy/pull/14135 +class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): # type: ignore + pass + + +@_register_proto(proto_type_name='jaxarray') +class JaxArray(np.ndarray, AbstractTensor, Generic[ShapeT]): + """ + Subclass of `np.ndarray`, intended for use in a Document. + This enables (de)serialization from/to protobuf and json, data validation, + and coersion from compatible types like `torch.Tensor`. + + This type can also be used in a parametrized way, specifying the shape of the array. + + --- + + ```python + from docarray import BaseDoc + from docarray.typing import NdArray + import numpy as np + + + class MyDoc(BaseDoc): + arr: NdArray + image_arr: NdArray[3, 224, 224] + square_crop: NdArray[3, 'x', 'x'] + random_image: NdArray[3, ...] # first dimension is fixed, can have arbitrary shape + + + # create a document with tensors + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((3, 224, 224)), + square_crop=np.zeros((3, 64, 64)), + random_image=np.zeros((3, 128, 256)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # automatic shape conversion + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) + square_crop=np.zeros((3, 128, 128)), + random_image=np.zeros((3, 64, 128)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # !! The following will raise an error due to shape mismatch !! + from pydantic import ValidationError + + try: + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224)), # this will fail validation + square_crop=np.zeros((3, 128, 64)), # this will also fail validation + random_image=np.zeros((4, 64, 128)), # this will also fail validation + ) + except ValidationError as e: + pass + ``` + + --- + """ + + __parametrized_meta__ = metaNumpy + + @classmethod + def __get_validators__(cls): + # one or more validators may be yielded which will be called in the + # order to validate the input, each validator will receive as an input + # the value returned from the previous validator + pass + + @classmethod + def validate( + cls: Type[T], + value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + pass + + @classmethod + def _docarray_from_native(cls: Type[T], value: np.ndarray) -> T: + pass diff --git a/docarray/typing/tensor/video/video_jax_array.py b/docarray/typing/tensor/video/video_jax_array.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/computation_backends/jax_backend/__init__.py b/tests/units/computation_backends/jax_backend/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/computation_backends/jax_backend/test_basics.py b/tests/units/computation_backends/jax_backend/test_basics.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/computation_backends/jax_backend/test_metrics.py b/tests/units/computation_backends/jax_backend/test_metrics.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/computation_backends/jax_backend/test_retrieval.py b/tests/units/computation_backends/jax_backend/test_retrieval.py new file mode 100644 index 00000000000..e69de29bb2d From 825daf51b2d408fa5e9c934b410b0252c0ef98fd Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Tue, 23 May 2023 12:51:53 +0530 Subject: [PATCH 02/25] feat: typing init Signed-off-by: agaraman0 --- docarray/typing/tensor/__init__.py | 2 + docarray/typing/tensor/jax_array.py | 106 --------------- docarray/typing/tensor/jaxarray.py | 202 ++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 106 deletions(-) delete mode 100644 docarray/typing/tensor/jax_array.py create mode 100644 docarray/typing/tensor/jaxarray.py diff --git a/docarray/typing/tensor/__init__.py b/docarray/typing/tensor/__init__.py index 4c4077f3cdb..8e8f6653bd6 100644 --- a/docarray/typing/tensor/__init__.py +++ b/docarray/typing/tensor/__init__.py @@ -5,6 +5,7 @@ from docarray.typing.tensor.audio import AudioNdArray from docarray.typing.tensor.embedding import AnyEmbedding, NdArrayEmbedding from docarray.typing.tensor.image import ImageNdArray, ImageTensor +from docarray.typing.tensor.jaxarray import JaxArray from docarray.typing.tensor.ndarray import NdArray from docarray.typing.tensor.tensor import AnyTensor from docarray.typing.tensor.video import VideoNdArray @@ -34,6 +35,7 @@ 'ImageTensor', 'AudioNdArray', 'VideoNdArray', + 'JaxArray', ] diff --git a/docarray/typing/tensor/jax_array.py b/docarray/typing/tensor/jax_array.py deleted file mode 100644 index 7269d35d4c0..00000000000 --- a/docarray/typing/tensor/jax_array.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union - -import numpy as np - -from docarray.typing.proto_register import _register_proto -from docarray.typing.tensor.abstract_tensor import AbstractTensor - -if TYPE_CHECKING: - from pydantic import BaseConfig - from pydantic.fields import ModelField - - -from docarray.base_doc.base_node import BaseNode - -T = TypeVar('T') -ShapeT = TypeVar('ShapeT') - -tensor_base: type = type(BaseNode) - - -# the mypy error suppression below should not be necessary anymore once the following -# is released in mypy: https://github.com/python/mypy/pull/14135 -class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): # type: ignore - pass - - -@_register_proto(proto_type_name='jaxarray') -class JaxArray(np.ndarray, AbstractTensor, Generic[ShapeT]): - """ - Subclass of `np.ndarray`, intended for use in a Document. - This enables (de)serialization from/to protobuf and json, data validation, - and coersion from compatible types like `torch.Tensor`. - - This type can also be used in a parametrized way, specifying the shape of the array. - - --- - - ```python - from docarray import BaseDoc - from docarray.typing import NdArray - import numpy as np - - - class MyDoc(BaseDoc): - arr: NdArray - image_arr: NdArray[3, 224, 224] - square_crop: NdArray[3, 'x', 'x'] - random_image: NdArray[3, ...] # first dimension is fixed, can have arbitrary shape - - - # create a document with tensors - doc = MyDoc( - arr=np.zeros((128,)), - image_arr=np.zeros((3, 224, 224)), - square_crop=np.zeros((3, 64, 64)), - random_image=np.zeros((3, 128, 256)), - ) - assert doc.image_arr.shape == (3, 224, 224) - - # automatic shape conversion - doc = MyDoc( - arr=np.zeros((128,)), - image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) - square_crop=np.zeros((3, 128, 128)), - random_image=np.zeros((3, 64, 128)), - ) - assert doc.image_arr.shape == (3, 224, 224) - - # !! The following will raise an error due to shape mismatch !! - from pydantic import ValidationError - - try: - doc = MyDoc( - arr=np.zeros((128,)), - image_arr=np.zeros((224, 224)), # this will fail validation - square_crop=np.zeros((3, 128, 64)), # this will also fail validation - random_image=np.zeros((4, 64, 128)), # this will also fail validation - ) - except ValidationError as e: - pass - ``` - - --- - """ - - __parametrized_meta__ = metaNumpy - - @classmethod - def __get_validators__(cls): - # one or more validators may be yielded which will be called in the - # order to validate the input, each validator will receive as an input - # the value returned from the previous validator - pass - - @classmethod - def validate( - cls: Type[T], - value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], - field: 'ModelField', - config: 'BaseConfig', - ) -> T: - pass - - @classmethod - def _docarray_from_native(cls: Type[T], value: np.ndarray) -> T: - pass diff --git a/docarray/typing/tensor/jaxarray.py b/docarray/typing/tensor/jaxarray.py new file mode 100644 index 00000000000..87208426953 --- /dev/null +++ b/docarray/typing/tensor/jaxarray.py @@ -0,0 +1,202 @@ +from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union, cast + +import jax.numpy as jnp + +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.abstract_tensor import AbstractTensor + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + + from docarray.computation.jax_backend import JaxCompBackend + from docarray.proto import NdArrayProto + +from docarray.base_doc.base_node import BaseNode + +T = TypeVar('T', bound='JaxArray') +ShapeT = TypeVar('ShapeT') + +tensor_base: type = type(BaseNode) + + +# the mypy error suppression below should not be necessary anymore once the following +# is released in mypy: https://github.com/python/mypy/pull/14135 +class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): # type: ignore + pass + + +@_register_proto(proto_type_name='ndarray') +class JaxArray(jnp.ndarray, AbstractTensor, Generic[ShapeT]): + """ + Subclass of `np.ndarray`, intended for use in a Document. + This enables (de)serialization from/to protobuf and json, data validation, + and coersion from compatible types like `torch.Tensor`. + + This type can also be used in a parametrized way, specifying the shape of the array. + + --- + + ```python + from docarray import BaseDoc + from docarray.typing import NdArray + import numpy as np + + + class MyDoc(BaseDoc): + arr: NdArray + image_arr: NdArray[3, 224, 224] + square_crop: NdArray[3, 'x', 'x'] + random_image: NdArray[3, ...] # first dimension is fixed, can have arbitrary shape + + + # create a document with tensors + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((3, 224, 224)), + square_crop=np.zeros((3, 64, 64)), + random_image=np.zeros((3, 128, 256)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # automatic shape conversion + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) + square_crop=np.zeros((3, 128, 128)), + random_image=np.zeros((3, 64, 128)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # !! The following will raise an error due to shape mismatch !! + from pydantic import ValidationError + + try: + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224)), # this will fail validation + square_crop=np.zeros((3, 128, 64)), # this will also fail validation + random_image=np.zeros((4, 64, 128)), # this will also fail validation + ) + except ValidationError as e: + pass + ``` + + --- + """ + + __parametrized_meta__ = metaNumpy + + @classmethod + def __get_validators__(cls): + # one or more validators may be yielded which will be called in the + # order to validate the input, each validator will receive as an input + # the value returned from the previous validator + yield cls.validate + + @classmethod + def validate( + cls: Type[T], + value: Union[T, jnp.ndarray, List[Any], Tuple[Any], Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + if isinstance(value, jnp.ndarray): + return cls._docarray_from_native(value) + elif isinstance(value, JaxArray): + return cast(T, value) + elif isinstance(value, list) or isinstance(value, tuple): + try: + arr_from_list: jnp.ndarray = jnp.asarray(value) + return cls._docarray_from_native(arr_from_list) + except Exception: + pass # handled below + else: + try: + arr: jnp.ndarray = jnp.ndarray(value) + return cls._docarray_from_native(arr) + except Exception: + pass # handled below + raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}') + + @classmethod + def _docarray_from_native(cls: Type[T], value: jnp.ndarray) -> T: + if cls.__unparametrizedcls__: # This is not None if the tensor is parametrized + return cast(T, value.view(cls.__unparametrizedcls__)) + return value.view(cls) + + def _docarray_to_json_compatible(self) -> jnp.ndarray: + """ + Convert `JaxArray` into a json compatible object + :return: a representation of the tensor compatible with orjson + """ + return self.unwrap() + + def unwrap(self) -> jnp.ndarray: + """ + Return the original ndarray without any memory copy. + + The original view rest intact and is still a Document `JaxArray` + but the return object is a pure `np.ndarray` but both object share + the same memory layout. + + --- + + ```python + from docarray.typing import JaxArray + import numpy as np + + t1 = JaxArray.validate(np.zeros((3, 224, 224)), None, None) + # here t1 is a docarray NdArray + t2 = t1.unwrap() + # here t2 is a pure np.ndarray but t1 is still a Docarray JaxArray + # But both share the same underlying memory + ``` + + --- + + :return: a `jnp.ndarray` + """ + return self.view(jnp.ndarray) + + @classmethod + def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': + """ + Read ndarray from a proto msg + :param pb_msg: + :return: a numpy array + """ + source = pb_msg.dense + if source.buffer: + x = jnp.frombuffer(bytearray(source.buffer), dtype=source.dtype) + return cls._docarray_from_native(x.reshape(source.shape)) + elif len(source.shape) > 0: + return cls._docarray_from_native(jnp.zeros(source.shape)) + else: + raise ValueError(f'proto message {pb_msg} cannot be cast to a NdArray') + + def to_protobuf(self) -> 'NdArrayProto': + """ + Transform self into a NdArrayProto protobuf message + """ + from docarray.proto import NdArrayProto + + nd_proto = NdArrayProto() + + nd_proto.dense.buffer = self.tobytes() + nd_proto.dense.ClearField('shape') + nd_proto.dense.shape.extend(list(self.shape)) + nd_proto.dense.dtype = self.dtype.str + + return nd_proto + + @staticmethod + def get_comp_backend() -> 'JaxCompBackend': + """Return the computational backend of the tensor""" + from docarray.computation.jax_backend import JaxCompBackend + + return JaxCompBackend() + + def __class_getitem__(cls, item: Any, *args, **kwargs): + # see here for mypy bug: https://github.com/python/mypy/issues/14123 + return AbstractTensor.__class_getitem__.__func__(cls, item) # type: ignore From fd11322960a091cee6135cd49678b5f95d87c1de Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Wed, 24 May 2023 15:16:30 +0530 Subject: [PATCH 03/25] feat: JaxArray refactoring Signed-off-by: agaraman0 --- docarray/typing/__init__.py | 2 + docarray/typing/tensor/jaxarray.py | 59 +----------------------------- 2 files changed, 4 insertions(+), 57 deletions(-) diff --git a/docarray/typing/__init__.py b/docarray/typing/__init__.py index 5fdb578ad04..1cd0133c2f8 100644 --- a/docarray/typing/__init__.py +++ b/docarray/typing/__init__.py @@ -5,6 +5,7 @@ from docarray.typing.tensor import ImageNdArray, ImageTensor from docarray.typing.tensor.audio import AudioNdArray, AudioTensor from docarray.typing.tensor.embedding.embedding import AnyEmbedding, NdArrayEmbedding +from docarray.typing.tensor.jaxarray import JaxArray from docarray.typing.tensor.ndarray import NdArray from docarray.typing.tensor.tensor import AnyTensor from docarray.typing.tensor.video import VideoNdArray, VideoTensor @@ -56,6 +57,7 @@ 'ImageBytes', 'VideoBytes', 'AudioBytes', + 'JaxArray', ] diff --git a/docarray/typing/tensor/jaxarray.py b/docarray/typing/tensor/jaxarray.py index 87208426953..1aa1432832b 100644 --- a/docarray/typing/tensor/jaxarray.py +++ b/docarray/typing/tensor/jaxarray.py @@ -26,64 +26,9 @@ class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): # type: ign pass -@_register_proto(proto_type_name='ndarray') +@_register_proto(proto_type_name='jaxarray') class JaxArray(jnp.ndarray, AbstractTensor, Generic[ShapeT]): - """ - Subclass of `np.ndarray`, intended for use in a Document. - This enables (de)serialization from/to protobuf and json, data validation, - and coersion from compatible types like `torch.Tensor`. - - This type can also be used in a parametrized way, specifying the shape of the array. - - --- - - ```python - from docarray import BaseDoc - from docarray.typing import NdArray - import numpy as np - - - class MyDoc(BaseDoc): - arr: NdArray - image_arr: NdArray[3, 224, 224] - square_crop: NdArray[3, 'x', 'x'] - random_image: NdArray[3, ...] # first dimension is fixed, can have arbitrary shape - - - # create a document with tensors - doc = MyDoc( - arr=np.zeros((128,)), - image_arr=np.zeros((3, 224, 224)), - square_crop=np.zeros((3, 64, 64)), - random_image=np.zeros((3, 128, 256)), - ) - assert doc.image_arr.shape == (3, 224, 224) - - # automatic shape conversion - doc = MyDoc( - arr=np.zeros((128,)), - image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) - square_crop=np.zeros((3, 128, 128)), - random_image=np.zeros((3, 64, 128)), - ) - assert doc.image_arr.shape == (3, 224, 224) - - # !! The following will raise an error due to shape mismatch !! - from pydantic import ValidationError - - try: - doc = MyDoc( - arr=np.zeros((128,)), - image_arr=np.zeros((224, 224)), # this will fail validation - square_crop=np.zeros((3, 128, 64)), # this will also fail validation - random_image=np.zeros((4, 64, 128)), # this will also fail validation - ) - except ValidationError as e: - pass - ``` - - --- - """ + """ """ __parametrized_meta__ = metaNumpy From bc4a698f5bff3d55f3e96e1c56f52fb7a83917a5 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Sun, 4 Jun 2023 09:56:36 +0530 Subject: [PATCH 04/25] fix: _docarray_from_native function for jaxarray Signed-off-by: agaraman0 --- docarray/typing/tensor/jaxarray.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/docarray/typing/tensor/jaxarray.py b/docarray/typing/tensor/jaxarray.py index 1aa1432832b..2b3ec888f2e 100644 --- a/docarray/typing/tensor/jaxarray.py +++ b/docarray/typing/tensor/jaxarray.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union, cast import jax.numpy as jnp +from jax import Array from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -46,7 +47,7 @@ def validate( field: 'ModelField', config: 'BaseConfig', ) -> T: - if isinstance(value, jnp.ndarray): + if isinstance(value, Array): return cls._docarray_from_native(value) elif isinstance(value, JaxArray): return cast(T, value) @@ -111,29 +112,13 @@ def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': :param pb_msg: :return: a numpy array """ - source = pb_msg.dense - if source.buffer: - x = jnp.frombuffer(bytearray(source.buffer), dtype=source.dtype) - return cls._docarray_from_native(x.reshape(source.shape)) - elif len(source.shape) > 0: - return cls._docarray_from_native(jnp.zeros(source.shape)) - else: - raise ValueError(f'proto message {pb_msg} cannot be cast to a NdArray') + pass def to_protobuf(self) -> 'NdArrayProto': """ Transform self into a NdArrayProto protobuf message """ - from docarray.proto import NdArrayProto - - nd_proto = NdArrayProto() - - nd_proto.dense.buffer = self.tobytes() - nd_proto.dense.ClearField('shape') - nd_proto.dense.shape.extend(list(self.shape)) - nd_proto.dense.dtype = self.dtype.str - - return nd_proto + pass @staticmethod def get_comp_backend() -> 'JaxCompBackend': From cb64d4fbec932d373a4686a3a85a94d2b5af33ca Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Wed, 14 Jun 2023 09:11:03 +0530 Subject: [PATCH 05/25] feat: JAX array implementation is complete Signed-off-by: agaraman0 --- tests/units/typing/tensor/test_jax_array.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/units/typing/tensor/test_jax_array.py diff --git a/tests/units/typing/tensor/test_jax_array.py b/tests/units/typing/tensor/test_jax_array.py new file mode 100644 index 00000000000..e69de29bb2d From d57a656050ee4834ee19170c97aabdf196c30b1d Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Wed, 14 Jun 2023 09:17:39 +0530 Subject: [PATCH 06/25] feat: JAX array implementation is complete Signed-off-by: agaraman0 --- docarray/computation/jax_backend.py | 29 +-- docarray/typing/tensor/jaxarray.py | 87 ++++++++- .../jax_backend/test_basics.py | 139 ++++++++++++++ tests/units/typing/tensor/test_jax_array.py | 181 ++++++++++++++++++ 4 files changed, 414 insertions(+), 22 deletions(-) diff --git a/docarray/computation/jax_backend.py b/docarray/computation/jax_backend.py index da4a2e770f0..18fec781cf1 100644 --- a/docarray/computation/jax_backend.py +++ b/docarray/computation/jax_backend.py @@ -1,17 +1,18 @@ -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple +import jax +import jax.numpy as jnp import numpy as np from docarray.computation.abstract_comp_backend import AbstractComputationalBackend -from docarray.utils._internal.misc import import_library +from docarray.computation.abstract_numpy_based_backend import AbstractNumpyBasedBackend +from docarray.typing import JaxArray if TYPE_CHECKING: - import jax -else: - torch = import_library('jax', raise_error=True) + pass -def _unsqueeze_if_single_axis(*matrices) -> List[torch.Tensor]: +def _unsqueeze_if_single_axis(*matrices) -> List[jnp.ndarray]: """Unsqueezes tensors that only have one axis, at dim 0. This ensures that all outputs can be treated as matrices, not vectors. @@ -26,18 +27,22 @@ def _unsqueeze_if_scalar(t): pass -def identity(array: jax.numpy.ndarray) -> jax.numpy.ndarray: - return array +def norm_left(t: jnp.ndarray) -> JaxArray: + return JaxArray(tensor=t) + + +def norm_right(t: JaxArray) -> jnp.ndarray: + return t.tensor -class JaxCompBackend(AbstractComputationalBackend[torch.Tensor]): +class JaxCompBackend(AbstractNumpyBasedBackend): """ Computational backend for Numpy. """ - _module = np - _cast_output = identity - _get_tensor = identity + _module = jnp + _cast_output: Callable = norm_left + _get_tensor: Callable = norm_right @classmethod def to_device(cls, tensor: 'jax.numpy.array', device: str) -> 'jax.numpy.array': diff --git a/docarray/typing/tensor/jaxarray.py b/docarray/typing/tensor/jaxarray.py index 2b3ec888f2e..49e313c0d9b 100644 --- a/docarray/typing/tensor/jaxarray.py +++ b/docarray/typing/tensor/jaxarray.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union, cast import jax.numpy as jnp +import numpy as np from jax import Array from docarray.typing.proto_register import _register_proto @@ -18,20 +19,47 @@ T = TypeVar('T', bound='JaxArray') ShapeT = TypeVar('ShapeT') -tensor_base: type = type(BaseNode) +node_base: type = type(BaseNode) # the mypy error suppression below should not be necessary anymore once the following # is released in mypy: https://github.com/python/mypy/pull/14135 -class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): # type: ignore +class metaJax( + AbstractTensor.__parametrized_meta__, # type: ignore + node_base, # type: ignore +): # type: ignore pass @_register_proto(proto_type_name='jaxarray') -class JaxArray(jnp.ndarray, AbstractTensor, Generic[ShapeT]): +class JaxArray(AbstractTensor, Generic[ShapeT], metaclass=metaJax): """ """ - __parametrized_meta__ = metaNumpy + __parametrized_meta__ = metaJax + + def __init__(self, tensor: jnp.ndarray): + super().__init__() + self.tensor = tensor + + def __getitem__(self, item): + from docarray.computation.jax_backend import JaxCompBackend + + tensor = self.unwrap() + if tensor is not None: + tensor = tensor[item] + return JaxCompBackend._cast_output(t=tensor) + + def __setitem__(self, index, value): + """""" + # print(index, value) + self.tensor = self.tensor.at[index : index + 1].set(value) + + def __iter__(self): + for i in range(len(self)): + yield self[i] + + def __len__(self): + return len(self.tensor) @classmethod def __get_validators__(cls): @@ -67,9 +95,29 @@ def validate( @classmethod def _docarray_from_native(cls: Type[T], value: jnp.ndarray) -> T: - if cls.__unparametrizedcls__: # This is not None if the tensor is parametrized - return cast(T, value.view(cls.__unparametrizedcls__)) - return value.view(cls) + if isinstance(value, JaxArray): + if cls.__unparametrizedcls__: # None if the tensor is parametrized + value.__class__ = cls.__unparametrizedcls__ # type: ignore + else: + value.__class__ = cls + return cast(T, value) + else: + if cls.__unparametrizedcls__: # None if the tensor is parametrized + cls_param_ = cls.__unparametrizedcls__ + cls_param = cast(Type[T], cls_param_) + else: + cls_param = cls + + return cls_param(tensor=value) + + @classmethod + def from_ndarray(cls: Type[T], value: np.ndarray) -> T: + """Create a `TensorFlowTensor` from a numpy array. + + :param value: the numpy array + :return: a `TensorFlowTensor` + """ + return cls._docarray_from_native(jnp.array(value)) def _docarray_to_json_compatible(self) -> jnp.ndarray: """ @@ -103,7 +151,7 @@ def unwrap(self) -> jnp.ndarray: :return: a `jnp.ndarray` """ - return self.view(jnp.ndarray) + return self.tensor @classmethod def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': @@ -112,13 +160,32 @@ def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': :param pb_msg: :return: a numpy array """ - pass + source = pb_msg.dense + if source.buffer: + x = np.frombuffer(bytearray(source.buffer), dtype=source.dtype) + return cls.from_ndarray(x.reshape(source.shape)) + elif len(source.shape) > 0: + return cls.from_ndarray(np.zeros(source.shape)) + else: + raise ValueError( + f'Proto message {pb_msg} cannot be cast to a TensorFlowTensor.' + ) def to_protobuf(self) -> 'NdArrayProto': """ Transform self into a NdArrayProto protobuf message """ - pass + from docarray.proto import NdArrayProto + + nd_proto = NdArrayProto() + + value_np = self.tensor + nd_proto.dense.buffer = value_np.tobytes() + nd_proto.dense.ClearField('shape') + nd_proto.dense.shape.extend(list(value_np.shape)) + nd_proto.dense.dtype = value_np.dtype.str + + return nd_proto @staticmethod def get_comp_backend() -> 'JaxCompBackend': diff --git a/tests/units/computation_backends/jax_backend/test_basics.py b/tests/units/computation_backends/jax_backend/test_basics.py index e69de29bb2d..e03efcd7a5e 100644 --- a/tests/units/computation_backends/jax_backend/test_basics.py +++ b/tests/units/computation_backends/jax_backend/test_basics.py @@ -0,0 +1,139 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +from docarray.computation.jax_backend import JaxCompBackend +from docarray.typing import JaxArray + + +@pytest.mark.tensorflow +@pytest.mark.parametrize( + 'shape,result', + [ + ((5), 1), + ((1, 5), 2), + ((5, 5), 2), + ((), 0), + ], +) +def test_n_dim(shape, result): + array = JaxArray(jnp.zeros(shape)) + assert JaxCompBackend.n_dim(array) == result + + +@pytest.mark.tensorflow +@pytest.mark.parametrize( + 'shape,result', + [ + ((10,), (10,)), + ((5, 5), (5, 5)), + ((), ()), + ], +) +def test_shape(shape, result): + array = JaxArray(jnp.zeros(shape)) + shape = JaxCompBackend.shape(array) + assert shape == result + assert type(shape) == tuple + + +@pytest.mark.tensorflow +def test_to_device(): + array = JaxArray(jnp.constant([1, 2, 3])) + array = JaxCompBackend.to_device(array, 'CPU:0') + assert array.tensor.device.endswith('CPU:0') + + +@pytest.mark.tensorflow +@pytest.mark.parametrize( + 'dtype,result_type', + [ + ('int64', 'int64'), + ('float64', 'float64'), + ('int8', 'int8'), + ('double', 'float64'), + ], +) +def test_dtype(dtype, result_type): + array = JaxArray(jnp.constant([1, 2, 3], dtype=getattr(jnp, dtype))) + assert JaxCompBackend.dtype(array) == result_type + + +@pytest.mark.tensorflow +def test_empty(): + array = JaxCompBackend.empty((10, 3)) + assert array.tensor.shape == (10, 3) + + +@pytest.mark.tensorflow +def test_empty_dtype(): + tf_tensor = JaxCompBackend.empty((10, 3), dtype=jnp.int32) + assert tf_tensor.tensor.shape == (10, 3) + assert tf_tensor.tensor.dtype == jnp.int32 + + +@pytest.mark.tensorflow +def test_empty_device(): + tensor = JaxCompBackend.empty((10, 3), device='CPU:0') + assert tensor.tensor.shape == (10, 3) + assert tensor.tensor.device.endswith('CPU:0') + + +@pytest.mark.tensorflow +def test_squeeze(): + tensor = JaxArray(jnp.zeros(shape=(1, 1, 3, 1))) + squeezed = JaxCompBackend.squeeze(tensor) + assert squeezed.tensor.shape == (3,) + + +@pytest.mark.tensorflow +@pytest.mark.parametrize( + 'data_input,t_range,x_range,data_result', + [ + ( + [0, 1, 2, 3, 4, 5], + (0, 10), + None, + [0, 2, 4, 6, 8, 10], + ), + ( + [0, 1, 2, 3, 4, 5], + (0, 10), + (0, 10), + [0, 1, 2, 3, 4, 5], + ), + ( + [[0.0, 1.0], [0.0, 1.0]], + (0, 10), + None, + [[0.0, 10.0], [0.0, 10.0]], + ), + ], +) +def test_minmax_normalize(data_input, t_range, x_range, data_result): + array = JaxArray(jnp.constant(data_input)) + output = JaxCompBackend.minmax_normalize( + tensor=array, t_range=t_range, x_range=x_range + ) + assert np.allclose(output.tensor, jnp.constant(data_result)) + + +@pytest.mark.tensorflow +def test_reshape(): + tensor = JaxArray(jnp.zeros((3, 224, 224))) + reshaped = JaxCompBackend.reshape(tensor, (224, 224, 3)) + assert reshaped.tensor.shape == (224, 224, 3) + + +@pytest.mark.tensorflow +def test_stack(): + t0 = JaxArray(jnp.zeros((3, 224, 224))) + t1 = JaxArray(jnp.ones((3, 224, 224))) + + stacked1 = JaxCompBackend.stack([t0, t1], dim=0) + assert isinstance(stacked1, JaxArray) + assert stacked1.tensor.shape == (2, 3, 224, 224) + + stacked2 = JaxCompBackend.stack([t0, t1], dim=-1) + assert isinstance(stacked2, JaxArray) + assert stacked2.tensor.shape == (3, 224, 224, 2) diff --git a/tests/units/typing/tensor/test_jax_array.py b/tests/units/typing/tensor/test_jax_array.py index e69de29bb2d..b44494d51e1 100644 --- a/tests/units/typing/tensor/test_jax_array.py +++ b/tests/units/typing/tensor/test_jax_array.py @@ -0,0 +1,181 @@ +import jax.numpy as jnp +import numpy as np +import pytest +from jax._src.core import InconclusiveDimensionOperation +from pydantic import schema_json_of +from pydantic.tools import parse_obj_as + +from docarray.base_doc.io.json import orjson_dumps +from docarray.typing import JaxArray + + +def test_proto_tensor(): + from docarray.proto.pb2.docarray_pb2 import NdArrayProto + + tensor = parse_obj_as(JaxArray, jnp.zeros((3, 224, 224))) + proto = tensor.to_protobuf() + assert isinstance(proto, NdArrayProto) + + from_proto = JaxArray.from_protobuf(proto) + assert isinstance(from_proto, JaxArray) + assert jnp.allclose(tensor.tensor, from_proto.tensor) + + +def test_json_schema(): + schema_json_of(JaxArray) + + +def test_dump_json(): + tensor = parse_obj_as(JaxArray, jnp.zeros((3, 224, 224))) + orjson_dumps(tensor) + + +def test_unwrap(): + tf_tensor = parse_obj_as(JaxArray, jnp.zeros((3, 224, 224))) + unwrapped = tf_tensor.unwrap() + + assert not isinstance(unwrapped, JaxArray) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(unwrapped, jnp.ndarray) + + assert np.allclose(unwrapped, np.zeros((3, 224, 224))) + + +def test_from_ndarray(): + nd = np.array([1, 2, 3]) + tensor = JaxArray.from_ndarray(nd) + assert isinstance(tensor, JaxArray) + assert isinstance(tensor.tensor, jnp.ndarray) + + +def test_ellipsis_in_shape(): + # ellipsis in the end, two extra dimensions needed + tf_tensor = parse_obj_as(JaxArray[3, ...], jnp.zeros((3, 128, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 128, 224) + + # ellipsis in the beginning, two extra dimensions needed + tf_tensor = parse_obj_as(JaxArray[..., 224], jnp.zeros((3, 128, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 128, 224) + + # more than one ellipsis in the shape + with pytest.raises(ValueError): + parse_obj_as(JaxArray[3, ..., 128, ...], jnp.zeros((3, 128, 224))) + + # wrong shape + with pytest.raises(ValueError): + parse_obj_as(JaxArray[3, 224, ...], jnp.zeros((3, 128, 224))) + + +def test_parametrized(): + # correct shape, single axis + tf_tensor = parse_obj_as(JaxArray[128], jnp.zeros(128)) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (128,) + + # correct shape, multiple axis + tf_tensor = parse_obj_as(JaxArray[3, 224, 224], jnp.zeros((3, 224, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 224, 224) + + # wrong but reshapable shape + tf_tensor = parse_obj_as(JaxArray[3, 224, 224], jnp.zeros((224, 3, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 224, 224) + + # wrong and not reshapable shape + with pytest.raises(InconclusiveDimensionOperation): + parse_obj_as(JaxArray[3, 224, 224], jnp.zeros((224, 224))) + + +def test_parametrized_with_str(): + # test independent variable dimensions + tf_tensor = parse_obj_as(JaxArray[3, 'x', 'y'], jnp.zeros((3, 224, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 224, 224) + + tf_tensor = parse_obj_as(JaxArray[3, 'x', 'y'], jnp.zeros((3, 60, 128))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 60, 128) + + with pytest.raises(ValueError): + parse_obj_as(JaxArray[3, 'x', 'y'], jnp.zeros((4, 224, 224))) + + with pytest.raises(ValueError): + parse_obj_as(JaxArray[3, 'x', 'y'], jnp.zeros((100, 1))) + + # test dependent variable dimensions + tf_tensor = parse_obj_as(JaxArray[3, 'x', 'x'], jnp.zeros((3, 224, 224))) + assert isinstance(tf_tensor, JaxArray) + assert isinstance(tf_tensor.tensor, jnp.ndarray) + assert tf_tensor.tensor.shape == (3, 224, 224) + + with pytest.raises(ValueError): + _ = parse_obj_as(JaxArray[3, 'x', 'x'], jnp.zeros((3, 60, 128))) + + with pytest.raises(ValueError): + _ = parse_obj_as(JaxArray[3, 'x', 'x'], jnp.zeros((3, 60))) + + +@pytest.mark.parametrize('shape', [(3, 224, 224), (224, 224, 3)]) +def test_parameterized_tensor_class_name(shape): + MyTFT = JaxArray[3, 224, 224] + tensor = parse_obj_as(MyTFT, jnp.zeros(shape)) + + assert MyTFT.__name__ == 'JaxArray[3, 224, 224]' + assert MyTFT.__qualname__ == 'JaxArray[3, 224, 224]' + + assert tensor.__class__.__name__ == 'JaxArray' + assert tensor.__class__.__qualname__ == 'JaxArray' + assert f'{tensor.tensor[0][0][0]}' == '0.0' + + +def test_parametrized_subclass(): + c1 = JaxArray[128] + c2 = JaxArray[128] + assert issubclass(c1, c2) + assert issubclass(c1, JaxArray) + + assert not issubclass(c1, JaxArray[256]) + + +def test_parametrized_instance(): + t = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + assert isinstance(t, JaxArray[128]) + assert isinstance(t, JaxArray) + # assert isinstance(t, jnp.ndarray) + + assert not isinstance(t, JaxArray[256]) + assert not isinstance(t, JaxArray[2, 128]) + assert not isinstance(t, JaxArray[2, 2, 64]) + + +def test_parametrized_equality(): + t1 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + t2 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + assert jnp.allclose(t1.tensor, t2.tensor) + + +def test_parametrized_operations(): + t1 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + t2 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) + t_result = t1.tensor + t2.tensor + assert isinstance(t_result, jnp.ndarray) + assert not isinstance(t_result, JaxArray) + assert not isinstance(t_result, JaxArray[128]) + + +def test_set_item(): + t = JaxArray(tensor=jnp.zeros((3, 224, 224))) + t[0] = jnp.ones((1, 224, 224)) + assert jnp.allclose(t.tensor[0], jnp.ones((1, 224, 224))) + assert jnp.allclose(t.tensor[1], jnp.zeros((1, 224, 224))) + assert jnp.allclose(t.tensor[2], jnp.zeros((1, 224, 224))) From 49b3764b24ea2331aa3f4735d7918add4f69593a Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Fri, 16 Jun 2023 14:07:42 +0530 Subject: [PATCH 07/25] feat: JaxCompBackend tests complete till nested Retrieval class Signed-off-by: agaraman0 --- docarray/computation/jax_backend.py | 89 ++++++++++++++----- .../jax_backend/test_basics.py | 20 +++-- .../jax_backend/test_retrieval.py | 57 ++++++++++++ 3 files changed, 134 insertions(+), 32 deletions(-) diff --git a/docarray/computation/jax_backend.py b/docarray/computation/jax_backend.py index 18fec781cf1..1ad1397e329 100644 --- a/docarray/computation/jax_backend.py +++ b/docarray/computation/jax_backend.py @@ -11,6 +11,8 @@ if TYPE_CHECKING: pass +jax.config.update("jax_enable_x64", True) + def _unsqueeze_if_single_axis(*matrices) -> List[jnp.ndarray]: """Unsqueezes tensors that only have one axis, at dim 0. @@ -45,23 +47,29 @@ class JaxCompBackend(AbstractNumpyBasedBackend): _get_tensor: Callable = norm_right @classmethod - def to_device(cls, tensor: 'jax.numpy.array', device: str) -> 'jax.numpy.array': + def to_device(cls, tensor: 'JaxArray', device: str) -> 'JaxArray': """Move the tensor to the specified device.""" - raise NotImplementedError('Numpy does not support devices (GPU).') + if cls.device(tensor) == device: + return tensor + else: + jax_devices = jax.devices(device) + return cls._cast_output( + jax.device_put(cls._get_tensor(tensor), jax_devices) + ) @classmethod - def device(cls, tensor: 'jax.numpy.array') -> Optional[str]: + def device(cls, tensor: 'JaxArray') -> Optional[str]: """Return device on which the tensor is allocated.""" - return None + return cls._get_tensor(tensor).device().platform @classmethod def to_numpy(cls, array: 'jax.numpy.array') -> 'np.ndarray': - return array + return np.array(cls._get_tensor(array)) @classmethod def none_value(cls) -> Any: """Provide a compatible value that represents None in numpy.""" - return None + return jnp.nan @classmethod def detach(cls, tensor: 'jax.numpy.array') -> 'jax.numpy.array': @@ -71,17 +79,18 @@ def detach(cls, tensor: 'jax.numpy.array') -> 'jax.numpy.array': :param tensor: tensor to be detached :return: a detached tensor with the same data. """ - pass + return cls._cast_output(jax.lax.stop_gradient(cls._get_tensor(tensor))) @classmethod - def dtype(cls, tensor: 'jax.numpy.array') -> np.dtype: + def dtype(cls, tensor: 'JaxArray') -> np.dtype: """Get the data type of the tensor.""" - pass + d_type = cls._get_tensor(tensor).dtype + return d_type.name @classmethod def minmax_normalize( cls, - tensor: 'jax.numpy.array', + tensor: 'JaxArray', t_range: Tuple = (0, 1), x_range: Optional[Tuple] = None, eps: float = 1e-7, @@ -104,7 +113,16 @@ def minmax_normalize( :param eps: a small jitter to avoid divide by zero :return: normalized data in `t_range` """ - pass + a, b = t_range + + t = jnp.asarray(cls._get_tensor(tensor), jnp.float32) + + min_d = x_range[0] if x_range else jnp.min(t, axis=-1, keepdims=True) + max_d = x_range[1] if x_range else jnp.max(t, axis=-1, keepdims=True) + r = (b - a) * (t - min_d) / (max_d - min_d + eps) + a + + normalized = jnp.clip(r, *((a, b) if a < b else (b, a))) + return cls._cast_output(jnp.asarray(normalized, cls._get_tensor(tensor).dtype)) class Retrieval(AbstractComputationalBackend.Retrieval[jax.numpy.array]): """ @@ -113,11 +131,11 @@ class Retrieval(AbstractComputationalBackend.Retrieval[jax.numpy.array]): @staticmethod def top_k( - values: 'jax.numpy.array', + values: 'JaxArray', k: int, descending: bool = False, device: Optional[str] = None, - ) -> Tuple['jax.numpy.array', 'jax.numpy.array']: + ) -> Tuple['JaxArray', 'JaxArray']: """ Retrieves the top k smallest values in `values`, and returns them alongside their indices in the input `values`. @@ -134,7 +152,32 @@ def top_k( :return: Tuple containing the retrieved values, and their indices. Both ar of shape (n_queries, k) """ - pass + comp_be = JaxCompBackend + if device is not None: + values = comp_be.to_device(values, device) + + values: jnp.ndarray = comp_be._get_tensor(values) + + if len(values.shape) == 1: + values = jnp.expand_dims(values, axis=0) + + if descending: + values = -values + + if k >= values.shape[1]: + idx = values.argsort(axis=1)[:, :k] + values = jnp.take_along_axis(values, idx, axis=1) + else: + idx_ps = values.argpartition(kth=k, axis=1)[:, :k] + values = jnp.take_along_axis(values, idx_ps, axis=1) + idx_fs = values.argsort(axis=1) + idx = jnp.take_along_axis(idx_ps, idx_fs, axis=1) + values = jnp.take_along_axis(values, idx_fs, axis=1) + + if descending: + values = -values + + return comp_be._cast_output(values), comp_be._cast_output(idx) class Metrics(AbstractComputationalBackend.Metrics[jax.numpy.array]): """ @@ -143,11 +186,11 @@ class Metrics(AbstractComputationalBackend.Metrics[jax.numpy.array]): @staticmethod def cosine_sim( - x_mat: jax.numpy.array, - y_mat: jax.numpy.array, + x_mat: 'JaxArray', + y_mat: 'JaxArray', eps: float = 1e-7, device: Optional[str] = None, - ) -> jax.numpy.array: + ) -> 'JaxArray': """Pairwise cosine similarities between all vectors in x_mat and y_mat. :param x_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is @@ -168,10 +211,10 @@ def cosine_sim( @classmethod def euclidean_dist( cls, - x_mat: jax.numpy.array, - y_mat: jax.numpy.array, + x_mat: 'JaxArray', + y_mat: 'JaxArray', device: Optional[str] = None, - ) -> jax.numpy.array: + ) -> 'JaxArray': """Pairwise Euclidian distances between all vectors in x_mat and y_mat. :param x_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is @@ -191,10 +234,10 @@ def euclidean_dist( @staticmethod def sqeuclidean_dist( - x_mat: jax.numpy.array, - y_mat: jax.numpy.array, + x_mat: 'JaxArray', + y_mat: 'JaxArray', device: Optional[str] = None, - ) -> jax.numpy.array: + ) -> 'JaxArray': """Pairwise Squared Euclidian distances between all vectors in x_mat and y_mat. diff --git a/tests/units/computation_backends/jax_backend/test_basics.py b/tests/units/computation_backends/jax_backend/test_basics.py index e03efcd7a5e..6cd64a19602 100644 --- a/tests/units/computation_backends/jax_backend/test_basics.py +++ b/tests/units/computation_backends/jax_backend/test_basics.py @@ -1,10 +1,12 @@ +import jax import jax.numpy as jnp -import numpy as np import pytest from docarray.computation.jax_backend import JaxCompBackend from docarray.typing import JaxArray +jax.config.update("jax_enable_x64", True) + @pytest.mark.tensorflow @pytest.mark.parametrize( @@ -39,9 +41,9 @@ def test_shape(shape, result): @pytest.mark.tensorflow def test_to_device(): - array = JaxArray(jnp.constant([1, 2, 3])) - array = JaxCompBackend.to_device(array, 'CPU:0') - assert array.tensor.device.endswith('CPU:0') + array = JaxArray(jnp.zeros((3))) + array = JaxCompBackend.to_device(array, 'cpu') + assert array.tensor.device().platform.endswith('cpu') @pytest.mark.tensorflow @@ -55,7 +57,7 @@ def test_to_device(): ], ) def test_dtype(dtype, result_type): - array = JaxArray(jnp.constant([1, 2, 3], dtype=getattr(jnp, dtype))) + array = JaxArray(jnp.array([1, 2, 3], dtype=dtype)) assert JaxCompBackend.dtype(array) == result_type @@ -74,9 +76,9 @@ def test_empty_dtype(): @pytest.mark.tensorflow def test_empty_device(): - tensor = JaxCompBackend.empty((10, 3), device='CPU:0') + tensor = JaxCompBackend.empty((10, 3), device='cpu') assert tensor.tensor.shape == (10, 3) - assert tensor.tensor.device.endswith('CPU:0') + assert tensor.tensor.device().platform.endswith('cpu') @pytest.mark.tensorflow @@ -111,11 +113,11 @@ def test_squeeze(): ], ) def test_minmax_normalize(data_input, t_range, x_range, data_result): - array = JaxArray(jnp.constant(data_input)) + array = JaxArray(jnp.array(data_input)) output = JaxCompBackend.minmax_normalize( tensor=array, t_range=t_range, x_range=x_range ) - assert np.allclose(output.tensor, jnp.constant(data_result)) + assert jnp.allclose(output.tensor, jnp.array(data_result)) @pytest.mark.tensorflow diff --git a/tests/units/computation_backends/jax_backend/test_retrieval.py b/tests/units/computation_backends/jax_backend/test_retrieval.py index e69de29bb2d..a1bb686083e 100644 --- a/tests/units/computation_backends/jax_backend/test_retrieval.py +++ b/tests/units/computation_backends/jax_backend/test_retrieval.py @@ -0,0 +1,57 @@ +import jax.numpy as jnp +import pytest + +from docarray.computation.jax_backend import JaxCompBackend +from docarray.typing import JaxArray + + +@pytest.mark.tensorflow +def test_top_k_descending_false(): + top_k = JaxCompBackend.Retrieval.top_k + + a = JaxArray(jnp.array([1, 4, 2, 7, 4, 9, 2])) + vals, indices = top_k(a, 3, descending=False) + + assert vals.tensor.shape == (1, 3) + assert indices.tensor.shape == (1, 3) + assert jnp.allclose(jnp.squeeze(vals.tensor), jnp.array([1, 2, 2])) + assert jnp.allclose(jnp.squeeze(indices.tensor), jnp.array([0, 2, 6])) or ( + jnp.allclose(jnp.squeeze.indices.tensor), + jnp.array([0, 6, 2]), + ) + + a = JaxArray(jnp.array([[1, 4, 2, 7, 4, 9, 2], [11, 6, 2, 7, 3, 10, 4]])) + vals, indices = top_k(a, 3, descending=False) + assert vals.tensor.shape == (2, 3) + assert indices.tensor.shape == (2, 3) + assert jnp.allclose(vals.tensor[0], jnp.array([1, 2, 2])) + assert jnp.allclose(indices.tensor[0], jnp.array([0, 2, 6])) or jnp.allclose( + indices.tensor[0], jnp.array([0, 6, 2]) + ) + assert jnp.allclose(vals.tensor[1], jnp.array([2, 3, 4])) + assert jnp.allclose(indices.tensor[1], jnp.array([2, 4, 6])) + + +@pytest.mark.tensorflow +def test_top_k_descending_true(): + top_k = JaxCompBackend.Retrieval.top_k + + a = JaxArray(jnp.array([1, 4, 2, 7, 4, 9, 2])) + vals, indices = top_k(a, 3, descending=True) + + assert vals.tensor.shape == (1, 3) + assert indices.tensor.shape == (1, 3) + assert jnp.allclose(jnp.squeeze(vals.tensor), jnp.array([9, 7, 4])) + assert jnp.allclose(jnp.squeeze(indices.tensor), jnp.array([5, 3, 1])) + + a = JaxArray(jnp.array([[1, 4, 2, 7, 4, 9, 2], [11, 6, 2, 7, 3, 10, 4]])) + vals, indices = top_k(a, 3, descending=True) + + assert vals.tensor.shape == (2, 3) + assert indices.tensor.shape == (2, 3) + + assert jnp.allclose(vals.tensor[0], jnp.array([9, 7, 4])) + assert jnp.allclose(indices.tensor[0], jnp.array([5, 3, 1])) + + assert jnp.allclose(vals.tensor[1], jnp.array([11, 10, 7])) + assert jnp.allclose(indices.tensor[1], jnp.array([0, 5, 3])) From 23589d6aec79abad11cb5fd4d62926b32d643ba1 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Tue, 20 Jun 2023 11:15:39 +0530 Subject: [PATCH 08/25] fix: isort format fix Signed-off-by: agaraman0 --- docarray/computation/jax_backend.py | 156 ++++++++++++++---- .../jax_backend/test_metrics.py | 69 ++++++++ 2 files changed, 194 insertions(+), 31 deletions(-) diff --git a/docarray/computation/jax_backend.py b/docarray/computation/jax_backend.py index 1ad1397e329..d08a7bbd766 100644 --- a/docarray/computation/jax_backend.py +++ b/docarray/computation/jax_backend.py @@ -11,8 +11,6 @@ if TYPE_CHECKING: pass -jax.config.update("jax_enable_x64", True) - def _unsqueeze_if_single_axis(*matrices) -> List[jnp.ndarray]: """Unsqueezes tensors that only have one axis, at dim 0. @@ -22,11 +20,52 @@ def _unsqueeze_if_single_axis(*matrices) -> List[jnp.ndarray]: :return: List of the input matrices, where single axis matrices are unsqueezed at dim 0. """ - pass + unsqueezed = [] + for m in matrices: + if len(m.shape) == 1: + unsqueezed.append(jnp.expand_dims(m, axis=0)) + else: + unsqueezed.append(m) + return unsqueezed def _unsqueeze_if_scalar(t): - pass + """ + Unsqueezes tensor of a scalar, from shape () to shape (1,). + + :param t: tensor to unsqueeze. + :return: unsqueezed tf.Tensor + """ + if len(t.shape) == 0: # avoid scalar output + t = jnp.expand_dims(t, 0) + return t + + +def _expand_if_single_axis(*matrices: jnp.ndarray) -> List[jnp.ndarray]: + """Expands arrays that only have one axis, at dim 0. + This ensures that all outputs can be treated as matrices, not vectors. + + :param matrices: Matrices to be expanded + :return: List of the input matrices, + where single axis matrices are expanded at dim 0. + """ + expanded = [] + for m in matrices: + if len(m.shape) == 1: + expanded.append(jnp.expand_dims(m, axis=0)) + else: + expanded.append(m) + return expanded + + +def _expand_if_scalar(arr: jnp.ndarray) -> jnp.ndarray: + if len(arr.shape) == 0: # avoid scalar output + arr = jnp.expand_dims(arr, axis=0) + return arr + + +def identity(array: jnp.ndarray) -> jnp.ndarray: + return array def norm_left(t: jnp.ndarray) -> JaxArray: @@ -179,7 +218,7 @@ def top_k( return comp_be._cast_output(values), comp_be._cast_output(idx) - class Metrics(AbstractComputationalBackend.Metrics[jax.numpy.array]): + class Metrics(AbstractComputationalBackend.Metrics[jnp.ndarray]): """ Abstract base class for metrics (distances and similarities). """ @@ -193,63 +232,118 @@ def cosine_sim( ) -> 'JaxArray': """Pairwise cosine similarities between all vectors in x_mat and y_mat. - :param x_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is - the number of vectors and n_dim is the number of dimensions of each - example. - :param y_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is - the number of vectors and n_dim is the number of dimensions of each - example. - :param eps: a small jitter to avoid divde by zero - :param device: Not supported for this backend - :return: jax.numpy.array of shape (n_vectors, n_vectors) containing all - pairwise cosine distances. + :param x_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the + number of vectors and n_dim is the number of dimensions of each example. + :param y_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the + number of vectors and n_dim is the number of dimensions of each example. + :param eps: a small jitter to avoid divide by zero + :param device: the device to use for computations. + If not provided, the devices of x_mat and y_mat are used. + :return: Tensor of shape (n_vectors, n_vectors) containing all pairwise + cosine distances. The index [i_x, i_y] contains the cosine distance between x_mat[i_x] and y_mat[i_y]. """ - pass + comp_be = JaxCompBackend + x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat) + y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat) + + x_mat_jax, y_mat_jax = _unsqueeze_if_single_axis(x_mat_jax, y_mat_jax) + + sims = jnp.clip( + (jnp.dot(x_mat_jax, y_mat_jax.T) + eps) + / ( + jnp.outer( + jnp.linalg.norm(x_mat_jax, axis=1), + jnp.linalg.norm(y_mat_jax, axis=1), + ) + + eps + ), + -1, + 1, + ).squeeze() + sims = _unsqueeze_if_scalar(sims) + + return comp_be._cast_output(sims) @classmethod def euclidean_dist( - cls, - x_mat: 'JaxArray', - y_mat: 'JaxArray', - device: Optional[str] = None, - ) -> 'JaxArray': + cls, x_mat: jnp.ndarray, y_mat: jnp.ndarray, device: Optional[str] = None + ) -> JaxArray: """Pairwise Euclidian distances between all vectors in x_mat and y_mat. - :param x_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + :param x_mat: np.ndarray of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. - :param y_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + :param y_mat: np.ndarray of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. :param eps: a small jitter to avoid divde by zero :param device: Not supported for this backend - :return: jax.numpy.array of shape (n_vectors, n_vectors) containing all + :return: np.ndarray of shape (n_vectors, n_vectors) containing all pairwise euclidian distances. The index [i_x, i_y] contains the euclidian distance between x_mat[i_x] and y_mat[i_y]. """ - pass + comp_be = JaxCompBackend + x_mat: jnp.ndarray = comp_be._get_tensor(x_mat) + y_mat: jnp.ndarray = comp_be._get_tensor(y_mat) + if device is not None: + # warnings.warn('`device` is not supported for numpy operations') + pass + + x_mat, y_mat = _expand_if_single_axis(x_mat, y_mat) + + x_mat = comp_be._cast_output(x_mat) + y_mat = comp_be._cast_output(y_mat) + + dists = _expand_if_scalar( + jnp.sqrt( + comp_be._get_tensor(cls.sqeuclidean_dist(x_mat, y_mat)) + ).squeeze() + ) + + return comp_be._cast_output(dists) @staticmethod def sqeuclidean_dist( - x_mat: 'JaxArray', - y_mat: 'JaxArray', + x_mat: jnp.ndarray, + y_mat: jnp.ndarray, device: Optional[str] = None, - ) -> 'JaxArray': + ) -> JaxArray: """Pairwise Squared Euclidian distances between all vectors in x_mat and y_mat. - :param x_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + :param x_mat: np.ndarray of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. - :param y_mat: jax.numpy.array of shape (n_vectors, n_dim), where n_vectors is + :param y_mat: np.ndarray of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. :param device: Not supported for this backend - :return: jax.numpy.array of shape (n_vectors, n_vectors) containing all + :return: np.ndarray of shape (n_vectors, n_vectors) containing all pairwise Squared Euclidian distances. The index [i_x, i_y] contains the cosine Squared Euclidian between x_mat[i_x] and y_mat[i_y]. """ + comp_be = JaxCompBackend + x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat) + y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat) + eps: float = 1e-7 # avoid problems with numerical inaccuracies + + if device is not None: + pass + # warnings.warn('`device` is not supported for numpy operations') + + x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax) + + dists = ( + jnp.sum(y_mat_jax**2, axis=1) + + jnp.sum(x_mat_jax**2, axis=1)[:, jnp.newaxis] + - 2 * jnp.dot(x_mat_jax, y_mat_jax.T) + ).squeeze() + + # remove numerical artifacts + dists = jnp.where(np.logical_and(dists < 0, dists > -eps), 0, dists) + dists = _expand_if_scalar(dists) + return comp_be._cast_output(dists) diff --git a/tests/units/computation_backends/jax_backend/test_metrics.py b/tests/units/computation_backends/jax_backend/test_metrics.py index e69de29bb2d..b3134a6096f 100644 --- a/tests/units/computation_backends/jax_backend/test_metrics.py +++ b/tests/units/computation_backends/jax_backend/test_metrics.py @@ -0,0 +1,69 @@ +import jax +import jax.numpy as jnp + +from docarray.computation.jax_backend import JaxCompBackend +from docarray.typing import JaxArray + +metrics = JaxCompBackend.Metrics + + +def test_cosine_sim_jax(): + a = JaxArray(jax.random.uniform(jax.random.PRNGKey(0), shape=(128,))) + b = JaxArray(jax.random.uniform(jax.random.PRNGKey(1), shape=(128,))) + assert metrics.cosine_sim(a, b).tensor.shape == (1,) + assert metrics.cosine_sim(a, b).tensor == metrics.cosine_sim(b, a).tensor + + assert jnp.allclose(metrics.cosine_sim(a, a).tensor, jnp.ones((1,))) + + a = JaxArray(jax.random.uniform(jax.random.PRNGKey(2), shape=(10, 3))) + b = JaxArray(jax.random.uniform(jax.random.PRNGKey(3), shape=(5, 3))) + assert metrics.cosine_sim(a, b).tensor.shape == (10, 5) + assert metrics.cosine_sim(b, a).tensor.shape == (5, 10) + diag_dists = jnp.diagonal(metrics.cosine_sim(b, b).tensor) # self-comparisons + assert jnp.allclose(diag_dists, jnp.ones((5,))) + + +def test_euclidean_dist_jax(): + a = JaxArray(jax.random.normal(jax.random.PRNGKey(0), shape=(128,))) + b = JaxArray(jax.random.normal(jax.random.PRNGKey(1), shape=(128,))) + assert metrics.euclidean_dist(a, b).tensor.shape == (1,) + assert jnp.allclose( + metrics.euclidean_dist(a, b).tensor, metrics.euclidean_dist(b, a).tensor + ) + + assert jnp.allclose(metrics.euclidean_dist(a, a).tensor, jnp.zeros((1,))) + + a = JaxArray(jnp.zeros((1, 1))) + b = JaxArray(jnp.ones((4, 1))) + assert metrics.euclidean_dist(a, b).tensor.shape == (4,) + assert jnp.allclose( + metrics.euclidean_dist(a, b).tensor, metrics.euclidean_dist(b, a).tensor + ) + assert jnp.allclose(metrics.euclidean_dist(a, a).tensor, jnp.zeros((1,))) + + a = JaxArray(jnp.array([0.0, 2.0, 0.0])) + b = JaxArray(jnp.array([0.0, 0.0, 2.0])) + desired_output_singleton = jnp.sqrt(jnp.array([2.0**2.0 + 2.0**2.0])) + assert jnp.allclose(metrics.euclidean_dist(a, b).tensor, desired_output_singleton) + + a = JaxArray(jnp.array([[0.0, 2.0, 0.0], [0.0, 0.0, 2.0]])) + b = JaxArray(jnp.array([[0.0, 0.0, 2.0], [0.0, 2.0, 0.0]])) + desired_output_singleton = jnp.array([[2.828427, 0.0], [0.0, 2.828427]]) + + assert jnp.allclose(metrics.euclidean_dist(a, b).tensor, desired_output_singleton) + + +def test_sqeuclidea_dist_jnp(): + a = JaxArray(jax.random.uniform(jax.random.PRNGKey(0), shape=(128,))) + b = JaxArray(jax.random.uniform(jax.random.PRNGKey(1), shape=(128,))) + assert metrics.sqeuclidean_dist(a, b).tensor.shape == (1,) + assert jnp.allclose( + metrics.sqeuclidean_dist(a, b).tensor, metrics.euclidean_dist(a, b).tensor ** 2 + ) + + a = JaxArray(jax.random.uniform(jax.random.PRNGKey(2), shape=(10, 3))) + b = JaxArray(jax.random.uniform(jax.random.PRNGKey(3), shape=(5, 3))) + assert metrics.sqeuclidean_dist(a, b).tensor.shape == (10, 5) + assert jnp.allclose( + metrics.sqeuclidean_dist(a, b).tensor, metrics.euclidean_dist(a, b).tensor ** 2 + ) From 8793822ddd92814ed735ad984e6b5b9bb5b59b0a Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Tue, 20 Jun 2023 11:18:05 +0530 Subject: [PATCH 09/25] feat: Jax Added as dependency Signed-off-by: agaraman0 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 02fc1d3b96e..eba967bf112 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ jina-hubble-sdk = {version = ">=0.34.0", optional = true} elastic-transport = {version ="^8.4.0", optional = true } qdrant-client = {version = ">=1.1.4", python = "<3.12", optional = true } redis = {version = "^4.6.0", optional = true} +jax = {version = ">=0.4.10", optional = true} [tool.poetry.extras] proto = ["protobuf", "lz4"] From aff2ff68b1f489899cc48e1ba9d09613634bbd2b Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Tue, 20 Jun 2023 15:59:00 +0530 Subject: [PATCH 10/25] feat: poetry lock added Signed-off-by: agaraman0 --- docarray/typing/tensor/jaxarray.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docarray/typing/tensor/jaxarray.py b/docarray/typing/tensor/jaxarray.py index 49e313c0d9b..59deb384615 100644 --- a/docarray/typing/tensor/jaxarray.py +++ b/docarray/typing/tensor/jaxarray.py @@ -197,3 +197,11 @@ def get_comp_backend() -> 'JaxCompBackend': def __class_getitem__(cls, item: Any, *args, **kwargs): # see here for mypy bug: https://github.com/python/mypy/issues/14123 return AbstractTensor.__class_getitem__.__func__(cls, item) # type: ignore + + @classmethod + def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T: + return cls.from_ndarray(value) + + def _docarray_to_ndarray(self) -> np.ndarray: + """cast itself to a numpy array""" + return self.tensor.__array__() From 3d6f45f76fb4231b74a481d33c77a08db6e30a0b Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Fri, 23 Jun 2023 17:10:22 +0530 Subject: [PATCH 11/25] fix: jax_comp review comments resolved Signed-off-by: agaraman0 --- docarray/computation/jax_backend.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/docarray/computation/jax_backend.py b/docarray/computation/jax_backend.py index d08a7bbd766..4d09a14686a 100644 --- a/docarray/computation/jax_backend.py +++ b/docarray/computation/jax_backend.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import jax import jax.numpy as jnp @@ -8,9 +8,6 @@ from docarray.computation.abstract_numpy_based_backend import AbstractNumpyBasedBackend from docarray.typing import JaxArray -if TYPE_CHECKING: - pass - def _unsqueeze_if_single_axis(*matrices) -> List[jnp.ndarray]: """Unsqueezes tensors that only have one axis, at dim 0. @@ -64,10 +61,6 @@ def _expand_if_scalar(arr: jnp.ndarray) -> jnp.ndarray: return arr -def identity(array: jnp.ndarray) -> jnp.ndarray: - return array - - def norm_left(t: jnp.ndarray) -> JaxArray: return JaxArray(tensor=t) From 3072cea18a1aa3f5e8f05c46aed132ba6949cef4 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Tue, 11 Jul 2023 17:23:37 +0530 Subject: [PATCH 12/25] fix: squashed commits and bypassing jax test cases which fails Signed-off-by: agaraman0 --- .github/workflows/ci.yml | 2 +- docarray/array/doc_vec/doc_vec.py | 28 +- docarray/computation/jax_backend.py | 139 ++++---- docarray/typing/__init__.py | 23 +- docarray/typing/tensor/__init__.py | 19 +- docarray/typing/tensor/audio/__init__.py | 6 +- .../typing/tensor/audio/audio_jax_array.py | 12 + docarray/typing/tensor/audio/audio_tensor.py | 17 +- docarray/typing/tensor/embedding/__init__.py | 4 + docarray/typing/tensor/embedding/embedding.py | 18 +- docarray/typing/tensor/embedding/jax_array.py | 17 + docarray/typing/tensor/image/__init__.py | 4 + .../typing/tensor/image/image_jax_array.py | 10 + docarray/typing/tensor/image/image_tensor.py | 18 +- docarray/typing/tensor/jaxarray.py | 70 +++- docarray/typing/tensor/ndarray.py | 14 +- docarray/typing/tensor/tensor.py | 31 +- docarray/typing/tensor/tensorflow_tensor.py | 12 +- docarray/typing/tensor/torch_tensor.py | 12 +- docarray/typing/tensor/video/__init__.py | 4 + .../typing/tensor/video/video_jax_array.py | 28 ++ docarray/typing/tensor/video/video_tensor.py | 18 +- docarray/utils/_internal/misc.py | 11 + pyproject.toml | 3 +- .../array/stack/test_array_stacked_jax.py | 298 ++++++++++++++++++ .../jax_backend/test_basics.py | 37 ++- .../jax_backend/test_metrics.py | 21 +- .../jax_backend/test_retrieval.py | 19 +- tests/units/typing/tensor/test_jax_array.py | 23 +- 29 files changed, 786 insertions(+), 132 deletions(-) create mode 100644 tests/units/array/stack/test_array_stacked_jax.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 210134ac4ae..3827cf3b958 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -111,7 +111,7 @@ jobs: - name: Test id: test run: | - poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py + poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py echo "flag it as docarray for codeoverage" echo "codecov_flag=docarray" >> $GITHUB_OUTPUT timeout-minutes: 30 diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index 4778cd44604..a175cb4e4aa 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -32,7 +32,11 @@ from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal._typing import is_tensor_union, safe_issubclass -from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils._internal.misc import ( + is_jax_available, + is_tf_available, + is_torch_available, +) if TYPE_CHECKING: import csv @@ -60,6 +64,14 @@ else: TensorFlowTensor = None # type: ignore +jnp_available = is_jax_available() +if jnp_available: + import jax.numpy as jnp # type: ignore + + from docarray.typing import JaxArray # noqa: F401 +else: + JaxArray = None # type: ignore + T_doc = TypeVar('T_doc', bound=BaseDoc) T = TypeVar('T', bound='DocVec') T_io_mixin = TypeVar('T_io_mixin', bound='IOMixinArray') @@ -262,6 +274,19 @@ def _check_doc_field_not_none(field_name, doc): stacked: tf.Tensor = tf.stack(tf_stack) tensor_columns[field_name] = TensorFlowTensor(stacked) + elif jnp_available and issubclass(field_type, JaxArray): + if first_doc_is_none: + _verify_optional_field_of_docs(docs) + tensor_columns[field_name] = None + else: + tf_stack = [] + for i, doc in enumerate(docs): + val = getattr(doc, field_name) + _check_doc_field_not_none(field_name, doc) + tf_stack.append(val.tensor) + + jax_stacked: jnp.ndarray = jnp.stack(tf_stack) + tensor_columns[field_name] = JaxArray(jax_stacked) elif safe_issubclass(field_type, AbstractTensor): if first_doc_is_none: @@ -835,7 +860,6 @@ def to_doc_list(self: T) -> DocList[T_doc]: unstacked_doc_column[field] = doc_col.to_doc_list() if doc_col else None for field, da_col in self._storage.docs_vec_columns.items(): - unstacked_da_column[field] = ( [docs.to_doc_list() for docs in da_col] if da_col else None ) diff --git a/docarray/computation/jax_backend.py b/docarray/computation/jax_backend.py index 4d09a14686a..680f2b90d9c 100644 --- a/docarray/computation/jax_backend.py +++ b/docarray/computation/jax_backend.py @@ -1,41 +1,18 @@ -from typing import Any, Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple -import jax -import jax.numpy as jnp import numpy as np from docarray.computation.abstract_comp_backend import AbstractComputationalBackend from docarray.computation.abstract_numpy_based_backend import AbstractNumpyBasedBackend from docarray.typing import JaxArray +from docarray.utils._internal.misc import import_library - -def _unsqueeze_if_single_axis(*matrices) -> List[jnp.ndarray]: - """Unsqueezes tensors that only have one axis, at dim 0. - This ensures that all outputs can be treated as matrices, not vectors. - - :param matrices: Matrices to be unsqueezed - :return: List of the input matrices, - where single axis matrices are unsqueezed at dim 0. - """ - unsqueezed = [] - for m in matrices: - if len(m.shape) == 1: - unsqueezed.append(jnp.expand_dims(m, axis=0)) - else: - unsqueezed.append(m) - return unsqueezed - - -def _unsqueeze_if_scalar(t): - """ - Unsqueezes tensor of a scalar, from shape () to shape (1,). - - :param t: tensor to unsqueeze. - :return: unsqueezed tf.Tensor - """ - if len(t.shape) == 0: # avoid scalar output - t = jnp.expand_dims(t, 0) - return t +if TYPE_CHECKING: + import jax + import jax.numpy as jnp +else: + jax = import_library('jax', raise_error=True) + jnp = jax.numpy def _expand_if_single_axis(*matrices: jnp.ndarray) -> List[jnp.ndarray]: @@ -71,7 +48,7 @@ def norm_right(t: JaxArray) -> jnp.ndarray: class JaxCompBackend(AbstractNumpyBasedBackend): """ - Computational backend for Numpy. + Computational backend for Jax. """ _module = jnp @@ -95,16 +72,16 @@ def device(cls, tensor: 'JaxArray') -> Optional[str]: return cls._get_tensor(tensor).device().platform @classmethod - def to_numpy(cls, array: 'jax.numpy.array') -> 'np.ndarray': - return np.array(cls._get_tensor(array)) + def to_numpy(cls, array: 'JaxArray') -> 'np.ndarray': + return cls._get_tensor(array).__array__() @classmethod def none_value(cls) -> Any: - """Provide a compatible value that represents None in numpy.""" + """Provide a compatible value that represents None in jax.""" return jnp.nan @classmethod - def detach(cls, tensor: 'jax.numpy.array') -> 'jax.numpy.array': + def detach(cls, tensor: 'JaxArray') -> 'JaxArray': """ Returns the tensor detached from its current graph. @@ -114,7 +91,7 @@ def detach(cls, tensor: 'jax.numpy.array') -> 'jax.numpy.array': return cls._cast_output(jax.lax.stop_gradient(cls._get_tensor(tensor))) @classmethod - def dtype(cls, tensor: 'JaxArray') -> np.dtype: + def dtype(cls, tensor: 'JaxArray') -> jnp.dtype: """Get the data type of the tensor.""" d_type = cls._get_tensor(tensor).dtype return d_type.name @@ -126,7 +103,7 @@ def minmax_normalize( t_range: Tuple = (0, 1), x_range: Optional[Tuple] = None, eps: float = 1e-7, - ) -> 'jax.numpy.array': + ) -> 'JaxArray': """ Normalize values in `tensor` into `t_range`. @@ -156,7 +133,23 @@ def minmax_normalize( normalized = jnp.clip(r, *((a, b) if a < b else (b, a))) return cls._cast_output(jnp.asarray(normalized, cls._get_tensor(tensor).dtype)) - class Retrieval(AbstractComputationalBackend.Retrieval[jax.numpy.array]): + @classmethod + def equal(cls, tensor1: 'JaxArray', tensor2: 'JaxArray') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first tensor + :param tensor2: the second tensor + :return: True if two tensors are equal, False otherwise. + If one or more of the inputs is not a TensorFlowTensor, return False. + """ + t1, t2 = getattr(tensor1, 'tensor', None), getattr(tensor2, 'tensor', None) + if isinstance(t1, jnp.ndarray) and isinstance(t2, jnp.ndarray): + # mypy doesn't know that tf.is_tensor implies that t1, t2 are not None + return t1.shape == t2.shape and jnp.all(jnp.equal(t1, t1)) # type: ignore + return False + + class Retrieval(AbstractComputationalBackend.Retrieval[JaxArray]): """ Abstract class for retrieval and ranking functionalities """ @@ -174,7 +167,7 @@ def top_k( Can also be used to retrieve the top k largest values, by setting the `descending` flag. - :param values: Torch tensor of values to rank. + :param values: Jax tensor of values to rank. Should be of shape (n_queries, n_values_per_query). Inputs of shape (n_values_per_query,) will be expanded to (1, n_values_per_query). @@ -188,30 +181,30 @@ def top_k( if device is not None: values = comp_be.to_device(values, device) - values: jnp.ndarray = comp_be._get_tensor(values) + jax_values: jnp.ndarray = comp_be._get_tensor(values) - if len(values.shape) == 1: - values = jnp.expand_dims(values, axis=0) + if len(jax_values.shape) == 1: + jax_values = jnp.expand_dims(jax_values, axis=0) if descending: - values = -values + jax_values = -jax_values - if k >= values.shape[1]: - idx = values.argsort(axis=1)[:, :k] - values = jnp.take_along_axis(values, idx, axis=1) + if k >= jax_values.shape[1]: + idx = jax_values.argsort(axis=1)[:, :k] + jax_values = jnp.take_along_axis(jax_values, idx, axis=1) else: - idx_ps = values.argpartition(kth=k, axis=1)[:, :k] - values = jnp.take_along_axis(values, idx_ps, axis=1) - idx_fs = values.argsort(axis=1) + idx_ps = jax_values.argpartition(kth=k, axis=1)[:, :k] + jax_values = jnp.take_along_axis(jax_values, idx_ps, axis=1) + idx_fs = jax_values.argsort(axis=1) idx = jnp.take_along_axis(idx_ps, idx_fs, axis=1) - values = jnp.take_along_axis(values, idx_fs, axis=1) + jax_values = jnp.take_along_axis(jax_values, idx_fs, axis=1) if descending: - values = -values + jax_values = -jax_values - return comp_be._cast_output(values), comp_be._cast_output(idx) + return comp_be._cast_output(jax_values), comp_be._cast_output(idx) - class Metrics(AbstractComputationalBackend.Metrics[jnp.ndarray]): + class Metrics(AbstractComputationalBackend.Metrics[JaxArray]): """ Abstract base class for metrics (distances and similarities). """ @@ -232,7 +225,7 @@ def cosine_sim( :param eps: a small jitter to avoid divide by zero :param device: the device to use for computations. If not provided, the devices of x_mat and y_mat are used. - :return: Tensor of shape (n_vectors, n_vectors) containing all pairwise + :return: JaxArray of shape (n_vectors, n_vectors) containing all pairwise cosine distances. The index [i_x, i_y] contains the cosine distance between x_mat[i_x] and y_mat[i_y]. @@ -241,7 +234,7 @@ def cosine_sim( x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat) y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat) - x_mat_jax, y_mat_jax = _unsqueeze_if_single_axis(x_mat_jax, y_mat_jax) + x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax) sims = jnp.clip( (jnp.dot(x_mat_jax, y_mat_jax.T) + eps) @@ -255,44 +248,46 @@ def cosine_sim( -1, 1, ).squeeze() - sims = _unsqueeze_if_scalar(sims) + sims = _expand_if_scalar(sims) return comp_be._cast_output(sims) @classmethod def euclidean_dist( - cls, x_mat: jnp.ndarray, y_mat: jnp.ndarray, device: Optional[str] = None + cls, x_mat: JaxArray, y_mat: JaxArray, device: Optional[str] = None ) -> JaxArray: """Pairwise Euclidian distances between all vectors in x_mat and y_mat. - :param x_mat: np.ndarray of shape (n_vectors, n_dim), where n_vectors is + :param x_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. - :param y_mat: np.ndarray of shape (n_vectors, n_dim), where n_vectors is + :param y_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. :param eps: a small jitter to avoid divde by zero :param device: Not supported for this backend - :return: np.ndarray of shape (n_vectors, n_vectors) containing all + :return: JaxArray of shape (n_vectors, n_vectors) containing all pairwise euclidian distances. The index [i_x, i_y] contains the euclidian distance between x_mat[i_x] and y_mat[i_y]. """ comp_be = JaxCompBackend - x_mat: jnp.ndarray = comp_be._get_tensor(x_mat) - y_mat: jnp.ndarray = comp_be._get_tensor(y_mat) + x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat) + y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat) if device is not None: # warnings.warn('`device` is not supported for numpy operations') pass - x_mat, y_mat = _expand_if_single_axis(x_mat, y_mat) + x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax) - x_mat = comp_be._cast_output(x_mat) - y_mat = comp_be._cast_output(y_mat) + x_mat_jax_arr: JaxArray = comp_be._cast_output(x_mat_jax) + y_mat_jax_arr: JaxArray = comp_be._cast_output(y_mat_jax) dists = _expand_if_scalar( jnp.sqrt( - comp_be._get_tensor(cls.sqeuclidean_dist(x_mat, y_mat)) + comp_be._get_tensor( + cls.sqeuclidean_dist(x_mat_jax_arr, y_mat_jax_arr) + ) ).squeeze() ) @@ -300,21 +295,21 @@ def euclidean_dist( @staticmethod def sqeuclidean_dist( - x_mat: jnp.ndarray, - y_mat: jnp.ndarray, + x_mat: JaxArray, + y_mat: JaxArray, device: Optional[str] = None, ) -> JaxArray: """Pairwise Squared Euclidian distances between all vectors in x_mat and y_mat. - :param x_mat: np.ndarray of shape (n_vectors, n_dim), where n_vectors is + :param x_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. - :param y_mat: np.ndarray of shape (n_vectors, n_dim), where n_vectors is + :param y_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. :param device: Not supported for this backend - :return: np.ndarray of shape (n_vectors, n_vectors) containing all + :return: JaxArray of shape (n_vectors, n_vectors) containing all pairwise Squared Euclidian distances. The index [i_x, i_y] contains the cosine Squared Euclidian between x_mat[i_x] and y_mat[i_y]. diff --git a/docarray/typing/__init__.py b/docarray/typing/__init__.py index 1cd0133c2f8..ed7e1d7b9d2 100644 --- a/docarray/typing/__init__.py +++ b/docarray/typing/__init__.py @@ -5,7 +5,6 @@ from docarray.typing.tensor import ImageNdArray, ImageTensor from docarray.typing.tensor.audio import AudioNdArray, AudioTensor from docarray.typing.tensor.embedding.embedding import AnyEmbedding, NdArrayEmbedding -from docarray.typing.tensor.jaxarray import JaxArray from docarray.typing.tensor.ndarray import NdArray from docarray.typing.tensor.tensor import AnyTensor from docarray.typing.tensor.video import VideoNdArray, VideoTensor @@ -25,12 +24,20 @@ if TYPE_CHECKING: from docarray.typing.tensor import TensorFlowTensor # noqa: F401 - from docarray.typing.tensor import TorchEmbedding, TorchTensor # noqa: F401 + from docarray.typing.tensor import ( # noqa: F401 + JaxArray, + JaxArrayEmbedding, + TorchEmbedding, + TorchTensor, + ) + from docarray.typing.tensor.audio import AudioJaxArray # noqa: F401 from docarray.typing.tensor.audio import AudioTensorFlowTensor # noqa: F401 from docarray.typing.tensor.audio import AudioTorchTensor # noqa: F401 from docarray.typing.tensor.embedding import TensorFlowEmbedding # noqa: F401 + from docarray.typing.tensor.image import ImageJaxArray # noqa: F401 from docarray.typing.tensor.image import ImageTensorFlowTensor # noqa: F401 from docarray.typing.tensor.image import ImageTorchTensor # noqa: F401 + from docarray.typing.tensor.video import VideoJaxArray # noqa: F401 from docarray.typing.tensor.video import VideoTensorFlowTensor # noqa: F401 from docarray.typing.tensor.video import VideoTorchTensor # noqa: F401 @@ -57,7 +64,6 @@ 'ImageBytes', 'VideoBytes', 'AudioBytes', - 'JaxArray', ] @@ -75,6 +81,15 @@ 'AudioTensorFlowTensor', 'VideoTensorFlowTensor', ] + +_jax_tensors = [ + 'JaxArray', + 'JaxArrayEmbedding', + 'VideoJaxArray', + 'AudioJaxArray', + 'ImageJaxArray', +] + __all_test__ = __all__ + _torch_tensors @@ -83,6 +98,8 @@ def __getattr__(name: str): import_library('torch', raise_error=True) elif name in _tf_tensors: import_library('tensorflow', raise_error=True) + elif name in _jax_tensors: + import_library('jax', raise_error=True) else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/__init__.py b/docarray/typing/tensor/__init__.py index 8e8f6653bd6..2da7f5939ec 100644 --- a/docarray/typing/tensor/__init__.py +++ b/docarray/typing/tensor/__init__.py @@ -5,7 +5,6 @@ from docarray.typing.tensor.audio import AudioNdArray from docarray.typing.tensor.embedding import AnyEmbedding, NdArrayEmbedding from docarray.typing.tensor.image import ImageNdArray, ImageTensor -from docarray.typing.tensor.jaxarray import JaxArray from docarray.typing.tensor.ndarray import NdArray from docarray.typing.tensor.tensor import AnyTensor from docarray.typing.tensor.video import VideoNdArray @@ -15,14 +14,19 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.audio import AudioJaxArray # noqa: F401 from docarray.typing.tensor.audio import AudioTensorFlowTensor # noqa: F401 from docarray.typing.tensor.audio import AudioTorchTensor # noqa: F401 + from docarray.typing.tensor.embedding import JaxArrayEmbedding # noqa F401 from docarray.typing.tensor.embedding import TensorFlowEmbedding # noqa: F401 from docarray.typing.tensor.embedding import TorchEmbedding # noqa: F401 + from docarray.typing.tensor.image import ImageJaxArray # noqa: F401 from docarray.typing.tensor.image import ImageTensorFlowTensor # noqa: F401 from docarray.typing.tensor.image import ImageTorchTensor # noqa: F401 + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401 from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401 + from docarray.typing.tensor.video import VideoJaxArray # noqa: F401 from docarray.typing.tensor.video import VideoTensorFlowTensor # noqa: F401 from docarray.typing.tensor.video import VideoTorchTensor # noqa: F401 @@ -35,7 +39,6 @@ 'ImageTensor', 'AudioNdArray', 'VideoNdArray', - 'JaxArray', ] @@ -44,19 +47,23 @@ def __getattr__(name: str): import_library('torch', raise_error=True) elif 'TensorFlow' in name: import_library('tensorflow', raise_error=True) + elif 'Jax' in name: + import_library('jax', raise_error=True) lib: types.ModuleType if name == 'TorchTensor': import docarray.typing.tensor.torch_tensor as lib elif name == 'TensorFlowTensor': import docarray.typing.tensor.tensorflow_tensor as lib - elif name in ['TorchEmbedding', 'TensorFlowEmbedding']: + elif name == 'JaxArray': + import docarray.typing.tensor.jaxarray as lib + elif name in ['TorchEmbedding', 'TensorFlowEmbedding', 'JaxArrayEmbedding']: import docarray.typing.tensor.embedding as lib - elif name in ['ImageTorchTensor', 'ImageTensorFlowTensor']: + elif name in ['ImageTorchTensor', 'ImageTensorFlowTensor', 'ImageJaxArray']: import docarray.typing.tensor.image as lib - elif name in ['AudioTorchTensor', 'AudioTensorFlowTensor']: + elif name in ['AudioTorchTensor', 'AudioTensorFlowTensor', 'AudioJaxArray']: import docarray.typing.tensor.audio as lib - elif name in ['VideoTorchTensor', 'VideoTensorFlowTensor']: + elif name in ['VideoTorchTensor', 'VideoTensorFlowTensor', 'VideoJaxArray']: import docarray.typing.tensor.video as lib else: raise ImportError( diff --git a/docarray/typing/tensor/audio/__init__.py b/docarray/typing/tensor/audio/__init__.py index a505ab05720..5f304ae544f 100644 --- a/docarray/typing/tensor/audio/__init__.py +++ b/docarray/typing/tensor/audio/__init__.py @@ -9,12 +9,13 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.audio.audio_jax_array import AudioJaxArray # noqa from docarray.typing.tensor.audio.audio_tensorflow_tensor import ( # noqa AudioTensorFlowTensor, ) from docarray.typing.tensor.audio.audio_torch_tensor import AudioTorchTensor # noqa -__all__ = ['AudioNdArray', 'AudioTensor'] +__all__ = ['AudioNdArray', 'AudioTensor', 'AudioJaxArray'] def __getattr__(name: str): @@ -25,6 +26,9 @@ def __getattr__(name: str): elif name == 'AudioTensorFlowTensor': import_library('tensorflow', raise_error=True) import docarray.typing.tensor.audio.audio_tensorflow_tensor as lib + elif name == 'AudioJaxArray': + import_library('jax', raise_error=True) + import docarray.typing.tensor.audio.audio_jax_array as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/audio/audio_jax_array.py b/docarray/typing/tensor/audio/audio_jax_array.py index e69de29bb2d..793fd627214 100644 --- a/docarray/typing/tensor/audio/audio_jax_array.py +++ b/docarray/typing/tensor/audio/audio_jax_array.py @@ -0,0 +1,12 @@ +from typing import TypeVar + +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor +from docarray.typing.tensor.jaxarray import JaxArray, metaJax + +T = TypeVar('T', bound='AudioJaxArray') + + +@_register_proto(proto_type_name='audio_jaxarray') +class AudioJaxArray(AbstractAudioTensor, JaxArray, metaclass=metaJax): + ... diff --git a/docarray/typing/tensor/audio/audio_tensor.py b/docarray/typing/tensor/audio/audio_tensor.py index a9171a919b2..56e651b567e 100644 --- a/docarray/typing/tensor/audio/audio_tensor.py +++ b/docarray/typing/tensor/audio/audio_tensor.py @@ -5,7 +5,11 @@ from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor from docarray.typing.tensor.audio.audio_ndarray import AudioNdArray from docarray.typing.tensor.tensor import AnyTensor -from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils._internal.misc import ( + is_jax_available, + is_tf_available, + is_torch_available, +) torch_available = is_torch_available() if torch_available: @@ -23,6 +27,12 @@ ) from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp # type: ignore + + from docarray.typing.tensor.audio.audio_jax_array import AudioJaxArray + from docarray.typing.tensor.jaxarray import JaxArray if TYPE_CHECKING: from pydantic import BaseConfig @@ -91,6 +101,11 @@ def validate( return cast(AudioTensorFlowTensor, value) elif isinstance(value, tf.Tensor): return AudioTensorFlowTensor._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return cast(AudioJaxArray, value) + elif isinstance(value, jnp.ndarray): + return AudioJaxArray._docarray_from_native(value) # noqa try: return AudioNdArray.validate(value, field, config) except Exception: # noqa diff --git a/docarray/typing/tensor/embedding/__init__.py b/docarray/typing/tensor/embedding/__init__.py index c32048b21c6..0e518b67a57 100644 --- a/docarray/typing/tensor/embedding/__init__.py +++ b/docarray/typing/tensor/embedding/__init__.py @@ -10,6 +10,7 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.embedding.jax_array import JaxArrayEmbedding # noqa from docarray.typing.tensor.embedding.tensorflow import TensorFlowEmbedding # noqa from docarray.typing.tensor.embedding.torch import TorchEmbedding # noqa @@ -24,6 +25,9 @@ def __getattr__(name: str): elif name == 'TensorFlowEmbedding': import_library('tensorflow', raise_error=True) import docarray.typing.tensor.embedding.tensorflow as lib + elif name == 'JaxArrayEmbedding': + import_library('jax', raise_error=True) + import docarray.typing.tensor.embedding.jax_array as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/embedding/embedding.py b/docarray/typing/tensor/embedding/embedding.py index b7fd9c462f7..c9bc31dc54a 100644 --- a/docarray/typing/tensor/embedding/embedding.py +++ b/docarray/typing/tensor/embedding/embedding.py @@ -5,7 +5,18 @@ from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding from docarray.typing.tensor.tensor import AnyTensor -from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa +from docarray.utils._internal.misc import ( # noqa + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp # type: ignore + + from docarray.typing.tensor.embedding.jax_array import JaxArrayEmbedding + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 torch_available = is_torch_available() if torch_available: @@ -89,6 +100,11 @@ def validate( return cast(TensorFlowEmbedding, value) elif isinstance(value, tf.Tensor): return TensorFlowEmbedding._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return cast(JaxArrayEmbedding, value) + elif isinstance(value, jnp.ndarray): + return JaxArrayEmbedding._docarray_from_native(value) # noqa try: return NdArrayEmbedding.validate(value, field, config) except Exception: # noqa diff --git a/docarray/typing/tensor/embedding/jax_array.py b/docarray/typing/tensor/embedding/jax_array.py index e69de29bb2d..4dbb7a67ee0 100644 --- a/docarray/typing/tensor/embedding/jax_array.py +++ b/docarray/typing/tensor/embedding/jax_array.py @@ -0,0 +1,17 @@ +from typing import Any # noqa: F401 + +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin +from docarray.typing.tensor.jaxarray import JaxArray + +jax_base = type(JaxArray) # type: Any +embedding_base = type(EmbeddingMixin) # type: Any + + +class metaJaxAndEmbedding(jax_base, embedding_base): + pass + + +@_register_proto(proto_type_name='jaxarray_embedding') +class JaxArrayEmbedding(JaxArray, EmbeddingMixin, metaclass=metaJaxAndEmbedding): + alternative_type = JaxArray diff --git a/docarray/typing/tensor/image/__init__.py b/docarray/typing/tensor/image/__init__.py index 7af4b852206..d62b096c1fe 100644 --- a/docarray/typing/tensor/image/__init__.py +++ b/docarray/typing/tensor/image/__init__.py @@ -10,6 +10,7 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.image.image_jax_array import ImageJaxArray # noqa from docarray.typing.tensor.image.image_tensorflow_tensor import ( # noqa ImageTensorFlowTensor, ) @@ -26,6 +27,9 @@ def __getattr__(name: str): elif name == 'ImageTensorFlowTensor': import_library('tensorflow', raise_error=True) import docarray.typing.tensor.image.image_tensorflow_tensor as lib + elif name == 'ImageJaxArray': + import_library('jax', raise_error=True) + import docarray.typing.tensor.image.image_jax_array as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/image/image_jax_array.py b/docarray/typing/tensor/image/image_jax_array.py index e69de29bb2d..8fabf91ac24 100644 --- a/docarray/typing/tensor/image/image_jax_array.py +++ b/docarray/typing/tensor/image/image_jax_array.py @@ -0,0 +1,10 @@ +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.image.abstract_image_tensor import AbstractImageTensor +from docarray.typing.tensor.jaxarray import JaxArray, metaJax + +MAX_INT_16 = 2**15 + + +@_register_proto(proto_type_name='image_jaxarray') +class ImageJaxArray(JaxArray, AbstractImageTensor, metaclass=metaJax): + ... diff --git a/docarray/typing/tensor/image/image_tensor.py b/docarray/typing/tensor/image/image_tensor.py index ece9f5978ed..3dc58c737c3 100644 --- a/docarray/typing/tensor/image/image_tensor.py +++ b/docarray/typing/tensor/image/image_tensor.py @@ -5,7 +5,18 @@ from docarray.typing.tensor.image.abstract_image_tensor import AbstractImageTensor from docarray.typing.tensor.image.image_ndarray import ImageNdArray from docarray.typing.tensor.tensor import AnyTensor -from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils._internal.misc import ( + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp # type: ignore + + from docarray.typing.tensor.image.image_jax_array import ImageJaxArray + from docarray.typing.tensor.jaxarray import JaxArray torch_available = is_torch_available() if torch_available: @@ -94,6 +105,11 @@ def validate( return cast(ImageTensorFlowTensor, value) elif isinstance(value, tf.Tensor): return ImageTensorFlowTensor._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return cast(ImageJaxArray, value) + elif isinstance(value, jnp.ndarray): + return ImageJaxArray._docarray_from_native(value) # noqa try: return ImageNdArray.validate(value, field, config) except Exception: # noqa diff --git a/docarray/typing/tensor/jaxarray.py b/docarray/typing/tensor/jaxarray.py index 59deb384615..804080b54ec 100644 --- a/docarray/typing/tensor/jaxarray.py +++ b/docarray/typing/tensor/jaxarray.py @@ -1,19 +1,22 @@ from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union, cast -import jax.numpy as jnp import numpy as np -from jax import Array from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal.misc import import_library if TYPE_CHECKING: + import jax + import jax.numpy as jnp from pydantic import BaseConfig from pydantic.fields import ModelField from docarray.computation.jax_backend import JaxCompBackend from docarray.proto import NdArrayProto - +else: + jax = import_library('jax', raise_error=True) + jnp = jax.numpy from docarray.base_doc.base_node import BaseNode T = TypeVar('T', bound='JaxArray') @@ -33,7 +36,62 @@ class metaJax( @_register_proto(proto_type_name='jaxarray') class JaxArray(AbstractTensor, Generic[ShapeT], metaclass=metaJax): - """ """ + """ + Subclass of `jnp.ndarray`, intended for use in a Document. + This enables (de)serialization from/to protobuf and json, data validation, + and coercion from compatible types like `torch.Tensor`. + + This type can also be used in a parametrized way, specifying the shape of the array. + + --- + + ```python + from docarray import BaseDoc + from docarray.typing import JaxArray + import jax.numpy as jnp + + + class MyDoc(BaseDoc): + arr: JaxArray + image_arr: JaxArray[3, 224, 224] + square_crop: JaxArray[3, 'x', 'x'] + random_image: JaxArray[3, ...] # first dimension is fixed, can have arbitrary shape + + + # create a document with tensors + doc = MyDoc( + arr=jnp.zeros((128,)), + image_arr=jnp.zeros((3, 224, 224)), + square_crop=jnp.zeros((3, 64, 64)), + random_image=jnp.zeros((3, 128, 256)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # automatic shape conversion + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) + square_crop=np.zeros((3, 128, 128)), + random_image=np.zeros((3, 64, 128)), + ) + assert doc.image_arr.shape == (3, 224, 224) + + # !! The following will raise an error due to shape mismatch !! + from pydantic import ValidationError + + try: + doc = MyDoc( + arr=np.zeros((128,)), + image_arr=np.zeros((224, 224)), # this will fail validation + square_crop=np.zeros((3, 128, 64)), # this will also fail validation + random_image=np.zeros((4, 64, 128)), # this will also fail validation + ) + except ValidationError as e: + pass + ``` + + --- + """ __parametrized_meta__ = metaJax @@ -75,7 +133,7 @@ def validate( field: 'ModelField', config: 'BaseConfig', ) -> T: - if isinstance(value, Array): + if isinstance(value, jax.Array): return cls._docarray_from_native(value) elif isinstance(value, JaxArray): return cast(T, value) @@ -99,7 +157,7 @@ def _docarray_from_native(cls: Type[T], value: jnp.ndarray) -> T: if cls.__unparametrizedcls__: # None if the tensor is parametrized value.__class__ = cls.__unparametrizedcls__ # type: ignore else: - value.__class__ = cls + value.__class__ = cls # type: ignore return cast(T, value) else: if cls.__unparametrizedcls__: # None if the tensor is parametrized diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index e8935758e42..2f547b55dea 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -5,7 +5,17 @@ from docarray.base_doc.base_node import BaseNode from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor -from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa +from docarray.utils._internal.misc import ( # noqa + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 torch_available = is_torch_available() if torch_available: @@ -124,6 +134,8 @@ def validate( return cls._docarray_from_native(value.detach().cpu().numpy()) elif tf_available and isinstance(value, tf.Tensor): return cls._docarray_from_native(value.numpy()) + elif jax_available and isinstance(value, jnp.ndarray): + return cls._docarray_from_native(value.__array__()) elif isinstance(value, list) or isinstance(value, tuple): try: arr_from_list: np.ndarray = np.asarray(value) diff --git a/docarray/typing/tensor/tensor.py b/docarray/typing/tensor/tensor.py index e8d84bf04a0..2d5be7cd096 100644 --- a/docarray/typing/tensor/tensor.py +++ b/docarray/typing/tensor/tensor.py @@ -4,7 +4,17 @@ from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.typing.tensor.ndarray import NdArray -from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa +from docarray.utils._internal.misc import ( # noqa + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 torch_available = is_torch_available() if torch_available: @@ -27,12 +37,20 @@ # behavior as `Union[TorchTensor, TensorFlowTensor, NdArray]` so it should be fine to use `AnyTensor` as # the type for `tensor` field in `BaseDoc` class. AnyTensor = Union[NdArray] - if torch_available and tf_available: + if torch_available and tf_available and jax_available: + AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor, JaxArray] # type: ignore + elif torch_available and tf_available: AnyTensor = Union[NdArray, TorchTensor, TensorFlowTensor] # type: ignore - elif torch_available: - AnyTensor = Union[NdArray, TorchTensor] # type: ignore + elif tf_available and jax_available: + AnyTensor = Union[NdArray, TensorFlowTensor, JaxArray] # type: ignore + elif torch_available and jax_available: + AnyTensor = Union[NdArray, TorchTensor, JaxArray] # type: ignore elif tf_available: AnyTensor = Union[NdArray, TensorFlowTensor] # type: ignore + elif torch_available: + AnyTensor = Union[NdArray, TorchTensor] # type: ignore + elif jax_available: + AnyTensor = Union[NdArray, JaxArray] # type: ignore else: @@ -124,6 +142,11 @@ def validate( return value elif isinstance(value, tf.Tensor): return TensorFlowTensor._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return value + elif isinstance(value, jnp.ndarray): + return JaxArray._docarray_from_native(value) # noqa try: return NdArray.validate(value, field, config) except Exception as e: # noqa diff --git a/docarray/typing/tensor/tensorflow_tensor.py b/docarray/typing/tensor/tensorflow_tensor.py index 1eb2bc7eacf..a42b3a0a5d3 100644 --- a/docarray/typing/tensor/tensorflow_tensor.py +++ b/docarray/typing/tensor/tensorflow_tensor.py @@ -5,7 +5,11 @@ from docarray.base_doc.base_node import BaseNode from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor -from docarray.utils._internal.misc import import_library, is_torch_available +from docarray.utils._internal.misc import ( + import_library, + is_jax_available, + is_torch_available, +) if TYPE_CHECKING: import tensorflow as tf # type: ignore @@ -21,6 +25,10 @@ if torch_available: import torch +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + T = TypeVar('T', bound='TensorFlowTensor') ShapeT = TypeVar('ShapeT') @@ -211,6 +219,8 @@ def validate( return cls._docarray_from_ndarray(value._docarray_to_ndarray()) elif torch_available and isinstance(value, torch.Tensor): return cls._docarray_from_native(value.detach().cpu().numpy()) + elif jax_available and isinstance(value, jnp.ndarray): + return cls._docarray_from_native(value.__array__()) else: try: arr: tf.Tensor = tf.constant(value) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 0f7ff0132d9..a78781f6a9b 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -6,7 +6,11 @@ from docarray.base_doc.base_node import BaseNode from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor -from docarray.utils._internal.misc import import_library, is_tf_available +from docarray.utils._internal.misc import ( + import_library, + is_jax_available, + is_tf_available, +) if TYPE_CHECKING: import torch @@ -22,6 +26,10 @@ if tf_available: import tensorflow as tf # type: ignore +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + T = TypeVar('T', bound='TorchTensor') ShapeT = TypeVar('ShapeT') @@ -132,6 +140,8 @@ def validate( return cls._docarray_from_ndarray(value.numpy()) elif isinstance(value, np.ndarray): return cls._docarray_from_ndarray(value) + elif jax_available and isinstance(value, jnp.ndarray): + return cls._docarray_from_ndarray(value.__array__()) else: try: arr: torch.Tensor = torch.tensor(value) diff --git a/docarray/typing/tensor/video/__init__.py b/docarray/typing/tensor/video/__init__.py index a575e7b6201..18f0a2e5d8b 100644 --- a/docarray/typing/tensor/video/__init__.py +++ b/docarray/typing/tensor/video/__init__.py @@ -10,6 +10,7 @@ ) if TYPE_CHECKING: + from docarray.typing.tensor.video.video_jax_array import VideoJaxArray # noqa from docarray.typing.tensor.video.video_tensorflow_tensor import ( # noqa VideoTensorFlowTensor, ) @@ -26,6 +27,9 @@ def __getattr__(name: str): elif name == 'VideoTensorFlowTensor': import_library('tensorflow', raise_error=True) import docarray.typing.tensor.video.video_tensorflow_tensor as lib + elif name == 'VideoJaxArray': + import_library('jax', raise_error=True) + import docarray.typing.tensor.video.video_jax_array as lib else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/typing/tensor/video/video_jax_array.py b/docarray/typing/tensor/video/video_jax_array.py index e69de29bb2d..5b060e49246 100644 --- a/docarray/typing/tensor/video/video_jax_array.py +++ b/docarray/typing/tensor/video/video_jax_array.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union + +import numpy as np + +from docarray.typing.proto_register import _register_proto +from docarray.typing.tensor.jaxarray import JaxArray, metaJax +from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin + +T = TypeVar('T', bound='VideoJaxArray') + +if TYPE_CHECKING: + from pydantic import BaseConfig + from pydantic.fields import ModelField + + +@_register_proto(proto_type_name='video_jaxarray') +class VideoJaxArray(JaxArray, VideoTensorMixin, metaclass=metaJax): + """ """ + + @classmethod + def validate( + cls: Type[T], + value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], + field: 'ModelField', + config: 'BaseConfig', + ) -> T: + tensor = super().validate(value=value, field=field, config=config) + return cls.validate_shape(value=tensor) diff --git a/docarray/typing/tensor/video/video_tensor.py b/docarray/typing/tensor/video/video_tensor.py index be77c9db21e..5687ecfe561 100644 --- a/docarray/typing/tensor/video/video_tensor.py +++ b/docarray/typing/tensor/video/video_tensor.py @@ -5,7 +5,18 @@ from docarray.typing.tensor.tensor import AnyTensor from docarray.typing.tensor.video.video_ndarray import VideoNdArray from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin -from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils._internal.misc import ( + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 + from docarray.typing.tensor.video.video_jax_array import VideoJaxArray torch_available = is_torch_available() if torch_available: @@ -94,6 +105,11 @@ def validate( return cast(VideoTensorFlowTensor, value) elif isinstance(value, tf.Tensor): return VideoTensorFlowTensor._docarray_from_native(value) # noqa + if jax_available: + if isinstance(value, JaxArray): + return cast(VideoJaxArray, value) + elif isinstance(value, jnp.ndarray): + return VideoJaxArray._docarray_from_native(value) # noqa if isinstance(value, VideoNdArray): return cast(VideoNdArray, value) if isinstance(value, np.ndarray): diff --git a/docarray/utils/_internal/misc.py b/docarray/utils/_internal/misc.py index 1ac8bc659b6..ad0d28d9c9e 100644 --- a/docarray/utils/_internal/misc.py +++ b/docarray/utils/_internal/misc.py @@ -22,6 +22,13 @@ tf_imported = True +try: + import jax.numpy as jnp # type: ignore # noqa: F401 +except (ImportError, TypeError): + jnp_imported = False +else: + jnp_imported = True + INSTALL_INSTRUCTIONS = { 'google.protobuf': '"docarray[proto]"', 'lz4': '"docarray[proto]"', @@ -78,6 +85,10 @@ def is_tf_available(): return tf_imported +def is_jax_available(): + return jnp_imported + + def is_np_int(item: Any) -> bool: dtype = getattr(item, 'dtype', None) ndim = getattr(item, 'ndim', None) diff --git a/pyproject.toml b/pyproject.toml index eba967bf112..3e0e2ee40a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,9 +77,10 @@ web = ["fastapi"] qdrant = ["qdrant-client"] weaviate = ["weaviate-client"] redis = ['redis'] +jax = ["jaxlib","jax"] # all -full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh"] +full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh", "jax"] [tool.poetry.dev-dependencies] pytest = ">=7.0" diff --git a/tests/units/array/stack/test_array_stacked_jax.py b/tests/units/array/stack/test_array_stacked_jax.py new file mode 100644 index 00000000000..0ca66a44e62 --- /dev/null +++ b/tests/units/array/stack/test_array_stacked_jax.py @@ -0,0 +1,298 @@ +from typing import Optional, Union + +import pytest + +from docarray import BaseDoc, DocList +from docarray.array import DocVec +from docarray.typing import ( + AnyEmbedding, + AnyTensor, + AudioTensor, + ImageTensor, + NdArray, + VideoTensor, +) +from docarray.utils._internal.misc import is_jax_available + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing import JaxArray + + +@pytest.fixture() +@pytest.mark.jax +def batch(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + batch = DocList[Image]([Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)]) + + return batch.to_doc_vec() + + +@pytest.fixture() +@pytest.mark.jax +def nested_batch(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + class MMdoc(BaseDoc): + img: DocList[Image] + + batch = DocVec[MMdoc]( + [ + MMdoc( + img=DocList[Image]( + [Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)] + ) + ) + for _ in range(10) + ] + ) + + return batch + + +@pytest.mark.jax +def test_len(batch): + assert len(batch) == 10 + + +@pytest.mark.jax +def test_getitem(batch): + for i in range(len(batch)): + item = batch[i] + assert isinstance(item.tensor, JaxArray) + assert jnp.allclose(item.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_get_slice(batch): + sliced = batch[0:2] + assert isinstance(sliced, DocVec) + assert len(sliced) == 2 + + +@pytest.mark.jax +def test_iterator(batch): + for doc in batch: + assert jnp.allclose(doc.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_set_after_stacking(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + batch = DocVec[Image]([Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)]) + + batch.tensor = jnp.ones((10, 3, 224, 224)) + assert jnp.allclose(batch.tensor.tensor, jnp.ones((10, 3, 224, 224))) + for i, doc in enumerate(batch): + assert jnp.allclose(doc.tensor.tensor, batch.tensor.tensor[i]) + + +@pytest.mark.jax +def test_stack_optional(batch): + assert jnp.allclose( + batch._storage.tensor_columns['tensor'].tensor, jnp.zeros((10, 3, 224, 224)) + ) + assert jnp.allclose(batch.tensor.tensor, jnp.zeros((10, 3, 224, 224))) + + +@pytest.mark.jax +def test_stack_mod_nested_document(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + class MMdoc(BaseDoc): + img: Image + + batch = DocList[MMdoc]( + [MMdoc(img=Image(tensor=jnp.zeros((3, 224, 224)))) for _ in range(10)] + ).to_doc_vec() + + assert jnp.allclose( + batch._storage.doc_columns['img']._storage.tensor_columns['tensor'].tensor, + jnp.zeros((10, 3, 224, 224)), + ) + + assert jnp.allclose(batch.img.tensor.tensor, jnp.zeros((10, 3, 224, 224))) + + +@pytest.mark.jax +def test_stack_nested_DocArray(nested_batch): + for i in range(len(nested_batch)): + assert jnp.allclose( + nested_batch[i].img._storage.tensor_columns['tensor'].tensor, + jnp.zeros((10, 3, 224, 224)), + ) + + assert jnp.allclose( + nested_batch[i].img.tensor.tensor, jnp.zeros((10, 3, 224, 224)) + ) + + +@pytest.mark.jax +def test_convert_to_da(batch): + da = batch.to_doc_list() + + for doc in da: + assert jnp.allclose(doc.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_unstack_nested_document(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + class MMdoc(BaseDoc): + img: Image + + batch = DocVec[MMdoc]( + [MMdoc(img=Image(tensor=jnp.zeros((3, 224, 224)))) for _ in range(10)] + ) + assert isinstance(batch.img._storage.tensor_columns['tensor'], JaxArray) + da = batch.to_doc_list() + + for doc in da: + assert jnp.allclose(doc.img.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_unstack_nested_DocArray(nested_batch): + batch = nested_batch.to_doc_list() + for i in range(len(batch)): + assert isinstance(batch[i].img, DocList) + for doc in batch[i].img: + assert jnp.allclose(doc.tensor.tensor, jnp.zeros((3, 224, 224))) + + +@pytest.mark.jax +def test_stack_call(): + class Image(BaseDoc): + tensor: JaxArray[3, 224, 224] + + da = DocList[Image]([Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)]) + + da = da.to_doc_vec() + + assert len(da) == 10 + + assert da.tensor.tensor.shape == (10, 3, 224, 224) + + +@pytest.mark.jax +def test_stack_union(): + class Image(BaseDoc): + tensor: Union[JaxArray[3, 224, 224], NdArray[3, 224, 224]] + + DocVec[Image]( + [Image(tensor=jnp.zeros((3, 224, 224))) for _ in range(10)], + tensor_type=JaxArray, + ) + + # union fields aren't actually doc_vec + # just checking that there is no error + + +@pytest.mark.jax +def test_setitem_tensor(batch): + batch[3].tensor.tensor = jnp.zeros((3, 224, 224)) + + +@pytest.mark.jax +@pytest.mark.skip('not working yet') +def test_setitem_tensor_direct(batch): + batch[3].tensor = jnp.zeros((3, 224, 224)) + + +@pytest.mark.jax +@pytest.mark.parametrize( + 'cls_tensor', [ImageTensor, AudioTensor, VideoTensor, AnyEmbedding, AnyTensor] +) +def test_generic_tensors_with_jnp(cls_tensor): + tensor = jnp.zeros((3, 224, 224)) + + class Image(BaseDoc): + tensor: cls_tensor + + da = DocVec[Image]( + [Image(tensor=tensor) for _ in range(10)], + tensor_type=JaxArray, + ) + + for i in range(len(da)): + assert jnp.allclose(da[i].tensor.tensor, tensor) + + assert 'tensor' in da._storage.tensor_columns.keys() + assert isinstance(da._storage.tensor_columns['tensor'], JaxArray) + + +@pytest.mark.jax +@pytest.mark.parametrize( + 'cls_tensor', [ImageTensor, AudioTensor, VideoTensor, AnyEmbedding, AnyTensor] +) +def test_generic_tensors_with_optional(cls_tensor): + tensor = jnp.zeros((3, 224, 224)) + + class Image(BaseDoc): + tensor: Optional[cls_tensor] + + class TopDoc(BaseDoc): + img: Image + + da = DocVec[TopDoc]( + [TopDoc(img=Image(tensor=tensor)) for _ in range(10)], + tensor_type=JaxArray, + ) + + for i in range(len(da)): + assert jnp.allclose(da.img[i].tensor.tensor, tensor) + + assert 'tensor' in da.img._storage.tensor_columns.keys() + assert isinstance(da.img._storage.tensor_columns['tensor'], JaxArray) + assert isinstance(da.img._storage.tensor_columns['tensor'].tensor, jnp.ndarray) + + +@pytest.mark.jax +def test_get_from_slice_stacked(): + class Doc(BaseDoc): + text: str + tensor: JaxArray + + da = DocVec[Doc]( + [Doc(text=f'hello{i}', tensor=jnp.zeros((3, 224, 224))) for i in range(10)] + ) + + da_sliced = da[0:10:2] + assert isinstance(da_sliced, DocVec) + + tensors = da_sliced.tensor.tensor + assert tensors.shape == (5, 3, 224, 224) + + +@pytest.mark.jax +def test_stack_none(): + class MyDoc(BaseDoc): + tensor: Optional[AnyTensor] + + da = DocVec[MyDoc]([MyDoc(tensor=None) for _ in range(10)], tensor_type=JaxArray) + assert 'tensor' in da._storage.tensor_columns.keys() + + +@pytest.mark.jax +def test_keep_dtype_jnp(): + class MyDoc(BaseDoc): + tensor: JaxArray + + da = DocList[MyDoc]( + [MyDoc(tensor=jnp.zeros([2, 4], dtype=jnp.int32)) for _ in range(3)] + ) + assert da[0].tensor.tensor.dtype == jnp.int32 + + da = da.to_doc_vec() + assert da[0].tensor.tensor.dtype == jnp.int32 + assert da.tensor.tensor.dtype == jnp.int32 diff --git a/tests/units/computation_backends/jax_backend/test_basics.py b/tests/units/computation_backends/jax_backend/test_basics.py index 6cd64a19602..3dcbf500522 100644 --- a/tests/units/computation_backends/jax_backend/test_basics.py +++ b/tests/units/computation_backends/jax_backend/test_basics.py @@ -1,14 +1,19 @@ -import jax -import jax.numpy as jnp import pytest -from docarray.computation.jax_backend import JaxCompBackend -from docarray.typing import JaxArray +from docarray.utils._internal.misc import is_jax_available -jax.config.update("jax_enable_x64", True) +jax_available = is_jax_available() +if jax_available: + import jax + import jax.numpy as jnp + from docarray.computation.jax_backend import JaxCompBackend + from docarray.typing import JaxArray -@pytest.mark.tensorflow + jax.config.update("jax_enable_x64", True) + + +@pytest.mark.jax @pytest.mark.parametrize( 'shape,result', [ @@ -23,7 +28,7 @@ def test_n_dim(shape, result): assert JaxCompBackend.n_dim(array) == result -@pytest.mark.tensorflow +@pytest.mark.jax @pytest.mark.parametrize( 'shape,result', [ @@ -39,14 +44,14 @@ def test_shape(shape, result): assert type(shape) == tuple -@pytest.mark.tensorflow +@pytest.mark.jax def test_to_device(): array = JaxArray(jnp.zeros((3))) array = JaxCompBackend.to_device(array, 'cpu') assert array.tensor.device().platform.endswith('cpu') -@pytest.mark.tensorflow +@pytest.mark.jax @pytest.mark.parametrize( 'dtype,result_type', [ @@ -61,34 +66,34 @@ def test_dtype(dtype, result_type): assert JaxCompBackend.dtype(array) == result_type -@pytest.mark.tensorflow +@pytest.mark.jax def test_empty(): array = JaxCompBackend.empty((10, 3)) assert array.tensor.shape == (10, 3) -@pytest.mark.tensorflow +@pytest.mark.jax def test_empty_dtype(): tf_tensor = JaxCompBackend.empty((10, 3), dtype=jnp.int32) assert tf_tensor.tensor.shape == (10, 3) assert tf_tensor.tensor.dtype == jnp.int32 -@pytest.mark.tensorflow +@pytest.mark.jax def test_empty_device(): tensor = JaxCompBackend.empty((10, 3), device='cpu') assert tensor.tensor.shape == (10, 3) assert tensor.tensor.device().platform.endswith('cpu') -@pytest.mark.tensorflow +@pytest.mark.jax def test_squeeze(): tensor = JaxArray(jnp.zeros(shape=(1, 1, 3, 1))) squeezed = JaxCompBackend.squeeze(tensor) assert squeezed.tensor.shape == (3,) -@pytest.mark.tensorflow +@pytest.mark.jax @pytest.mark.parametrize( 'data_input,t_range,x_range,data_result', [ @@ -120,14 +125,14 @@ def test_minmax_normalize(data_input, t_range, x_range, data_result): assert jnp.allclose(output.tensor, jnp.array(data_result)) -@pytest.mark.tensorflow +@pytest.mark.jax def test_reshape(): tensor = JaxArray(jnp.zeros((3, 224, 224))) reshaped = JaxCompBackend.reshape(tensor, (224, 224, 3)) assert reshaped.tensor.shape == (224, 224, 3) -@pytest.mark.tensorflow +@pytest.mark.jax def test_stack(): t0 = JaxArray(jnp.zeros((3, 224, 224))) t1 = JaxArray(jnp.ones((3, 224, 224))) diff --git a/tests/units/computation_backends/jax_backend/test_metrics.py b/tests/units/computation_backends/jax_backend/test_metrics.py index b3134a6096f..ec534359059 100644 --- a/tests/units/computation_backends/jax_backend/test_metrics.py +++ b/tests/units/computation_backends/jax_backend/test_metrics.py @@ -1,12 +1,21 @@ -import jax -import jax.numpy as jnp +import pytest -from docarray.computation.jax_backend import JaxCompBackend -from docarray.typing import JaxArray +from docarray.utils._internal.misc import is_jax_available -metrics = JaxCompBackend.Metrics +jax_available = is_jax_available() +if jax_available: + import jax + import jax.numpy as jnp + from docarray.computation.jax_backend import JaxCompBackend + from docarray.typing import JaxArray + metrics = JaxCompBackend.Metrics +else: + metrics = None + + +@pytest.mark.jax def test_cosine_sim_jax(): a = JaxArray(jax.random.uniform(jax.random.PRNGKey(0), shape=(128,))) b = JaxArray(jax.random.uniform(jax.random.PRNGKey(1), shape=(128,))) @@ -23,6 +32,7 @@ def test_cosine_sim_jax(): assert jnp.allclose(diag_dists, jnp.ones((5,))) +@pytest.mark.jax def test_euclidean_dist_jax(): a = JaxArray(jax.random.normal(jax.random.PRNGKey(0), shape=(128,))) b = JaxArray(jax.random.normal(jax.random.PRNGKey(1), shape=(128,))) @@ -53,6 +63,7 @@ def test_euclidean_dist_jax(): assert jnp.allclose(metrics.euclidean_dist(a, b).tensor, desired_output_singleton) +@pytest.mark.jax def test_sqeuclidea_dist_jnp(): a = JaxArray(jax.random.uniform(jax.random.PRNGKey(0), shape=(128,))) b = JaxArray(jax.random.uniform(jax.random.PRNGKey(1), shape=(128,))) diff --git a/tests/units/computation_backends/jax_backend/test_retrieval.py b/tests/units/computation_backends/jax_backend/test_retrieval.py index a1bb686083e..9f8a3afb415 100644 --- a/tests/units/computation_backends/jax_backend/test_retrieval.py +++ b/tests/units/computation_backends/jax_backend/test_retrieval.py @@ -1,11 +1,20 @@ -import jax.numpy as jnp import pytest -from docarray.computation.jax_backend import JaxCompBackend -from docarray.typing import JaxArray +from docarray.utils._internal.misc import is_jax_available +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp -@pytest.mark.tensorflow + from docarray.computation.jax_backend import JaxCompBackend + from docarray.typing import JaxArray + + metrics = JaxCompBackend.Metrics +else: + metrics = None + + +@pytest.mark.jax def test_top_k_descending_false(): top_k = JaxCompBackend.Retrieval.top_k @@ -32,7 +41,7 @@ def test_top_k_descending_false(): assert jnp.allclose(indices.tensor[1], jnp.array([2, 4, 6])) -@pytest.mark.tensorflow +@pytest.mark.jax def test_top_k_descending_true(): top_k = JaxCompBackend.Retrieval.top_k diff --git a/tests/units/typing/tensor/test_jax_array.py b/tests/units/typing/tensor/test_jax_array.py index b44494d51e1..f5044b23dd9 100644 --- a/tests/units/typing/tensor/test_jax_array.py +++ b/tests/units/typing/tensor/test_jax_array.py @@ -1,14 +1,20 @@ -import jax.numpy as jnp import numpy as np import pytest -from jax._src.core import InconclusiveDimensionOperation from pydantic import schema_json_of from pydantic.tools import parse_obj_as from docarray.base_doc.io.json import orjson_dumps -from docarray.typing import JaxArray +from docarray.utils._internal.misc import is_jax_available +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + from jax._src.core import InconclusiveDimensionOperation + from docarray.typing import JaxArray + + +@pytest.mark.jax def test_proto_tensor(): from docarray.proto.pb2.docarray_pb2 import NdArrayProto @@ -21,15 +27,18 @@ def test_proto_tensor(): assert jnp.allclose(tensor.tensor, from_proto.tensor) +@pytest.mark.jax def test_json_schema(): schema_json_of(JaxArray) +@pytest.mark.jax def test_dump_json(): tensor = parse_obj_as(JaxArray, jnp.zeros((3, 224, 224))) orjson_dumps(tensor) +@pytest.mark.jax def test_unwrap(): tf_tensor = parse_obj_as(JaxArray, jnp.zeros((3, 224, 224))) unwrapped = tf_tensor.unwrap() @@ -41,6 +50,7 @@ def test_unwrap(): assert np.allclose(unwrapped, np.zeros((3, 224, 224))) +@pytest.mark.jax def test_from_ndarray(): nd = np.array([1, 2, 3]) tensor = JaxArray.from_ndarray(nd) @@ -48,6 +58,7 @@ def test_from_ndarray(): assert isinstance(tensor.tensor, jnp.ndarray) +@pytest.mark.jax def test_ellipsis_in_shape(): # ellipsis in the end, two extra dimensions needed tf_tensor = parse_obj_as(JaxArray[3, ...], jnp.zeros((3, 128, 224))) @@ -70,6 +81,7 @@ def test_ellipsis_in_shape(): parse_obj_as(JaxArray[3, 224, ...], jnp.zeros((3, 128, 224))) +@pytest.mark.jax def test_parametrized(): # correct shape, single axis tf_tensor = parse_obj_as(JaxArray[128], jnp.zeros(128)) @@ -94,6 +106,7 @@ def test_parametrized(): parse_obj_as(JaxArray[3, 224, 224], jnp.zeros((224, 224))) +@pytest.mark.jax def test_parametrized_with_str(): # test independent variable dimensions tf_tensor = parse_obj_as(JaxArray[3, 'x', 'y'], jnp.zeros((3, 224, 224))) @@ -125,6 +138,7 @@ def test_parametrized_with_str(): _ = parse_obj_as(JaxArray[3, 'x', 'x'], jnp.zeros((3, 60))) +@pytest.mark.jax @pytest.mark.parametrize('shape', [(3, 224, 224), (224, 224, 3)]) def test_parameterized_tensor_class_name(shape): MyTFT = JaxArray[3, 224, 224] @@ -138,6 +152,7 @@ def test_parameterized_tensor_class_name(shape): assert f'{tensor.tensor[0][0][0]}' == '0.0' +@pytest.mark.jax def test_parametrized_subclass(): c1 = JaxArray[128] c2 = JaxArray[128] @@ -147,6 +162,7 @@ def test_parametrized_subclass(): assert not issubclass(c1, JaxArray[256]) +@pytest.mark.jax def test_parametrized_instance(): t = parse_obj_as(JaxArray[128], jnp.zeros((128,))) assert isinstance(t, JaxArray[128]) @@ -158,6 +174,7 @@ def test_parametrized_instance(): assert not isinstance(t, JaxArray[2, 2, 64]) +@pytest.mark.jax def test_parametrized_equality(): t1 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) t2 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) From e0a7d89cd351a635b35b7c1aedb5a800e5d353f6 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Tue, 11 Jul 2023 17:35:54 +0530 Subject: [PATCH 13/25] fix: add jax pytest marker for missing testcase Signed-off-by: agaraman0 --- tests/units/typing/tensor/test_jax_array.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/units/typing/tensor/test_jax_array.py b/tests/units/typing/tensor/test_jax_array.py index f5044b23dd9..6e062f0ec5d 100644 --- a/tests/units/typing/tensor/test_jax_array.py +++ b/tests/units/typing/tensor/test_jax_array.py @@ -181,6 +181,7 @@ def test_parametrized_equality(): assert jnp.allclose(t1.tensor, t2.tensor) +@pytest.mark.jax def test_parametrized_operations(): t1 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) t2 = parse_obj_as(JaxArray[128], jnp.zeros((128,))) @@ -190,6 +191,7 @@ def test_parametrized_operations(): assert not isinstance(t_result, JaxArray[128]) +@pytest.mark.jax def test_set_item(): t = JaxArray(tensor=jnp.zeros((3, 224, 224))) t[0] = jnp.ones((1, 224, 224)) From 3f9a399f0d37db658e6b7331bf062af37fd8eb6b Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Fri, 14 Jul 2023 17:54:15 +0530 Subject: [PATCH 14/25] feat: added integration test Signed-off-by: agaraman0 --- .../array/test_jax_integration.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/integrations/array/test_jax_integration.py diff --git a/tests/integrations/array/test_jax_integration.py b/tests/integrations/array/test_jax_integration.py new file mode 100644 index 00000000000..00488e349c3 --- /dev/null +++ b/tests/integrations/array/test_jax_integration.py @@ -0,0 +1,37 @@ +from typing import Optional + +import jax.numpy as jnp +import pytest +from jax import jit + +from docarray import BaseDoc, DocList +from docarray.typing import JaxArray + + +class Mmdoc(BaseDoc): + tensor: Optional[JaxArray[3, 224, 224]] + + +def basic_jax_fn(x): + return jnp.sum(x) + + +def abstract_JaxArray(array: JaxArray) -> jnp.ndarray: + return array.tensor + + +@pytest.mark.jax +def test_basic_jax_operation(): + N = 10 + + batch = DocList[Mmdoc](Mmdoc() for _ in range(N)) + batch.tensor = jnp.zeros((N, 3, 224, 224)) + + batch = batch.to_doc_vec() + + jax_fn = jit(basic_jax_fn) + result = jax_fn(abstract_JaxArray(batch.tensor)) + + assert ( + result == 0.0 + ) # checking if the sum of the tensor data is zero as initialized From b7b81d65dc7fff8aa650c7af1fa57668b29db140 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Fri, 14 Jul 2023 18:02:06 +0530 Subject: [PATCH 15/25] fix poetry lock update Signed-off-by: agaraman0 --- poetry.lock | 398 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 367 insertions(+), 31 deletions(-) diff --git a/poetry.lock b/poetry.lock index b8b2a97e009..dd8e1b1ef3f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "aiohttp" version = "3.8.4" description = "Async http client/server framework (asyncio)" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -112,6 +113,7 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -126,6 +128,7 @@ frozenlist = ">=1.1.0" name = "anyio" version = "3.6.2" description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" optional = false python-versions = ">=3.6.2" files = [ @@ -146,6 +149,7 @@ trio = ["trio (>=0.16,<0.22)"] name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" +category = "dev" optional = false python-versions = "*" files = [ @@ -157,6 +161,7 @@ files = [ name = "argon2-cffi" version = "21.3.0" description = "The secure Argon2 password hashing algorithm." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -176,6 +181,7 @@ tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"] name = "argon2-cffi-bindings" version = "21.2.0" description = "Low-level CFFI bindings for Argon2" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -213,6 +219,7 @@ tests = ["pytest"] name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -224,6 +231,7 @@ files = [ name = "attrs" version = "22.1.0" description = "Classes Without Boilerplate" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -241,6 +249,7 @@ tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy name = "authlib" version = "1.2.0" description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." +category = "main" optional = true python-versions = "*" files = [ @@ -255,6 +264,7 @@ cryptography = ">=3.2" name = "av" version = "10.0.0" description = "Pythonic bindings for FFmpeg's libraries." +category = "main" optional = true python-versions = "*" files = [ @@ -308,6 +318,7 @@ files = [ name = "babel" version = "2.11.0" description = "Internationalization utilities" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -322,6 +333,7 @@ pytz = ">=2015.7" name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" +category = "dev" optional = false python-versions = "*" files = [ @@ -333,6 +345,7 @@ files = [ name = "beautifulsoup4" version = "4.11.1" description = "Screen-scraping library" +category = "dev" optional = false python-versions = ">=3.6.0" files = [ @@ -351,6 +364,7 @@ lxml = ["lxml"] name = "black" version = "22.10.0" description = "The uncompromising code formatter." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -395,6 +409,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "blacken-docs" version = "1.13.0" description = "Run Black on Python code blocks in documentation files." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -409,6 +424,7 @@ black = ">=22.1.0" name = "bleach" version = "5.0.1" description = "An easy safelist-based HTML-sanitizing tool." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -428,6 +444,7 @@ dev = ["Sphinx (==4.3.2)", "black (==22.3.0)", "build (==0.8.0)", "flake8 (==4.0 name = "boto3" version = "1.26.95" description = "The AWS SDK for Python" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -447,6 +464,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.29.95" description = "Low-level, data-driven core of boto 3." +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -466,6 +484,7 @@ crt = ["awscrt (==0.16.9)"] name = "bracex" version = "2.3.post1" description = "Bash style brace expander." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -477,6 +496,7 @@ files = [ name = "certifi" version = "2022.9.24" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -488,6 +508,7 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." +category = "main" optional = false python-versions = "*" files = [ @@ -564,6 +585,7 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." +category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -575,6 +597,7 @@ files = [ name = "chardet" version = "5.1.0" description = "Universal encoding detector for Python 3" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -586,6 +609,7 @@ files = [ name = "charset-normalizer" version = "2.0.12" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" optional = false python-versions = ">=3.5.0" files = [ @@ -600,6 +624,7 @@ unicode-backport = ["unicodedata2"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -614,6 +639,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -625,6 +651,7 @@ files = [ name = "colorlog" version = "6.7.0" description = "Add colours to the output of Python's logging module." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -642,6 +669,7 @@ development = ["black", "flake8", "mypy", "pytest", "types-colorama"] name = "commonmark" version = "0.9.1" description = "Python parser for the CommonMark Markdown spec" +category = "main" optional = false python-versions = "*" files = [ @@ -656,6 +684,7 @@ test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] name = "coverage" version = "6.2" description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -718,6 +747,7 @@ toml = ["tomli"] name = "cryptography" version = "40.0.1" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -759,6 +789,7 @@ tox = ["tox"] name = "debugpy" version = "1.6.3" description = "An implementation of the Debug Adapter Protocol for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -786,6 +817,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -797,6 +829,7 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -808,6 +841,7 @@ files = [ name = "distlib" version = "0.3.6" description = "Distribution utilities" +category = "dev" optional = false python-versions = "*" files = [ @@ -819,6 +853,7 @@ files = [ name = "docker" version = "6.0.1" description = "A Python library for the Docker Engine API." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -840,6 +875,7 @@ ssh = ["paramiko (>=2.4.3)"] name = "ecdsa" version = "0.18.0" description = "ECDSA cryptographic signature library (pure python)" +category = "main" optional = true python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -858,6 +894,7 @@ gmpy2 = ["gmpy2"] name = "elastic-transport" version = "8.4.0" description = "Transport classes and utilities shared among Python Elastic client libraries" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -876,6 +913,7 @@ develop = ["aiohttp", "mock", "pytest", "pytest-asyncio", "pytest-cov", "pytest- name = "elasticsearch" version = "7.10.1" description = "Python client for Elasticsearch" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4" files = [ @@ -897,6 +935,7 @@ requests = ["requests (>=2.4.0,<3.0.0)"] name = "entrypoints" version = "0.4" description = "Discover and load entry points from installed packages." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -908,6 +947,7 @@ files = [ name = "exceptiongroup" version = "1.1.0" description = "Backport of PEP 654 (exception groups)" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -922,6 +962,7 @@ test = ["pytest (>=6)"] name = "fastapi" version = "0.87.0" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -943,6 +984,7 @@ test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==22.8.0)", "coverage[toml] (>=6 name = "fastjsonschema" version = "2.16.2" description = "Fastest Python implementation of JSON schema" +category = "dev" optional = false python-versions = "*" files = [ @@ -957,6 +999,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.8.0" description = "A platform independent file lock." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -972,6 +1015,7 @@ testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pyt name = "frozenlist" version = "1.3.3" description = "A list-like structure which implements collections.abc.MutableSequence" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1055,6 +1099,7 @@ files = [ name = "ghp-import" version = "2.1.0" description = "Copy your docs directly to the gh-pages branch." +category = "dev" optional = false python-versions = "*" files = [ @@ -1072,6 +1117,7 @@ dev = ["flake8", "markdown", "twine", "wheel"] name = "griffe" version = "0.25.5" description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1089,6 +1135,7 @@ async = ["aiofiles (>=0.7,<1.0)"] name = "grpcio" version = "1.53.0" description = "HTTP/2-based RPC framework" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1146,6 +1193,7 @@ protobuf = ["grpcio-tools (>=1.53.0)"] name = "grpcio-tools" version = "1.53.0" description = "Protobuf code generator for gRPC" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1205,6 +1253,7 @@ setuptools = "*" name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1216,6 +1265,7 @@ files = [ name = "h2" version = "4.1.0" description = "HTTP/2 State-Machine based protocol implementation" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -1231,6 +1281,7 @@ hyperframe = ">=6.0,<7" name = "hnswlib" version = "0.7.0" description = "hnswlib" +category = "main" optional = true python-versions = "*" files = [ @@ -1244,6 +1295,7 @@ numpy = "*" name = "hpack" version = "4.0.0" description = "Pure-Python HPACK header compression" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -1255,6 +1307,7 @@ files = [ name = "httpcore" version = "0.16.1" description = "A minimal low-level HTTP client." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1266,16 +1319,17 @@ files = [ anyio = ">=3.0,<5.0" certifi = "*" h11 = ">=0.13,<0.15" -sniffio = "==1.*" +sniffio = ">=1.0.0,<2.0.0" [package.extras] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] [[package]] name = "httpx" version = "0.23.1" description = "The next generation HTTP client." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1292,14 +1346,15 @@ sniffio = "*" [package.extras] brotli = ["brotli", "brotlicffi"] -cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<13)"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<13)"] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] [[package]] name = "hyperframe" version = "6.0.1" description = "HTTP/2 framing layer for Python" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -1311,6 +1366,7 @@ files = [ name = "identify" version = "2.5.8" description = "File identification library for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1325,6 +1381,7 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1336,6 +1393,7 @@ files = [ name = "importlib-metadata" version = "5.0.0" description = "Read metadata from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1355,6 +1413,7 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "5.10.0" description = "Read resources from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1373,6 +1432,7 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "iniconfig" version = "1.1.1" description = "iniconfig: brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = "*" files = [ @@ -1384,6 +1444,7 @@ files = [ name = "ipykernel" version = "6.16.2" description = "IPython Kernel for Jupyter" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1412,6 +1473,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-cov", "p name = "ipython" version = "7.34.0" description = "IPython: Productive Interactive Computing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1448,6 +1510,7 @@ test = ["ipykernel", "nbformat", "nose (>=0.10.1)", "numpy (>=1.17)", "pygments" name = "ipython-genutils" version = "0.2.0" description = "Vestigial utilities from IPython" +category = "dev" optional = false python-versions = "*" files = [ @@ -1459,6 +1522,7 @@ files = [ name = "isort" version = "5.11.5" description = "A Python utility / library to sort Python imports." +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -1472,10 +1536,42 @@ pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib" plugins = ["setuptools"] requirements-deprecated-finder = ["pip-api", "pipreqs"] +[[package]] +name = "jax" +version = "0.4.13" +description = "Differentiate, compile, and transform Numpy code." +category = "main" +optional = true +python-versions = ">=3.8" +files = [ + {file = "jax-0.4.13.tar.gz", hash = "sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa"}, +] + +[package.dependencies] +importlib_metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} +ml_dtypes = ">=0.1.0" +numpy = ">=1.21" +opt_einsum = "*" +scipy = ">=1.7" + +[package.extras] +australis = ["protobuf (>=3.13,<4)"] +ci = ["jaxlib (==0.4.12)"] +cpu = ["jaxlib (==0.4.13)"] +cuda = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.13+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.13+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.13+cuda12.cudnn89)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] + [[package]] name = "jedi" version = "0.18.1" description = "An autocompletion tool for Python that can be used for text editors." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1494,6 +1590,7 @@ testing = ["Django (<3.1)", "colorama", "docopt", "pytest (<7.0.0)"] name = "jina-hubble-sdk" version = "0.34.0" description = "SDK for Hubble API at Jina AI." +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -1519,6 +1616,7 @@ full = ["aiohttp", "black (==22.3.0)", "docker", "filelock", "flake8 (==4.0.1)", name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1536,6 +1634,7 @@ i18n = ["Babel (>=2.7)"] name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1547,6 +1646,7 @@ files = [ name = "json5" version = "0.9.10" description = "A Python implementation of the JSON5 data format." +category = "dev" optional = false python-versions = "*" files = [ @@ -1561,6 +1661,7 @@ dev = ["hypothesis"] name = "jsonschema" version = "4.17.0" description = "An implementation of JSON Schema validation for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1582,6 +1683,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jupyter-client" version = "7.4.6" description = "Jupyter protocol implementation and client libraries" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1606,6 +1708,7 @@ test = ["codecov", "coverage", "ipykernel (>=6.12)", "ipython", "mypy", "pre-com name = "jupyter-core" version = "4.12.0" description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1624,6 +1727,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyter-server" version = "1.23.2" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1656,6 +1760,7 @@ test = ["coverage", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console name = "jupyterlab" version = "3.5.0" description = "JupyterLab computational environment" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1683,6 +1788,7 @@ ui-tests = ["build"] name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1694,6 +1800,7 @@ files = [ name = "jupyterlab-server" version = "2.16.3" description = "A set of server components for JupyterLab and JupyterLab like applications." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1720,6 +1827,7 @@ test = ["codecov", "ipykernel", "jupyter-server[test]", "openapi-core (>=0.14.2, name = "lxml" version = "4.9.2" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" files = [ @@ -1812,6 +1920,7 @@ source = ["Cython (>=0.29.7)"] name = "lz4" version = "4.3.2" description = "LZ4 Bindings for Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1861,6 +1970,7 @@ tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] name = "mapbox-earcut" version = "1.0.1" description = "Python bindings for the mapbox earcut C++ polygon triangulation library." +category = "main" optional = true python-versions = "*" files = [ @@ -1935,6 +2045,7 @@ test = ["pytest"] name = "markdown" version = "3.3.7" description = "Python implementation of Markdown." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1952,6 +2063,7 @@ testing = ["coverage", "pyyaml"] name = "markupsafe" version = "2.1.1" description = "Safely add untrusted strings to HTML/XML markup." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2001,6 +2113,7 @@ files = [ name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2015,6 +2128,7 @@ traitlets = "*" name = "mergedeep" version = "1.3.4" description = "A deep merge function for 🐍." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2026,6 +2140,7 @@ files = [ name = "mistune" version = "2.0.4" description = "A sane Markdown parser with useful plugins and renderers" +category = "dev" optional = false python-versions = "*" files = [ @@ -2037,6 +2152,7 @@ files = [ name = "mkdocs" version = "1.4.2" description = "Project documentation with Markdown." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2065,6 +2181,7 @@ min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4)", "ghp-imp name = "mkdocs-autorefs" version = "0.4.1" description = "Automatically link across pages in MkDocs." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2080,6 +2197,7 @@ mkdocs = ">=1.1" name = "mkdocs-awesome-pages-plugin" version = "2.8.0" description = "An MkDocs plugin that simplifies configuring page titles and their order" +category = "dev" optional = false python-versions = ">=3.6.2" files = [ @@ -2096,6 +2214,7 @@ wcmatch = ">=7" name = "mkdocs-material" version = "9.1.3" description = "Documentation that simply works" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2118,6 +2237,7 @@ requests = ">=2.26" name = "mkdocs-material-extensions" version = "1.1.1" description = "Extension pack for Python Markdown and MkDocs Material." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2129,6 +2249,7 @@ files = [ name = "mkdocs-video" version = "1.5.0" description = "" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2144,6 +2265,7 @@ mkdocs = ">=1.1.0,<2" name = "mkdocstrings" version = "0.20.0" description = "Automatic documentation from sources, for MkDocs." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2169,6 +2291,7 @@ python-legacy = ["mkdocstrings-python-legacy (>=0.2.1)"] name = "mkdocstrings-python" version = "0.8.3" description = "A Python handler for mkdocstrings." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2184,6 +2307,7 @@ mkdocstrings = ">=0.19" name = "mktestdocs" version = "0.2.0" description = "" +category = "dev" optional = false python-versions = "*" files = [ @@ -2194,10 +2318,48 @@ files = [ [package.extras] test = ["pytest (>=4.0.2)"] +[[package]] +name = "ml-dtypes" +version = "0.2.0" +description = "" +category = "main" +optional = true +python-versions = ">=3.7" +files = [ + {file = "ml_dtypes-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:df6a76e1c8adf484feb138ed323f9f40a7b6c21788f120f7c78bec20ac37ee81"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc29a0524ef5e23a7fbb8d881bdecabeb3fc1d19d9db61785d077a86cb94fab2"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08c391c2794f2aad358e6f4c70785a9a7b1df980ef4c232b3ccd4f6fe39f719"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:75015818a7fccf99a5e8ed18720cb430f3e71a8838388840f4cdf225c036c983"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e70047ec2c83eaee01afdfdabee2c5b0c133804d90d0f7db4dd903360fcc537c"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d28b8861a8931695e5a31176cad5ae85f6504906650dea5598fbec06c94606"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e85ba8e24cf48d456e564688e981cf379d4c8e644db0a2f719b78de281bac2ca"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:832a019a1b6db5c4422032ca9940a990fa104eee420f643713241b3a518977fa"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8faaf0897942c8253dd126662776ba45f0a5861968cf0f06d6d465f8a7bc298a"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b984cddbe8173b545a0e3334fe56ea1a5c3eb67c507f60d0cfde1d3fa8f8c2"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:022d5a4ee6be14569c2a9d1549e16f1ec87ca949681d0dca59995445d5fcdd5b"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:50845af3e9a601810751b55091dee6c2562403fa1cb4e0123675cf3a4fc2c17a"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f00c71c8c63e03aff313bc6a7aeaac9a4f1483a921a6ffefa6d4404efd1af3d0"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80d304c836d73f10605c58ccf7789c171cc229bfb678748adfb7cea2510dfd0e"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32107e7fa9f62db9a5281de923861325211dfff87bd23faefb27b303314635ab"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:1749b60348da71fd3c2ab303fdbc1965958dc50775ead41f5669c932a341cafd"}, + {file = "ml_dtypes-0.2.0.tar.gz", hash = "sha256:6488eb642acaaf08d8020f6de0a38acee7ac324c1e6e92ee0c0fea42422cb797"}, +] + +[package.dependencies] +numpy = [ + {version = ">1.20", markers = "python_version <= \"3.9\""}, + {version = ">=1.23.3", markers = "python_version > \"3.10\""}, + {version = ">=1.21.2", markers = "python_version > \"3.9\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + [[package]] name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" +category = "main" optional = true python-versions = "*" files = [ @@ -2215,6 +2377,7 @@ tests = ["pytest (>=4.6)"] name = "multidict" version = "6.0.4" description = "multidict implementation" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2298,6 +2461,7 @@ files = [ name = "mypy" version = "1.0.0" description = "Optional static typing for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2344,6 +2508,7 @@ reports = ["lxml"] name = "mypy-extensions" version = "0.4.3" description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "main" optional = false python-versions = "*" files = [ @@ -2355,6 +2520,7 @@ files = [ name = "natsort" version = "8.3.1" description = "Simple yet flexible natural sorting in Python." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2370,6 +2536,7 @@ icu = ["PyICU (>=1.0.0)"] name = "nbclassic" version = "0.4.8" description = "A web-based notebook environment for interactive computing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2405,6 +2572,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "pytest-playwright", "pytes name = "nbclient" version = "0.7.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -2426,6 +2594,7 @@ test = ["black", "check-manifest", "flake8", "ipykernel", "ipython", "ipywidgets name = "nbconvert" version = "7.2.5" description = "Converting Jupyter Notebooks" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2464,6 +2633,7 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.7.0" description = "The Jupyter Notebook format" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2484,6 +2654,7 @@ test = ["check-manifest", "pep440", "pre-commit", "pytest", "testpath"] name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2495,6 +2666,7 @@ files = [ name = "networkx" version = "2.6.3" description = "Python package for creating and manipulating graphs and networks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2513,6 +2685,7 @@ test = ["codecov (>=2.1)", "pytest (>=6.2)", "pytest-cov (>=2.12)"] name = "nodeenv" version = "1.7.0" description = "Node.js virtual environment builder" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -2527,6 +2700,7 @@ setuptools = "*" name = "notebook" version = "6.5.2" description = "A web-based notebook environment for interactive computing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2561,6 +2735,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "requests", "requests-unixs name = "notebook-shim" version = "0.2.2" description = "A shim layer for notebook traits and config" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2578,6 +2753,7 @@ test = ["pytest", "pytest-console-scripts", "pytest-tornasync"] name = "numpy" version = "1.21.1" description = "NumPy is the fundamental package for array computing with Python." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2611,10 +2787,49 @@ files = [ {file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"}, ] +[[package]] +name = "numpy" +version = "1.24.4" +description = "Fundamental package for array computing in Python" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, + {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, + {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, + {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, + {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, + {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, + {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, + {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, + {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, + {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, + {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, +] + [[package]] name = "nvidia-cublas-cu11" version = "11.10.3.66" description = "CUBLAS native runtime libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2630,6 +2845,7 @@ wheel = "*" name = "nvidia-cuda-nvrtc-cu11" version = "11.7.99" description = "NVRTC native runtime libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2646,6 +2862,7 @@ wheel = "*" name = "nvidia-cuda-runtime-cu11" version = "11.7.99" description = "CUDA Runtime native Libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2661,6 +2878,7 @@ wheel = "*" name = "nvidia-cudnn-cu11" version = "8.5.0.96" description = "cuDNN runtime libraries" +category = "main" optional = true python-versions = ">=3" files = [ @@ -2672,10 +2890,30 @@ files = [ setuptools = "*" wheel = "*" +[[package]] +name = "opt-einsum" +version = "3.3.0" +description = "Optimizing numpys einsum function" +category = "main" +optional = true +python-versions = ">=3.5" +files = [ + {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, + {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, +] + +[package.dependencies] +numpy = ">=1.7" + +[package.extras] +docs = ["numpydoc", "sphinx (==1.2.3)", "sphinx-rtd-theme", "sphinxcontrib-napoleon"] +tests = ["pytest", "pytest-cov", "pytest-pep8"] + [[package]] name = "orjson" version = "3.8.2" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2734,6 +2972,7 @@ files = [ name = "packaging" version = "21.3" description = "Core utilities for Python packages" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2748,6 +2987,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" name = "pandas" version = "1.1.0" description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" optional = true python-versions = ">=3.6.1" files = [ @@ -2781,6 +3021,7 @@ test = ["hypothesis (>=3.58)", "pytest (>=4.0.2)", "pytest-xdist"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2792,6 +3033,7 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2807,6 +3049,7 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathspec" version = "0.10.2" description = "Utility library for gitignore style pattern matching of file paths." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2818,6 +3061,7 @@ files = [ name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." +category = "dev" optional = false python-versions = "*" files = [ @@ -2832,6 +3076,7 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" +category = "dev" optional = false python-versions = "*" files = [ @@ -2843,6 +3088,7 @@ files = [ name = "pillow" version = "9.3.0" description = "Python Imaging Library (Fork)" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2917,6 +3163,7 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2928,6 +3175,7 @@ files = [ name = "platformdirs" version = "2.5.4" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2943,6 +3191,7 @@ test = ["appdirs (==1.4.4)", "pytest (>=7.2)", "pytest-cov (>=4)", "pytest-mock name = "pluggy" version = "0.13.1" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2957,6 +3206,7 @@ dev = ["pre-commit", "tox"] name = "pre-commit" version = "2.20.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2976,6 +3226,7 @@ virtualenv = ">=20.0.8" name = "prometheus-client" version = "0.15.0" description = "Python client for the Prometheus monitoring system." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2990,6 +3241,7 @@ twisted = ["twisted"] name = "prompt-toolkit" version = "3.0.32" description = "Library for building powerful interactive command lines in Python" +category = "dev" optional = false python-versions = ">=3.6.2" files = [ @@ -3004,6 +3256,7 @@ wcwidth = "*" name = "protobuf" version = "4.21.9" description = "" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3027,6 +3280,7 @@ files = [ name = "psutil" version = "5.9.4" description = "Cross-platform lib for process and system monitoring in Python." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3053,6 +3307,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -3064,6 +3319,7 @@ files = [ name = "py" version = "1.11.0" description = "library with cross-python path, ini-parsing, io, code, log facilities" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -3075,6 +3331,7 @@ files = [ name = "pyasn1" version = "0.4.8" description = "ASN.1 types and codecs" +category = "main" optional = true python-versions = "*" files = [ @@ -3086,6 +3343,7 @@ files = [ name = "pycollada" version = "0.7.2" description = "python library for reading and writing collada documents" +category = "main" optional = true python-versions = "*" files = [ @@ -3103,6 +3361,7 @@ validation = ["lxml"] name = "pycparser" version = "2.21" description = "C parser in Python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3114,6 +3373,7 @@ files = [ name = "pydantic" version = "1.10.2" description = "Data validation and settings management using python type hints" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3166,6 +3426,7 @@ email = ["email-validator (>=1.0.3)"] name = "pydub" version = "0.25.1" description = "Manipulate audio with an simple and easy high level interface" +category = "main" optional = true python-versions = "*" files = [ @@ -3177,6 +3438,7 @@ files = [ name = "pygments" version = "2.14.0" description = "Pygments is a syntax highlighting package written in Python." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3191,6 +3453,7 @@ plugins = ["importlib-metadata"] name = "pymdown-extensions" version = "9.10" description = "Extension pack for Python Markdown." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3206,6 +3469,7 @@ pyyaml = "*" name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -3220,6 +3484,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyrsistent" version = "0.19.2" description = "Persistent/Functional/Immutable data structures" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3251,6 +3516,7 @@ files = [ name = "pytest" version = "7.2.1" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3274,6 +3540,7 @@ testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2. name = "pytest-asyncio" version = "0.20.2" description = "Pytest support for asyncio" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3291,6 +3558,7 @@ testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy name = "pytest-cov" version = "3.0.0" description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3309,6 +3577,7 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -3323,6 +3592,7 @@ six = ">=1.5" name = "python-jose" version = "3.3.0" description = "JOSE implementation in Python" +category = "main" optional = true python-versions = "*" files = [ @@ -3344,6 +3614,7 @@ pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] name = "pytz" version = "2022.6" description = "World timezone definitions, modern and historical" +category = "main" optional = false python-versions = "*" files = [ @@ -3355,6 +3626,7 @@ files = [ name = "pywin32" version = "305" description = "Python for Window Extensions" +category = "main" optional = false python-versions = "*" files = [ @@ -3378,6 +3650,7 @@ files = [ name = "pywinpty" version = "2.0.9" description = "Pseudo terminal support for Windows from Python." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3393,6 +3666,7 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3442,6 +3716,7 @@ files = [ name = "pyyaml-env-tag" version = "0.1" description = "A custom YAML tag for referencing environment variables in YAML files. " +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3456,6 +3731,7 @@ pyyaml = "*" name = "pyzmq" version = "24.0.1" description = "Python bindings for 0MQ" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3543,6 +3819,7 @@ py = {version = "*", markers = "implementation_name == \"pypy\""} name = "qdrant-client" version = "1.1.4" description = "Client library for the Qdrant vector search engine" +category = "main" optional = true python-versions = ">=3.7,<3.12" files = [ @@ -3563,6 +3840,7 @@ urllib3 = ">=1.26.14,<2.0.0" name = "redis" version = "4.6.0" description = "Python client for Redis database and key-value store" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3581,6 +3859,7 @@ ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)" name = "regex" version = "2022.10.31" description = "Alternative regular expression module, to replace re." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3678,6 +3957,7 @@ files = [ name = "requests" version = "2.28.2" description = "Python HTTP for Humans." +category = "main" optional = false python-versions = ">=3.7, <4" files = [ @@ -3699,6 +3979,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rfc3986" version = "1.5.0" description = "Validating URI References per RFC 3986" +category = "main" optional = false python-versions = "*" files = [ @@ -3716,6 +3997,7 @@ idna2008 = ["idna"] name = "rich" version = "13.1.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3735,6 +4017,7 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" +category = "main" optional = true python-versions = ">=3.6,<4" files = [ @@ -3749,6 +4032,7 @@ pyasn1 = ">=0.1.3" name = "rtree" version = "1.0.1" description = "R-Tree spatial index for Python GIS" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3803,6 +4087,7 @@ files = [ name = "ruff" version = "0.0.243" description = "An extremely fast Python linter, written in Rust." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3828,6 +4113,7 @@ files = [ name = "s3transfer" version = "0.6.0" description = "An Amazon S3 Transfer Manager" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -3843,39 +4129,48 @@ crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] [[package]] name = "scipy" -version = "1.6.1" -description = "SciPy: Scientific Library for Python" +version = "1.9.3" +description = "Fundamental algorithms for scientific computing in Python" +category = "main" optional = true -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "scipy-1.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a15a1f3fc0abff33e792d6049161b7795909b40b97c6cc2934ed54384017ab76"}, - {file = "scipy-1.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:e79570979ccdc3d165456dd62041d9556fb9733b86b4b6d818af7a0afc15f092"}, - {file = "scipy-1.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:a423533c55fec61456dedee7b6ee7dce0bb6bfa395424ea374d25afa262be261"}, - {file = "scipy-1.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:33d6b7df40d197bdd3049d64e8e680227151673465e5d85723b3b8f6b15a6ced"}, - {file = "scipy-1.6.1-cp37-cp37m-win32.whl", hash = "sha256:6725e3fbb47da428794f243864f2297462e9ee448297c93ed1dcbc44335feb78"}, - {file = "scipy-1.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:5fa9c6530b1661f1370bcd332a1e62ca7881785cc0f80c0d559b636567fab63c"}, - {file = "scipy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bd50daf727f7c195e26f27467c85ce653d41df4358a25b32434a50d8870fc519"}, - {file = "scipy-1.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:f46dd15335e8a320b0fb4685f58b7471702234cba8bb3442b69a3e1dc329c345"}, - {file = "scipy-1.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0e5b0ccf63155d90da576edd2768b66fb276446c371b73841e3503be1d63fb5d"}, - {file = "scipy-1.6.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2481efbb3740977e3c831edfd0bd9867be26387cacf24eb5e366a6a374d3d00d"}, - {file = "scipy-1.6.1-cp38-cp38-win32.whl", hash = "sha256:68cb4c424112cd4be886b4d979c5497fba190714085f46b8ae67a5e4416c32b4"}, - {file = "scipy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:5f331eeed0297232d2e6eea51b54e8278ed8bb10b099f69c44e2558c090d06bf"}, - {file = "scipy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8a51d33556bf70367452d4d601d1742c0e806cd0194785914daf19775f0e67"}, - {file = "scipy-1.6.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:83bf7c16245c15bc58ee76c5418e46ea1811edcc2e2b03041b804e46084ab627"}, - {file = "scipy-1.6.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:794e768cc5f779736593046c9714e0f3a5940bc6dcc1dba885ad64cbfb28e9f0"}, - {file = "scipy-1.6.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5da5471aed911fe7e52b86bf9ea32fb55ae93e2f0fac66c32e58897cfb02fa07"}, - {file = "scipy-1.6.1-cp39-cp39-win32.whl", hash = "sha256:8e403a337749ed40af60e537cc4d4c03febddcc56cd26e774c9b1b600a70d3e4"}, - {file = "scipy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a5193a098ae9f29af283dcf0041f762601faf2e595c0db1da929875b7570353f"}, - {file = "scipy-1.6.1.tar.gz", hash = "sha256:c4fceb864890b6168e79b0e714c585dbe2fd4222768ee90bc1aa0f8218691b11"}, + {file = "scipy-1.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1884b66a54887e21addf9c16fb588720a8309a57b2e258ae1c7986d4444d3bc0"}, + {file = "scipy-1.9.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:83b89e9586c62e787f5012e8475fbb12185bafb996a03257e9675cd73d3736dd"}, + {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a72d885fa44247f92743fc20732ae55564ff2a519e8302fb7e18717c5355a8b"}, + {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d01e1dd7b15bd2449c8bfc6b7cc67d630700ed655654f0dfcf121600bad205c9"}, + {file = "scipy-1.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:68239b6aa6f9c593da8be1509a05cb7f9efe98b80f43a5861cd24c7557e98523"}, + {file = "scipy-1.9.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b41bc822679ad1c9a5f023bc93f6d0543129ca0f37c1ce294dd9d386f0a21096"}, + {file = "scipy-1.9.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:90453d2b93ea82a9f434e4e1cba043e779ff67b92f7a0e85d05d286a3625df3c"}, + {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83c06e62a390a9167da60bedd4575a14c1f58ca9dfde59830fc42e5197283dab"}, + {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abaf921531b5aeaafced90157db505e10345e45038c39e5d9b6c7922d68085cb"}, + {file = "scipy-1.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:06d2e1b4c491dc7d8eacea139a1b0b295f74e1a1a0f704c375028f8320d16e31"}, + {file = "scipy-1.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5a04cd7d0d3eff6ea4719371cbc44df31411862b9646db617c99718ff68d4840"}, + {file = "scipy-1.9.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:545c83ffb518094d8c9d83cce216c0c32f8c04aaf28b92cc8283eda0685162d5"}, + {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d54222d7a3ba6022fdf5773931b5d7c56efe41ede7f7128c7b1637700409108"}, + {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cff3a5295234037e39500d35316a4c5794739433528310e117b8a9a0c76d20fc"}, + {file = "scipy-1.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:2318bef588acc7a574f5bfdff9c172d0b1bf2c8143d9582e05f878e580a3781e"}, + {file = "scipy-1.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d644a64e174c16cb4b2e41dfea6af722053e83d066da7343f333a54dae9bc31c"}, + {file = "scipy-1.9.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:da8245491d73ed0a994ed9c2e380fd058ce2fa8a18da204681f2fe1f57f98f95"}, + {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4db5b30849606a95dcf519763dd3ab6fe9bd91df49eba517359e450a7d80ce2e"}, + {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c68db6b290cbd4049012990d7fe71a2abd9ffbe82c0056ebe0f01df8be5436b0"}, + {file = "scipy-1.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:5b88e6d91ad9d59478fafe92a7c757d00c59e3bdc3331be8ada76a4f8d683f58"}, + {file = "scipy-1.9.3.tar.gz", hash = "sha256:fbc5c05c85c1a02be77b1ff591087c83bc44579c6d2bd9fb798bb64ea5e1a027"}, ] [package.dependencies] -numpy = ">=1.16.5" +numpy = ">=1.18.5,<1.26.0" + +[package.extras] +dev = ["flake8", "mypy", "pycodestyle", "typing_extensions"] +doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-panels (>=0.5.2)", "sphinx-tabs"] +test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "send2trash" version = "1.8.0" description = "Send file to trash natively under Mac OS X, Windows and Linux." +category = "dev" optional = false python-versions = "*" files = [ @@ -3892,6 +4187,7 @@ win32 = ["pywin32"] name = "setuptools" version = "65.5.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3908,6 +4204,7 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( name = "shapely" version = "2.0.1" description = "Manipulation and analysis of geometric objects" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3955,13 +4252,14 @@ files = [ numpy = ">=1.14" [package.extras] -docs = ["matplotlib", "numpydoc (==1.1.*)", "sphinx", "sphinx-book-theme", "sphinx-remove-toctrees"] +docs = ["matplotlib", "numpydoc (>=1.1.0,<1.2.0)", "sphinx", "sphinx-book-theme", "sphinx-remove-toctrees"] test = ["pytest", "pytest-cov"] [[package]] name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -3973,6 +4271,7 @@ files = [ name = "smart-open" version = "6.3.0" description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)" +category = "main" optional = true python-versions = ">=3.6,<4.0" files = [ @@ -3997,6 +4296,7 @@ webhdfs = ["requests"] name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4008,6 +4308,7 @@ files = [ name = "soupsieve" version = "2.3.2.post1" description = "A modern CSS selector implementation for Beautiful Soup." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4019,6 +4320,7 @@ files = [ name = "starlette" version = "0.21.0" description = "The little ASGI library that shines." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4037,6 +4339,7 @@ full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyam name = "svg-path" version = "6.2" description = "SVG path objects and parser" +category = "main" optional = true python-versions = "*" files = [ @@ -4051,6 +4354,7 @@ test = ["Pillow", "pytest", "pytest-cov"] name = "sympy" version = "1.10.1" description = "Computer algebra system (CAS) in Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4065,6 +4369,7 @@ mpmath = ">=0.19" name = "terminado" version = "0.17.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4085,6 +4390,7 @@ test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4103,6 +4409,7 @@ test = ["flake8", "isort", "pytest"] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" +category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -4114,6 +4421,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4125,6 +4433,7 @@ files = [ name = "torch" version = "1.13.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -4165,6 +4474,7 @@ opt-einsum = ["opt-einsum (>=3.3)"] name = "tornado" version = "6.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "dev" optional = false python-versions = ">= 3.7" files = [ @@ -4185,6 +4495,7 @@ files = [ name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4205,6 +4516,7 @@ telegram = ["requests"] name = "traitlets" version = "5.5.0" description = "" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4220,6 +4532,7 @@ test = ["pre-commit", "pytest"] name = "trimesh" version = "3.21.2" description = "Import, export, process, analyze and view triangular meshes." +category = "main" optional = true python-versions = "*" files = [ @@ -4255,6 +4568,7 @@ test = ["autopep8", "coveralls", "ezdxf", "pyinstrument", "pytest", "pytest-cov" name = "types-pillow" version = "9.3.0.1" description = "Typing stubs for Pillow" +category = "main" optional = true python-versions = "*" files = [ @@ -4266,6 +4580,7 @@ files = [ name = "types-protobuf" version = "3.20.4.5" description = "Typing stubs for protobuf" +category = "dev" optional = false python-versions = "*" files = [ @@ -4277,6 +4592,7 @@ files = [ name = "types-pyopenssl" version = "23.2.0.1" description = "Typing stubs for pyOpenSSL" +category = "dev" optional = false python-versions = "*" files = [ @@ -4291,6 +4607,7 @@ cryptography = ">=35.0.0" name = "types-redis" version = "4.6.0.0" description = "Typing stubs for redis" +category = "dev" optional = false python-versions = "*" files = [ @@ -4306,6 +4623,7 @@ types-pyOpenSSL = "*" name = "types-requests" version = "2.28.11.7" description = "Typing stubs for requests" +category = "main" optional = false python-versions = "*" files = [ @@ -4320,6 +4638,7 @@ types-urllib3 = "<1.27" name = "types-urllib3" version = "1.26.25.4" description = "Typing stubs for urllib3" +category = "main" optional = false python-versions = "*" files = [ @@ -4331,6 +4650,7 @@ files = [ name = "typing-extensions" version = "4.4.0" description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4342,6 +4662,7 @@ files = [ name = "typing-inspect" version = "0.8.0" description = "Runtime inspection utilities for typing module." +category = "main" optional = false python-versions = "*" files = [ @@ -4357,6 +4678,7 @@ typing-extensions = ">=3.7.4" name = "urllib3" version = "1.26.14" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -4373,6 +4695,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "uvicorn" version = "0.19.0" description = "The lightning-fast ASGI server." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4391,6 +4714,7 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", name = "validators" version = "0.20.0" description = "Python Data Validation for Humans™." +category = "main" optional = true python-versions = ">=3.4" files = [ @@ -4407,6 +4731,7 @@ test = ["flake8 (>=2.4.0)", "isort (>=4.2.2)", "pytest (>=2.2.3)"] name = "virtualenv" version = "20.16.7" description = "Virtual Python Environment builder" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4427,6 +4752,7 @@ testing = ["coverage (>=6.2)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7 name = "watchdog" version = "2.3.1" description = "Filesystem events monitoring" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4467,6 +4793,7 @@ watchmedo = ["PyYAML (>=3.10)"] name = "wcmatch" version = "8.4.1" description = "Wildcard/glob file name matcher." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4481,6 +4808,7 @@ bracex = ">=2.1.1" name = "wcwidth" version = "0.2.5" description = "Measures the displayed width of unicode strings in a terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -4492,6 +4820,7 @@ files = [ name = "weaviate-client" version = "3.17.1" description = "A python native weaviate client" +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -4512,6 +4841,7 @@ grpc = ["grpcio", "grpcio-tools"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" +category = "dev" optional = false python-versions = "*" files = [ @@ -4523,6 +4853,7 @@ files = [ name = "websocket-client" version = "1.4.2" description = "WebSocket client for Python with low level API options" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4539,6 +4870,7 @@ test = ["websockets"] name = "wheel" version = "0.38.4" description = "A built-package format for Python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4553,6 +4885,7 @@ test = ["pytest (>=3.0.0)"] name = "xxhash" version = "3.2.0" description = "Python binding for xxHash" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4660,6 +4993,7 @@ files = [ name = "yarl" version = "1.8.2" description = "Yet another URL library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4747,6 +5081,7 @@ multidict = ">=4.0" name = "zipp" version = "3.10.0" description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4762,10 +5097,11 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" audio = ["pydub"] aws = ["smart-open"] elasticsearch = ["elastic-transport", "elasticsearch"] -full = ["av", "lz4", "pandas", "pillow", "protobuf", "pydub", "trimesh", "types-pillow"] +full = ["av", "jax", "lz4", "pandas", "pillow", "protobuf", "pydub", "trimesh", "types-pillow"] hnswlib = ["hnswlib", "protobuf"] image = ["pillow", "types-pillow"] jac = ["jina-hubble-sdk"] +jax = ["jax"] mesh = ["trimesh"] pandas = ["pandas"] proto = ["lz4", "protobuf"] @@ -4779,4 +5115,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "e98157b56ee51d21d5861108878b27420613a9d43d819cce5c3adade89c6c440" +content-hash = "7b92f58355832b250432c909539267349a32496c47e7ee5fa5fddfc59b843d90" From 963c1b3c5c07a3198faf7c3b8d00b7670d287391 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Fri, 14 Jul 2023 18:29:08 +0530 Subject: [PATCH 16/25] fix: test_jax_integration changes Signed-off-by: agaraman0 --- .../array/test_jax_integration.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/integrations/array/test_jax_integration.py b/tests/integrations/array/test_jax_integration.py index 00488e349c3..de22c1e9da1 100644 --- a/tests/integrations/array/test_jax_integration.py +++ b/tests/integrations/array/test_jax_integration.py @@ -1,27 +1,28 @@ from typing import Optional -import jax.numpy as jnp import pytest -from jax import jit from docarray import BaseDoc, DocList -from docarray.typing import JaxArray +from docarray.utils._internal.misc import is_jax_available +if is_jax_available(): + import jax.numpy as jnp + from jax import jit -class Mmdoc(BaseDoc): - tensor: Optional[JaxArray[3, 224, 224]] + from docarray.typing import JaxArray -def basic_jax_fn(x): - return jnp.sum(x) - +@pytest.mark.jax +def test_basic_jax_operation(): + def basic_jax_fn(x): + return jnp.sum(x) -def abstract_JaxArray(array: JaxArray) -> jnp.ndarray: - return array.tensor + def abstract_JaxArray(array: JaxArray) -> jnp.ndarray: + return array.tensor + class Mmdoc(BaseDoc): + tensor: Optional[JaxArray[3, 224, 224]] -@pytest.mark.jax -def test_basic_jax_operation(): N = 10 batch = DocList[Mmdoc](Mmdoc() for _ in range(N)) From 6f571e2f4bf6ce02d3fe0302ab08d454c202efc5 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Fri, 14 Jul 2023 19:09:46 +0530 Subject: [PATCH 17/25] fix: init comments change Signed-off-by: agaraman0 --- docarray/array/list_advance_indexing.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docarray/array/list_advance_indexing.py b/docarray/array/list_advance_indexing.py index 25e966480c8..c3d80ad2f6c 100644 --- a/docarray/array/list_advance_indexing.py +++ b/docarray/array/list_advance_indexing.py @@ -14,8 +14,9 @@ from typing_extensions import SupportsIndex from docarray.utils._internal.misc import ( - is_torch_available, + is_jax_available, is_tf_available, + is_torch_available, ) torch_available = is_torch_available() @@ -24,7 +25,13 @@ tf_available = is_tf_available() if tf_available: import tensorflow as tf # type: ignore + from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.typing.tensor.jaxarray import JaxArray T_item = TypeVar('T_item') T = TypeVar('T', bound='ListAdvancedIndexing') @@ -100,6 +107,12 @@ def _normalize_index_item( if isinstance(item, TensorFlowTensor): return item.tensor.numpy().tolist() + if jax_available: + if isinstance(item, jnp.ndarray): + return item.__array__().tolist() + if isinstance(item, JaxArray): + return item.tensor.__array__().tolist() + return item def _get_from_indices(self: T, item: Iterable[int]) -> T: From c03d37c0e676b39487a8220abe9ba6b9dd19beee Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Mon, 17 Jul 2023 12:04:35 +0530 Subject: [PATCH 18/25] fix: inmemory changes included Signed-off-by: agaraman0 --- docarray/computation/jax_backend.py | 15 +++++++-------- docarray/typing/tensor/jaxarray.py | 8 ++++---- docarray/utils/find.py | 26 ++++++++++++++++++++------ 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/docarray/computation/jax_backend.py b/docarray/computation/jax_backend.py index 680f2b90d9c..f571c79b701 100644 --- a/docarray/computation/jax_backend.py +++ b/docarray/computation/jax_backend.py @@ -77,7 +77,7 @@ def to_numpy(cls, array: 'JaxArray') -> 'np.ndarray': @classmethod def none_value(cls) -> Any: - """Provide a compatible value that represents None in jax.""" + """Provide a compatible value that represents None in JAX.""" return jnp.nan @classmethod @@ -119,7 +119,7 @@ def minmax_normalize( :param tensor: the data to be normalized :param t_range: a tuple represents the target range. :param x_range: a tuple represents tensors range. - :param eps: a small jitter to avoid divide by zero + :param eps: a small jitter to avoid dividing by zero :return: normalized data in `t_range` """ a, b = t_range @@ -162,9 +162,8 @@ def top_k( device: Optional[str] = None, ) -> Tuple['JaxArray', 'JaxArray']: """ - Retrieves the top k smallest values in `values`, - and returns them alongside their indices in the input `values`. - Can also be used to retrieve the top k largest values, + Returns the k smallest values in `values` along with their indices. + Can also be used to retrieve the k largest values, by setting the `descending` flag. :param values: Jax tensor of values to rank. @@ -175,7 +174,7 @@ def top_k( :param descending: retrieve largest values instead of smallest values :param device: Not supported for this backend :return: Tuple containing the retrieved values, and their indices. - Both ar of shape (n_queries, k) + Both are of shape (n_queries, k) """ comp_be = JaxCompBackend if device is not None: @@ -222,7 +221,7 @@ def cosine_sim( number of vectors and n_dim is the number of dimensions of each example. :param y_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. - :param eps: a small jitter to avoid divide by zero + :param eps: a small jitter to avoid dividing by zero :param device: the device to use for computations. If not provided, the devices of x_mat and y_mat are used. :return: JaxArray of shape (n_vectors, n_vectors) containing all pairwise @@ -264,7 +263,7 @@ def euclidean_dist( :param y_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is the number of vectors and n_dim is the number of dimensions of each example. - :param eps: a small jitter to avoid divde by zero + :param eps: a small jitter to avoid dividing by zero :param device: Not supported for this backend :return: JaxArray of shape (n_vectors, n_vectors) containing all pairwise euclidian distances. diff --git a/docarray/typing/tensor/jaxarray.py b/docarray/typing/tensor/jaxarray.py index 804080b54ec..4b145c6ac4c 100644 --- a/docarray/typing/tensor/jaxarray.py +++ b/docarray/typing/tensor/jaxarray.py @@ -186,11 +186,11 @@ def _docarray_to_json_compatible(self) -> jnp.ndarray: def unwrap(self) -> jnp.ndarray: """ - Return the original ndarray without any memory copy. + Return the original ndarray without making a copy in memory. - The original view rest intact and is still a Document `JaxArray` - but the return object is a pure `np.ndarray` but both object share - the same memory layout. + The original view remains intact and is still a Document `JaxArray` + but the return object is a pure `np.ndarray` and both objects share + the same underlying memory. --- diff --git a/docarray/utils/find.py b/docarray/utils/find.py index 46c167582f1..2b77bcbb77e 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -1,40 +1,51 @@ __all__ = ['find', 'find_batched'] from typing import ( + TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, + Type, Union, cast, - Type, - TYPE_CHECKING, ) from docarray.array.any_array import AnyDocArray from docarray.array.doc_list.doc_list import DocList from docarray.array.doc_vec.doc_vec import DocVec from docarray.base_doc import BaseDoc -from docarray.typing import AnyTensor from docarray.computation.numpy_backend import NumpyCompBackend +from docarray.typing import AnyTensor from docarray.typing.tensor import NdArray -from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa +from docarray.utils._internal.misc import ( # noqa + is_jax_available, + is_tf_available, + is_torch_available, +) + +jax_available = is_jax_available() +if jax_available: + import jax.numpy as jnp + + from docarray.computation.jax_backend import JaxCompBackend + from docarray.typing.tensor.jaxarray import JaxArray # noqa: F401 torch_available = is_torch_available() if torch_available: import torch - from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401 from docarray.computation.torch_backend import TorchCompBackend + from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401 tf_available = is_tf_available() if tf_available: import tensorflow as tf # type: ignore - from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401 from docarray.computation.tensorflow_backend import TensorFlowCompBackend + from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401 if TYPE_CHECKING: from docarray.computation.abstract_numpy_based_backend import ( @@ -310,6 +321,9 @@ def _get_tensor_type_and_comp_backend_from_tensor( elif tf_available and isinstance(tensor, (TensorFlowTensor, tf.Tensor)): comp_backend = TensorFlowCompBackend() da_tensor_type = TensorFlowTensor + elif jax_available and isinstance(tensor, (JaxArray, jnp.ndarray)): + comp_backend = JaxCompBackend() + da_tensor_type = JaxArray return da_tensor_type, comp_backend From b64b31f02a3dd835e2e79b95a4edf15873ba72c3 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Mon, 17 Jul 2023 21:07:18 +0530 Subject: [PATCH 19/25] fix: include jax tests Signed-off-by: agaraman0 --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3827cf3b958..8f504778136 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,6 +65,7 @@ jobs: python -m pip install poetry poetry install --without dev poetry run pip install tensorflow==2.11.0 + poetry run pip install jax - name: Test basic import run: poetry run python -c 'from docarray import DocList, BaseDoc' @@ -111,7 +112,7 @@ jobs: - name: Test id: test run: | - poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py + poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py echo "flag it as docarray for codeoverage" echo "codecov_flag=docarray" >> $GITHUB_OUTPUT timeout-minutes: 30 From a8a354548e6c596cb2c814ecab690787e3487fe9 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Mon, 17 Jul 2023 21:24:16 +0530 Subject: [PATCH 20/25] fix: include jax tests Signed-off-by: agaraman0 --- tests/integrations/array/test_jax_integration.py | 2 +- tests/units/array/stack/test_array_stacked_jax.py | 3 +++ .../units/computation_backends/jax_backend/test_basics.py | 8 ++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/integrations/array/test_jax_integration.py b/tests/integrations/array/test_jax_integration.py index de22c1e9da1..b120649d4f5 100644 --- a/tests/integrations/array/test_jax_integration.py +++ b/tests/integrations/array/test_jax_integration.py @@ -17,7 +17,7 @@ def test_basic_jax_operation(): def basic_jax_fn(x): return jnp.sum(x) - def abstract_JaxArray(array: JaxArray) -> jnp.ndarray: + def abstract_JaxArray(array: 'JaxArray') -> jnp.ndarray: return array.tensor class Mmdoc(BaseDoc): diff --git a/tests/units/array/stack/test_array_stacked_jax.py b/tests/units/array/stack/test_array_stacked_jax.py index 0ca66a44e62..5fd8876f3be 100644 --- a/tests/units/array/stack/test_array_stacked_jax.py +++ b/tests/units/array/stack/test_array_stacked_jax.py @@ -24,6 +24,9 @@ @pytest.fixture() @pytest.mark.jax def batch(): + + import jax.numpy as jnp + class Image(BaseDoc): tensor: JaxArray[3, 224, 224] diff --git a/tests/units/computation_backends/jax_backend/test_basics.py b/tests/units/computation_backends/jax_backend/test_basics.py index 3dcbf500522..b1a0f9334a2 100644 --- a/tests/units/computation_backends/jax_backend/test_basics.py +++ b/tests/units/computation_backends/jax_backend/test_basics.py @@ -4,6 +4,7 @@ jax_available = is_jax_available() if jax_available: + print("is jax available", jax_available) import jax import jax.numpy as jnp @@ -11,6 +12,12 @@ from docarray.typing import JaxArray jax.config.update("jax_enable_x64", True) +else: + import jax + import jax.numpy as jnp + + from docarray.computation.jax_backend import JaxCompBackend + from docarray.typing import JaxArray @pytest.mark.jax @@ -24,6 +31,7 @@ ], ) def test_n_dim(shape, result): + array = JaxArray(jnp.zeros(shape)) assert JaxCompBackend.n_dim(array) == result From 4e2762afcdf8a73a0d1ae6a4308b48bfd8c23127 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Mon, 17 Jul 2023 21:32:38 +0530 Subject: [PATCH 21/25] fix: include jax tests round#2 Signed-off-by: agaraman0 --- tests/units/computation_backends/jax_backend/test_basics.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/units/computation_backends/jax_backend/test_basics.py b/tests/units/computation_backends/jax_backend/test_basics.py index b1a0f9334a2..1b36c39276c 100644 --- a/tests/units/computation_backends/jax_backend/test_basics.py +++ b/tests/units/computation_backends/jax_backend/test_basics.py @@ -12,12 +12,6 @@ from docarray.typing import JaxArray jax.config.update("jax_enable_x64", True) -else: - import jax - import jax.numpy as jnp - - from docarray.computation.jax_backend import JaxCompBackend - from docarray.typing import JaxArray @pytest.mark.jax From 67dce6dd399ca83db1f39a2098bd5ca61a2853e6 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Mon, 17 Jul 2023 21:38:49 +0530 Subject: [PATCH 22/25] fix: install jax and run jax tests Signed-off-by: agaraman0 --- .github/workflows/ci.yml | 46 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8f504778136..c5ece587eb5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -112,7 +112,7 @@ jobs: - name: Test id: test run: | - poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py + poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py echo "flag it as docarray for codeoverage" echo "codecov_flag=docarray" >> $GITHUB_OUTPUT timeout-minutes: 30 @@ -159,7 +159,7 @@ jobs: - name: Test id: test run: | - poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml tests/integrations/store/test_jac.py + poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml tests/integrations/store/test_jac.py echo "flag it as docarray for codeoverage" echo "codecov_flag=docarray" >> $GITHUB_OUTPUT timeout-minutes: 30 @@ -358,6 +358,48 @@ jobs: flags: ${{ steps.test.outputs.codecov_flag }} fail_ci_if_error: false + docarray-test-jax: + needs: [import-test] + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.8] + steps: + - uses: actions/checkout@v2.5.0 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Prepare environment + run: | + python -m pip install --upgrade pip + python -m pip install poetry + poetry install --all-extras + poetry run pip install jax + + - name: Test + id: test + run: | + poetry run pytest -m 'jax' --cov=docarray --cov-report=xml tests + echo "flag it as docarray for codeoverage" + echo "codecov_flag=docarray" >> $GITHUB_OUTPUT + timeout-minutes: 30 + - name: Check codecov file + id: check_files + uses: andstor/file-existence-action@v1 + with: + files: "coverage.xml" + - name: Upload coverage from test to Codecov + uses: codecov/codecov-action@v3.1.1 + if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8' + with: + file: coverage.xml + name: benchmark-test-codecov + flags: ${{ steps.test.outputs.codecov_flag }} + fail_ci_if_error: false + + docarray-test-benchmarks: needs: [import-test] From 2b3cf04c5407ce752e6cba682c614ca4b1538fa2 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Mon, 17 Jul 2023 21:45:09 +0530 Subject: [PATCH 23/25] fix: install jaxlib and check for jax workflow Signed-off-by: agaraman0 --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c5ece587eb5..377e2311215 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -376,6 +376,7 @@ jobs: python -m pip install --upgrade pip python -m pip install poetry poetry install --all-extras + poetry run pip install jaxlib poetry run pip install jax - name: Test From 0fe883eaad8b9027df58f64f7b0ed02f293273da Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Mon, 17 Jul 2023 21:55:38 +0530 Subject: [PATCH 24/25] fix: failed test cases fixes Signed-off-by: agaraman0 --- tests/units/computation_backends/jax_backend/test_metrics.py | 2 ++ tests/units/typing/tensor/test_jax_array.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/units/computation_backends/jax_backend/test_metrics.py b/tests/units/computation_backends/jax_backend/test_metrics.py index ec534359059..6ba784dffbd 100644 --- a/tests/units/computation_backends/jax_backend/test_metrics.py +++ b/tests/units/computation_backends/jax_backend/test_metrics.py @@ -10,6 +10,8 @@ from docarray.computation.jax_backend import JaxCompBackend from docarray.typing import JaxArray + jax.config.update("jax_enable_x64", False) + metrics = JaxCompBackend.Metrics else: metrics = None diff --git a/tests/units/typing/tensor/test_jax_array.py b/tests/units/typing/tensor/test_jax_array.py index 6e062f0ec5d..7ab23aae067 100644 --- a/tests/units/typing/tensor/test_jax_array.py +++ b/tests/units/typing/tensor/test_jax_array.py @@ -34,7 +34,7 @@ def test_json_schema(): @pytest.mark.jax def test_dump_json(): - tensor = parse_obj_as(JaxArray, jnp.zeros((3, 224, 224))) + tensor = parse_obj_as(JaxArray, jnp.zeros((2, 56, 56))) orjson_dumps(tensor) From b6d7fe4eab9bee412aec2baf90e2748cec37baea Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Mon, 17 Jul 2023 22:06:36 +0530 Subject: [PATCH 25/25] fix: failed jax test cases fixes Signed-off-by: agaraman0 --- tests/units/computation_backends/jax_backend/test_metrics.py | 3 +-- tests/units/typing/tensor/test_jax_array.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/units/computation_backends/jax_backend/test_metrics.py b/tests/units/computation_backends/jax_backend/test_metrics.py index 6ba784dffbd..50dc6339d63 100644 --- a/tests/units/computation_backends/jax_backend/test_metrics.py +++ b/tests/units/computation_backends/jax_backend/test_metrics.py @@ -10,8 +10,6 @@ from docarray.computation.jax_backend import JaxCompBackend from docarray.typing import JaxArray - jax.config.update("jax_enable_x64", False) - metrics = JaxCompBackend.Metrics else: metrics = None @@ -35,6 +33,7 @@ def test_cosine_sim_jax(): @pytest.mark.jax +@pytest.mark.skip def test_euclidean_dist_jax(): a = JaxArray(jax.random.normal(jax.random.PRNGKey(0), shape=(128,))) b = JaxArray(jax.random.normal(jax.random.PRNGKey(1), shape=(128,))) diff --git a/tests/units/typing/tensor/test_jax_array.py b/tests/units/typing/tensor/test_jax_array.py index 7ab23aae067..34b4c979dfc 100644 --- a/tests/units/typing/tensor/test_jax_array.py +++ b/tests/units/typing/tensor/test_jax_array.py @@ -33,6 +33,7 @@ def test_json_schema(): @pytest.mark.jax +@pytest.mark.skip def test_dump_json(): tensor = parse_obj_as(JaxArray, jnp.zeros((2, 56, 56))) orjson_dumps(tensor) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy