diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c43bfa7b4..cc5b769d59 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -94,7 +94,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8] + python-version: [3.9] pydantic-version: ["pydantic-v2", "pydantic-v1"] test-path: [tests/integrations, tests/units, tests/documentation] steps: @@ -112,6 +112,7 @@ jobs: ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip uninstall -y torch poetry run pip install torch + poetry run pip install numpy==1.26.1 sudo apt-get update sudo apt-get install --no-install-recommends ffmpeg diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 4154f3248a..4d45f1369a 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -22,7 +22,7 @@ import typing_extensions from pydantic import BaseModel, Field from pydantic.fields import FieldInfo -from typing_inspect import is_optional_type +from typing_inspect import get_args, is_optional_type from docarray.utils._internal.pydantic import is_pydantic_v2 @@ -185,7 +185,7 @@ def _get_field_annotation(cls, field: str) -> Type: if is_optional_type( annotation ): # this is equivalent to `outer_type_` in pydantic v1 - return annotation.__args__[0] + return get_args(annotation)[0] else: return annotation else: @@ -205,12 +205,12 @@ def _get_field_inner_type(cls, field: str) -> Type: if is_optional_type( annotation ): # this is equivalent to `outer_type_` in pydantic v1 - return annotation.__args__[0] + return get_args(annotation)[0] elif annotation == Tuple: - if len(annotation.__args__) == 0: + if len(get_args(annotation)) == 0: return Any else: - annotation.__args__[0] + get_args(annotation)[0] else: return annotation else: diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 958897555c..cc4a3470d7 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -336,7 +336,7 @@ def _get_content_from_node_proto( field_type = None if isinstance(field_type, GenericAlias): - field_type = field_type.__args__[0] + field_type = get_args(field_type)[0] return_field = arg_to_container[content_key]( cls._get_content_from_node_proto(node, field_type=field_type) diff --git a/docarray/display/document_summary.py b/docarray/display/document_summary.py index 7a3730016e..265236a8d3 100644 --- a/docarray/display/document_summary.py +++ b/docarray/display/document_summary.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Type, Union +from typing import Any, List, Optional, Type, Union, get_args from rich.highlighter import RegexHighlighter from rich.theme import Theme @@ -83,7 +83,7 @@ def _get_schema( if is_union_type(field_type) or is_optional_type(field_type): sub_tree = Tree(node_name, highlight=True) - for arg in field_type.__args__: + for arg in get_args(field_type): if safe_issubclass(arg, BaseDoc): sub_tree.add( DocumentSummary._get_schema( diff --git a/docarray/helper.py b/docarray/helper.py index d242b05ea9..34b0c2bfd4 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -15,7 +15,23 @@ Union, ) +import numpy as np + from docarray.utils._internal._typing import safe_issubclass +from docarray.utils._internal.misc import ( + is_jax_available, + is_tf_available, + is_torch_available, +) + +if is_torch_available(): + import torch + +if is_jax_available(): + import jax + +if is_tf_available(): + import tensorflow as tf if TYPE_CHECKING: from docarray import BaseDoc @@ -54,6 +70,35 @@ def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: return result +def _is_none_like(val: Any) -> bool: + """ + :param val: any value + :return: true iff `val` equals to `None`, `'None'` or `''` + """ + # Convoluted implementation, but fixes https://github.com/docarray/docarray/issues/1821 + + # tensor-like types can have unexpected (= broadcast) `==`/`in` semantics, + # so treat separately + is_np_arr = isinstance(val, np.ndarray) + if is_np_arr: + return False + + is_torch_tens = is_torch_available() and isinstance(val, torch.Tensor) + if is_torch_tens: + return False + + is_tf_tens = is_tf_available() and isinstance(val, tf.Tensor) + if is_tf_tens: + return False + + is_jax_arr = is_jax_available() and isinstance(val, jax.numpy.ndarray) + if is_jax_arr: + return False + + # "normal" case + return val in ['', 'None', None] + + def _access_path_dict_to_nested_dict(access_path2val: Dict[str, Any]) -> Dict[Any, Any]: """ Convert a dict, where the keys are access paths ("__"-separated) to a nested dictionary. @@ -76,7 +121,7 @@ def _access_path_dict_to_nested_dict(access_path2val: Dict[str, Any]) -> Dict[An for access_path, value in access_path2val.items(): field2val = _access_path_to_dict( access_path=access_path, - value=value if value not in ['', 'None'] else None, + value=None if _is_none_like(value) else value, ) _update_nested_dicts(to_update=nested_dict, update_with=field2val) return nested_dict diff --git a/tests/units/array/test_array_from_to_pandas.py b/tests/units/array/test_array_from_to_pandas.py index d89902c2f8..440398562f 100644 --- a/tests/units/array/test_array_from_to_pandas.py +++ b/tests/units/array/test_array_from_to_pandas.py @@ -136,7 +136,8 @@ class BasisUnion(BaseDoc): @pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor]) -def test_from_to_pandas_tensor_type(tensor_type): +@pytest.mark.parametrize('tensor_len', [0, 5]) +def test_from_to_pandas_tensor_type(tensor_type, tensor_len): class MyDoc(BaseDoc): embedding: tensor_type text: str @@ -145,9 +146,13 @@ class MyDoc(BaseDoc): da = DocVec[MyDoc]( [ MyDoc( - embedding=[1, 2, 3, 4, 5], text='hello', image=ImageDoc(url='aux.png') + embedding=list(range(tensor_len)), + text='hello', + image=ImageDoc(url='aux.png'), + ), + MyDoc( + embedding=list(range(tensor_len)), text='hello world', image=ImageDoc() ), - MyDoc(embedding=[5, 4, 3, 2, 1], text='hello world', image=ImageDoc()), ], tensor_type=tensor_type, ) diff --git a/tests/units/typing/tensor/test_ndarray.py b/tests/units/typing/tensor/test_ndarray.py index 49d5d34d1b..93ed58b382 100644 --- a/tests/units/typing/tensor/test_ndarray.py +++ b/tests/units/typing/tensor/test_ndarray.py @@ -200,9 +200,9 @@ def test_parametrized_instance(): def test_parametrized_equality(): t1 = parse_obj_as(NdArray[128], np.zeros(128)) t2 = parse_obj_as(NdArray[128], np.zeros(128)) - t3 = parse_obj_as(NdArray[256], np.zeros(256)) + t3 = parse_obj_as(NdArray[128], np.ones(128)) assert (t1 == t2).all() - assert not t1 == t3 + assert not (t1 == t3).any() def test_parametrized_operations():
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: