From a8645899186f0a5e09bddaa16547d052ea7db52d Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Fri, 8 Mar 2024 15:39:42 +0100 Subject: [PATCH 01/28] fix: try to fix doclist schema --- docarray/array/any_array.py | 4 +++- docarray/array/doc_list/doc_list.py | 10 ++++++---- docarray/documents/legacy/legacy_document.py | 6 +++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py index 50c47cf4ec..1c8a591d75 100644 --- a/docarray/array/any_array.py +++ b/docarray/array/any_array.py @@ -16,6 +16,7 @@ Union, cast, overload, + Tuple ) import numpy as np @@ -48,6 +49,8 @@ class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType): doc_type: Type[BaseDocWithoutId] __typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {} + # __origin__: Type['AnyDocArray'] = cls # add this + # __args__: Tuple[Any, ...] = (item,) # add this def __repr__(self): return f'<{self.__class__.__name__} (length={len(self)})>' @@ -72,7 +75,6 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): if item not in cls.__typed_da__[cls]: # Promote to global scope so multiprocessing can pickle it global _DocArrayTyped - class _DocArrayTyped(cls): # type: ignore doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item) diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index c21cf93413..43cd0a72d4 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -12,6 +12,8 @@ Union, cast, overload, + Callable, + get_args ) from pydantic import parse_obj_as @@ -357,8 +359,8 @@ def __repr__(self): @classmethod def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: GetCoreSchemaHandler + cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema] ) -> core_schema.CoreSchema: - return core_schema.general_plain_validator_function( - cls.validate, - ) + return core_schema.with_info_after_validator_function( + function=cls.validate, + schema=core_schema.list_schema(core_schema.any_schema())) diff --git a/docarray/documents/legacy/legacy_document.py b/docarray/documents/legacy/legacy_document.py index dc77f10d0b..e4165d5207 100644 --- a/docarray/documents/legacy/legacy_document.py +++ b/docarray/documents/legacy/legacy_document.py @@ -15,7 +15,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List, Union from docarray import BaseDoc, DocList from docarray.typing import AnyEmbedding, AnyTensor @@ -50,8 +50,8 @@ class LegacyDocument(BaseDoc): """ tensor: Optional[AnyTensor] = None - chunks: Optional[DocList[LegacyDocument]] = None - matches: Optional[DocList[LegacyDocument]] = None + chunks: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None + matches: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None blob: Optional[bytes] = None text: Optional[str] = None url: Optional[str] = None From 951679cb50604eee919653303df7d69e4c6a019c Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Fri, 8 Mar 2024 18:07:39 +0100 Subject: [PATCH 02/28] chore: push tmp changes --- aux.py | 96 +++++++++++++++++++++++++++++ docarray/array/any_array.py | 17 +++-- docarray/array/doc_list/doc_list.py | 30 ++++++++- 3 files changed, 137 insertions(+), 6 deletions(-) create mode 100644 aux.py diff --git a/aux.py b/aux.py new file mode 100644 index 0000000000..84076be1d7 --- /dev/null +++ b/aux.py @@ -0,0 +1,96 @@ +from typing import Sequence, TypeVar, Any, Callable, get_args, Generic + +from pydantic_core import core_schema, ValidationError + +from pydantic import BaseModel + +T = TypeVar('T') + + +class MySequence(Sequence[T], Generic[T]): + def __init__(self, v: Sequence[T]): + self.v = v + + def __getitem__(self, i): + return self.v[i] + + def __len__(self): + return len(self.v) + + @classmethod + def __get_pydantic_core_schema__( + cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema] + ) -> core_schema.CoreSchema: + print(f'source {source} and MySeq handler {handler}') + instance_schema = core_schema.is_instance_schema(cls) + + args = get_args(source) + print(f'args1 {args}') + if args: + sequence_t_schema = handler(Sequence[args[0]]) + else: + sequence_t_schema = handler(Sequence) + + non_instance_schema = core_schema.with_info_after_validator_function( + lambda v, i: MySequence(v), sequence_t_schema + ) + return core_schema.union_schema([instance_schema, non_instance_schema]) + + +class MySequence2(MySequence, Generic[T]): + pass + + +class A(BaseModel): + b: int + +class M(BaseModel): + model_config = dict(validate_default=True) + + s1: MySequence2[A] + + +print(M.schema()) + +args = get_args(MySequence2[A]) +print(f'MySequence2 args {args}') + +from typing import List, Union +from docarray.array.any_array import AnyDocArray +from docarray import BaseDoc, DocList +import pydantic + + +class Doc(BaseDoc): + a: str + + + +print(f'Doc {Doc.schema()}') + + +class DocDoc(BaseDoc): + docs: DocList[Doc] + + +print(DocDoc.schema()) + +args = get_args(DocList[Doc]) +print(f'DocList args {args}') + + +args = get_args(AnyDocArray[Doc]) +print(f'AnyDocArray args {args}') + + + + + + + + + + + + + diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py index 1c8a591d75..745bd24b74 100644 --- a/docarray/array/any_array.py +++ b/docarray/array/any_array.py @@ -16,7 +16,9 @@ Union, cast, overload, - Tuple + Tuple, + get_args, + get_origin, ) import numpy as np @@ -46,17 +48,16 @@ ) -class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType): +class AnyDocArray(AbstractType, Sequence[T_doc], Generic[T_doc]): doc_type: Type[BaseDocWithoutId] __typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {} - # __origin__: Type['AnyDocArray'] = cls # add this - # __args__: Tuple[Any, ...] = (item,) # add this def __repr__(self): return f'<{self.__class__.__name__} (length={len(self)})>' @classmethod def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): + print(f' hey here {item}') if not isinstance(item, type): if sys.version_info < (3, 12): return Generic.__class_getitem__.__func__(cls, item) # type: ignore @@ -75,8 +76,10 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): if item not in cls.__typed_da__[cls]: # Promote to global scope so multiprocessing can pickle it global _DocArrayTyped - class _DocArrayTyped(cls): # type: ignore + class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item) + # __origin__: Type['AnyDocArray'] = cls # add this + # __args__: Tuple[Any, ...] = (item,) # add this for field in _DocArrayTyped.doc_type._docarray_fields().keys(): @@ -109,6 +112,10 @@ def _setter(self, value): cls.__typed_da__[cls][item] = _DocArrayTyped + print(f'return {cls.__typed_da__[cls][item]}') + a = get_args(cls.__typed_da__[cls][item]) + print(f'a {a}') + print(f'get_origin {get_origin(cls.__typed_da__[cls][item])}') return cls.__typed_da__[cls][item] @overload diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index 43cd0a72d4..b22e2f0f68 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -13,7 +13,8 @@ cast, overload, Callable, - get_args + get_args, + Generic ) from pydantic import parse_obj_as @@ -51,6 +52,7 @@ class DocList( PushPullMixin, IOMixinDocList, AnyDocArray[T_doc], + Generic[T_doc] ): """ DocList is a container of Documents. @@ -361,6 +363,32 @@ def __repr__(self): def __get_pydantic_core_schema__( cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema] ) -> core_schema.CoreSchema: + def get_args_2(tp): + """Get type arguments with all substitutions performed. + + For unions, basic simplifications used by Union constructor are performed. + Examples:: + get_args(Dict[str, int]) == (str, int) + get_args(int) == () + get_args(Union[int, Union[T, int], str][int]) == (int, str) + get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + get_args(Callable[[], T][int]) == ([], int) + """ + from typing import _GenericAlias, get_origin + import collections + if isinstance(tp, _GenericAlias): + res = tp.__args__ + if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: + res = (list(res[:-1]), res[-1]) + return res + else: + print(f'IN ELSE') + return () + + instance_schema = core_schema.is_instance_schema(cls) + print(f'instance_schema {instance_schema} and {handler}') + args = get_args_2(DocList[BaseDocWithoutId]) + print(f' args {args}') return core_schema.with_info_after_validator_function( function=cls.validate, schema=core_schema.list_schema(core_schema.any_schema())) From febea8d563febe259b7c0d7d2548ff709c5d9f6d Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Mon, 11 Mar 2024 10:56:09 +0100 Subject: [PATCH 03/28] fix: make DocList properly a Generic --- aux.py | 96 -------------------- docarray/array/any_array.py | 32 ++++--- docarray/array/doc_list/doc_list.py | 47 +++------- docarray/documents/legacy/legacy_document.py | 6 +- 4 files changed, 35 insertions(+), 146 deletions(-) delete mode 100644 aux.py diff --git a/aux.py b/aux.py deleted file mode 100644 index 84076be1d7..0000000000 --- a/aux.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Sequence, TypeVar, Any, Callable, get_args, Generic - -from pydantic_core import core_schema, ValidationError - -from pydantic import BaseModel - -T = TypeVar('T') - - -class MySequence(Sequence[T], Generic[T]): - def __init__(self, v: Sequence[T]): - self.v = v - - def __getitem__(self, i): - return self.v[i] - - def __len__(self): - return len(self.v) - - @classmethod - def __get_pydantic_core_schema__( - cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema] - ) -> core_schema.CoreSchema: - print(f'source {source} and MySeq handler {handler}') - instance_schema = core_schema.is_instance_schema(cls) - - args = get_args(source) - print(f'args1 {args}') - if args: - sequence_t_schema = handler(Sequence[args[0]]) - else: - sequence_t_schema = handler(Sequence) - - non_instance_schema = core_schema.with_info_after_validator_function( - lambda v, i: MySequence(v), sequence_t_schema - ) - return core_schema.union_schema([instance_schema, non_instance_schema]) - - -class MySequence2(MySequence, Generic[T]): - pass - - -class A(BaseModel): - b: int - -class M(BaseModel): - model_config = dict(validate_default=True) - - s1: MySequence2[A] - - -print(M.schema()) - -args = get_args(MySequence2[A]) -print(f'MySequence2 args {args}') - -from typing import List, Union -from docarray.array.any_array import AnyDocArray -from docarray import BaseDoc, DocList -import pydantic - - -class Doc(BaseDoc): - a: str - - - -print(f'Doc {Doc.schema()}') - - -class DocDoc(BaseDoc): - docs: DocList[Doc] - - -print(DocDoc.schema()) - -args = get_args(DocList[Doc]) -print(f'DocList args {args}') - - -args = get_args(AnyDocArray[Doc]) -print(f'AnyDocArray args {args}') - - - - - - - - - - - - - diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py index 745bd24b74..e03b985d88 100644 --- a/docarray/array/any_array.py +++ b/docarray/array/any_array.py @@ -17,8 +17,6 @@ cast, overload, Tuple, - get_args, - get_origin, ) import numpy as np @@ -28,6 +26,7 @@ from docarray.exceptions.exceptions import UnusableObjectError from docarray.typing.abstract_type import AbstractType from docarray.utils._internal._typing import change_cls_name, safe_issubclass +from docarray.utils._internal.pydantic import is_pydantic_v2 if TYPE_CHECKING: from docarray.proto import DocListProto, NodeProto @@ -35,6 +34,11 @@ if sys.version_info >= (3, 12): from types import GenericAlias +else: + try: + from typing import GenericAlias + except: + from typing import _GenericAlias as GenericAlias T = TypeVar('T', bound='AnyDocArray') T_doc = TypeVar('T_doc', bound=BaseDocWithoutId) @@ -48,7 +52,7 @@ ) -class AnyDocArray(AbstractType, Sequence[T_doc], Generic[T_doc]): +class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType): doc_type: Type[BaseDocWithoutId] __typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {} @@ -57,7 +61,6 @@ def __repr__(self): @classmethod def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): - print(f' hey here {item}') if not isinstance(item, type): if sys.version_info < (3, 12): return Generic.__class_getitem__.__func__(cls, item) # type: ignore @@ -76,10 +79,12 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): if item not in cls.__typed_da__[cls]: # Promote to global scope so multiprocessing can pickle it global _DocArrayTyped + class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item) - # __origin__: Type['AnyDocArray'] = cls # add this - # __args__: Tuple[Any, ...] = (item,) # add this + if is_pydantic_v2: + __origin__: Type['AnyDocArray'] = cls # add this + __args__: Tuple[Any, ...] = (item,) # add this for field in _DocArrayTyped.doc_type._docarray_fields().keys(): @@ -109,13 +114,16 @@ def _setter(self, value): change_cls_name( _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals() ) + if is_pydantic_v2: + if sys.version_info < (3, 12): + cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(_DocArrayTyped, item) # type: ignore + # this do nothing that checking that item is valid type var or str + # Keep the approach in #1147 to be compatible with lower versions of Python. + else: + cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) + else: + cls.__typed_da__[cls][item] = _DocArrayTyped - cls.__typed_da__[cls][item] = _DocArrayTyped - - print(f'return {cls.__typed_da__[cls][item]}') - a = get_args(cls.__typed_da__[cls][item]) - print(f'a {a}') - print(f'get_origin {get_origin(cls.__typed_da__[cls][item])}') return cls.__typed_da__[cls][item] @overload diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index b22e2f0f68..8ea4983817 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -14,7 +14,6 @@ overload, Callable, get_args, - Generic ) from pydantic import parse_obj_as @@ -31,7 +30,6 @@ from docarray.utils._internal.pydantic import is_pydantic_v2 if is_pydantic_v2: - from pydantic import GetCoreSchemaHandler from pydantic_core import core_schema from docarray.utils._internal._typing import safe_issubclass @@ -48,11 +46,7 @@ class DocList( - ListAdvancedIndexing[T_doc], - PushPullMixin, - IOMixinDocList, - AnyDocArray[T_doc], - Generic[T_doc] + ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc] ): """ DocList is a container of Documents. @@ -363,32 +357,15 @@ def __repr__(self): def __get_pydantic_core_schema__( cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema] ) -> core_schema.CoreSchema: - def get_args_2(tp): - """Get type arguments with all substitutions performed. - - For unions, basic simplifications used by Union constructor are performed. - Examples:: - get_args(Dict[str, int]) == (str, int) - get_args(int) == () - get_args(Union[int, Union[T, int], str][int]) == (int, str) - get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) - get_args(Callable[[], T][int]) == ([], int) - """ - from typing import _GenericAlias, get_origin - import collections - if isinstance(tp, _GenericAlias): - res = tp.__args__ - if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: - res = (list(res[:-1]), res[-1]) - return res - else: - print(f'IN ELSE') - return () - instance_schema = core_schema.is_instance_schema(cls) - print(f'instance_schema {instance_schema} and {handler}') - args = get_args_2(DocList[BaseDocWithoutId]) - print(f' args {args}') - return core_schema.with_info_after_validator_function( - function=cls.validate, - schema=core_schema.list_schema(core_schema.any_schema())) + + args = get_args(source) + if args: + sequence_t_schema = handler(Sequence[args[0]]) + else: + sequence_t_schema = handler(Sequence) + + non_instance_schema = core_schema.with_info_after_validator_function( + lambda v, i: DocList(v), sequence_t_schema + ) + return core_schema.union_schema([instance_schema, non_instance_schema]) diff --git a/docarray/documents/legacy/legacy_document.py b/docarray/documents/legacy/legacy_document.py index e4165d5207..dc77f10d0b 100644 --- a/docarray/documents/legacy/legacy_document.py +++ b/docarray/documents/legacy/legacy_document.py @@ -15,7 +15,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Dict, Optional, List, Union +from typing import Any, Dict, Optional from docarray import BaseDoc, DocList from docarray.typing import AnyEmbedding, AnyTensor @@ -50,8 +50,8 @@ class LegacyDocument(BaseDoc): """ tensor: Optional[AnyTensor] = None - chunks: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None - matches: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None + chunks: Optional[DocList[LegacyDocument]] = None + matches: Optional[DocList[LegacyDocument]] = None blob: Optional[bytes] = None text: Optional[str] = None url: Optional[str] = None From 11eeb6a4d6bfe73455fb34e8425c74bbab8099e9 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Mon, 11 Mar 2024 10:56:09 +0100 Subject: [PATCH 04/28] fix: make DocList properly a Generic --- aux.py | 96 -------------------- docarray/array/any_array.py | 32 ++++--- docarray/array/doc_list/doc_list.py | 47 +++------- docarray/documents/legacy/legacy_document.py | 6 +- 4 files changed, 35 insertions(+), 146 deletions(-) delete mode 100644 aux.py diff --git a/aux.py b/aux.py deleted file mode 100644 index 84076be1d7..0000000000 --- a/aux.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Sequence, TypeVar, Any, Callable, get_args, Generic - -from pydantic_core import core_schema, ValidationError - -from pydantic import BaseModel - -T = TypeVar('T') - - -class MySequence(Sequence[T], Generic[T]): - def __init__(self, v: Sequence[T]): - self.v = v - - def __getitem__(self, i): - return self.v[i] - - def __len__(self): - return len(self.v) - - @classmethod - def __get_pydantic_core_schema__( - cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema] - ) -> core_schema.CoreSchema: - print(f'source {source} and MySeq handler {handler}') - instance_schema = core_schema.is_instance_schema(cls) - - args = get_args(source) - print(f'args1 {args}') - if args: - sequence_t_schema = handler(Sequence[args[0]]) - else: - sequence_t_schema = handler(Sequence) - - non_instance_schema = core_schema.with_info_after_validator_function( - lambda v, i: MySequence(v), sequence_t_schema - ) - return core_schema.union_schema([instance_schema, non_instance_schema]) - - -class MySequence2(MySequence, Generic[T]): - pass - - -class A(BaseModel): - b: int - -class M(BaseModel): - model_config = dict(validate_default=True) - - s1: MySequence2[A] - - -print(M.schema()) - -args = get_args(MySequence2[A]) -print(f'MySequence2 args {args}') - -from typing import List, Union -from docarray.array.any_array import AnyDocArray -from docarray import BaseDoc, DocList -import pydantic - - -class Doc(BaseDoc): - a: str - - - -print(f'Doc {Doc.schema()}') - - -class DocDoc(BaseDoc): - docs: DocList[Doc] - - -print(DocDoc.schema()) - -args = get_args(DocList[Doc]) -print(f'DocList args {args}') - - -args = get_args(AnyDocArray[Doc]) -print(f'AnyDocArray args {args}') - - - - - - - - - - - - - diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py index 745bd24b74..38d5e60884 100644 --- a/docarray/array/any_array.py +++ b/docarray/array/any_array.py @@ -17,8 +17,6 @@ cast, overload, Tuple, - get_args, - get_origin, ) import numpy as np @@ -28,6 +26,7 @@ from docarray.exceptions.exceptions import UnusableObjectError from docarray.typing.abstract_type import AbstractType from docarray.utils._internal._typing import change_cls_name, safe_issubclass +from docarray.utils._internal.pydantic import is_pydantic_v2 if TYPE_CHECKING: from docarray.proto import DocListProto, NodeProto @@ -35,6 +34,11 @@ if sys.version_info >= (3, 12): from types import GenericAlias +else: + try: + from typing import GenericAlias + except ImportError: + from typing import _GenericAlias as GenericAlias T = TypeVar('T', bound='AnyDocArray') T_doc = TypeVar('T_doc', bound=BaseDocWithoutId) @@ -48,7 +52,7 @@ ) -class AnyDocArray(AbstractType, Sequence[T_doc], Generic[T_doc]): +class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType): doc_type: Type[BaseDocWithoutId] __typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {} @@ -57,7 +61,6 @@ def __repr__(self): @classmethod def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): - print(f' hey here {item}') if not isinstance(item, type): if sys.version_info < (3, 12): return Generic.__class_getitem__.__func__(cls, item) # type: ignore @@ -76,10 +79,12 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): if item not in cls.__typed_da__[cls]: # Promote to global scope so multiprocessing can pickle it global _DocArrayTyped + class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item) - # __origin__: Type['AnyDocArray'] = cls # add this - # __args__: Tuple[Any, ...] = (item,) # add this + if is_pydantic_v2: + __origin__: Type['AnyDocArray'] = cls # add this + __args__: Tuple[Any, ...] = (item,) # add this for field in _DocArrayTyped.doc_type._docarray_fields().keys(): @@ -109,13 +114,16 @@ def _setter(self, value): change_cls_name( _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals() ) + if is_pydantic_v2: + if sys.version_info < (3, 12): + cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(_DocArrayTyped, item) # type: ignore + # this do nothing that checking that item is valid type var or str + # Keep the approach in #1147 to be compatible with lower versions of Python. + else: + cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) + else: + cls.__typed_da__[cls][item] = _DocArrayTyped - cls.__typed_da__[cls][item] = _DocArrayTyped - - print(f'return {cls.__typed_da__[cls][item]}') - a = get_args(cls.__typed_da__[cls][item]) - print(f'a {a}') - print(f'get_origin {get_origin(cls.__typed_da__[cls][item])}') return cls.__typed_da__[cls][item] @overload diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index b22e2f0f68..8ea4983817 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -14,7 +14,6 @@ overload, Callable, get_args, - Generic ) from pydantic import parse_obj_as @@ -31,7 +30,6 @@ from docarray.utils._internal.pydantic import is_pydantic_v2 if is_pydantic_v2: - from pydantic import GetCoreSchemaHandler from pydantic_core import core_schema from docarray.utils._internal._typing import safe_issubclass @@ -48,11 +46,7 @@ class DocList( - ListAdvancedIndexing[T_doc], - PushPullMixin, - IOMixinDocList, - AnyDocArray[T_doc], - Generic[T_doc] + ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc] ): """ DocList is a container of Documents. @@ -363,32 +357,15 @@ def __repr__(self): def __get_pydantic_core_schema__( cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema] ) -> core_schema.CoreSchema: - def get_args_2(tp): - """Get type arguments with all substitutions performed. - - For unions, basic simplifications used by Union constructor are performed. - Examples:: - get_args(Dict[str, int]) == (str, int) - get_args(int) == () - get_args(Union[int, Union[T, int], str][int]) == (int, str) - get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) - get_args(Callable[[], T][int]) == ([], int) - """ - from typing import _GenericAlias, get_origin - import collections - if isinstance(tp, _GenericAlias): - res = tp.__args__ - if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: - res = (list(res[:-1]), res[-1]) - return res - else: - print(f'IN ELSE') - return () - instance_schema = core_schema.is_instance_schema(cls) - print(f'instance_schema {instance_schema} and {handler}') - args = get_args_2(DocList[BaseDocWithoutId]) - print(f' args {args}') - return core_schema.with_info_after_validator_function( - function=cls.validate, - schema=core_schema.list_schema(core_schema.any_schema())) + + args = get_args(source) + if args: + sequence_t_schema = handler(Sequence[args[0]]) + else: + sequence_t_schema = handler(Sequence) + + non_instance_schema = core_schema.with_info_after_validator_function( + lambda v, i: DocList(v), sequence_t_schema + ) + return core_schema.union_schema([instance_schema, non_instance_schema]) diff --git a/docarray/documents/legacy/legacy_document.py b/docarray/documents/legacy/legacy_document.py index e4165d5207..dc77f10d0b 100644 --- a/docarray/documents/legacy/legacy_document.py +++ b/docarray/documents/legacy/legacy_document.py @@ -15,7 +15,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Dict, Optional, List, Union +from typing import Any, Dict, Optional from docarray import BaseDoc, DocList from docarray.typing import AnyEmbedding, AnyTensor @@ -50,8 +50,8 @@ class LegacyDocument(BaseDoc): """ tensor: Optional[AnyTensor] = None - chunks: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None - matches: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None + chunks: Optional[DocList[LegacyDocument]] = None + matches: Optional[DocList[LegacyDocument]] = None blob: Optional[bytes] = None text: Optional[str] = None url: Optional[str] = None From a07919d924807d8da84600b0b2225c2999df1e2a Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Mon, 11 Mar 2024 17:19:18 +0100 Subject: [PATCH 05/28] fix: undo some changes --- docarray/array/any_array.py | 18 +++--------------- docarray/array/doc_list/doc_list.py | 9 +++++---- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py index 38d5e60884..e6ccfe9429 100644 --- a/docarray/array/any_array.py +++ b/docarray/array/any_array.py @@ -34,11 +34,6 @@ if sys.version_info >= (3, 12): from types import GenericAlias -else: - try: - from typing import GenericAlias - except ImportError: - from typing import _GenericAlias as GenericAlias T = TypeVar('T', bound='AnyDocArray') T_doc = TypeVar('T_doc', bound=BaseDocWithoutId) @@ -80,7 +75,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): # Promote to global scope so multiprocessing can pickle it global _DocArrayTyped - class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore + class _DocArrayTyped(cls): # type: ignore doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item) if is_pydantic_v2: __origin__: Type['AnyDocArray'] = cls # add this @@ -114,15 +109,8 @@ def _setter(self, value): change_cls_name( _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals() ) - if is_pydantic_v2: - if sys.version_info < (3, 12): - cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(_DocArrayTyped, item) # type: ignore - # this do nothing that checking that item is valid type var or str - # Keep the approach in #1147 to be compatible with lower versions of Python. - else: - cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) - else: - cls.__typed_da__[cls][item] = _DocArrayTyped + + cls.__typed_da__[cls][item] = _DocArrayTyped return cls.__typed_da__[cls][item] diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index 8ea4983817..7e54fb9301 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -13,7 +13,6 @@ cast, overload, Callable, - get_args, ) from pydantic import parse_obj_as @@ -358,14 +357,16 @@ def __get_pydantic_core_schema__( cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema] ) -> core_schema.CoreSchema: instance_schema = core_schema.is_instance_schema(cls) - - args = get_args(source) + args = getattr(source, '__args__', None) if args: sequence_t_schema = handler(Sequence[args[0]]) else: sequence_t_schema = handler(Sequence) + def validate_fn(v, info): + return cls(v) + non_instance_schema = core_schema.with_info_after_validator_function( - lambda v, i: DocList(v), sequence_t_schema + validate_fn, sequence_t_schema ) return core_schema.union_schema([instance_schema, non_instance_schema]) From 5ec78bd5ab7fa87021b59575419b677f4964204e Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 12 Mar 2024 16:23:17 +0100 Subject: [PATCH 06/28] test: test fixes --- .github/workflows/cd.yml | 2 +- .github/workflows/ci.yml | 20 +++++++-------- .github/workflows/ci_only_pr.yml | 2 +- docarray/array/doc_list/doc_list.py | 3 ++- tests/integrations/externals/test_fastapi.py | 27 ++++++++++++++++++++ tests/units/array/test_array.py | 2 ++ tests/units/array/test_doclist_schema.py | 20 +++++++++++++++ 7 files changed, 63 insertions(+), 13 deletions(-) create mode 100644 tests/units/array/test_doclist_schema.py diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index a1aae08ec9..5f565ecb7a 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -21,7 +21,7 @@ jobs: - name: Pre-release (.devN) run: | git fetch --depth=1 origin +refs/tags/*:refs/tags/* - pip install poetry + pip install poetry==1.7.1 ./scripts/release.sh env: PYPI_USERNAME: ${{ secrets.TWINE_USERNAME }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b8c4added6..982be8264b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: - name: Lint with ruff run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install # stop the build if there are Python syntax errors or undefined names @@ -44,7 +44,7 @@ jobs: - name: check black run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --only dev poetry run black --check . @@ -62,7 +62,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --without dev poetry run pip install tensorflow==2.12.0 poetry run pip install jax @@ -106,7 +106,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras poetry run pip install elasticsearch==8.6.2 ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} @@ -156,7 +156,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip install protobuf==3.20.0 # we check that we support 3.19 @@ -204,7 +204,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip install protobuf==3.20.0 @@ -253,7 +253,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip install protobuf==3.20.0 @@ -302,7 +302,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip install protobuf==3.20.0 @@ -351,7 +351,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip uninstall -y torch @@ -398,7 +398,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras poetry run pip uninstall -y torch poetry run pip install torch diff --git a/.github/workflows/ci_only_pr.yml b/.github/workflows/ci_only_pr.yml index 1e8d3f9694..9d040e72b6 100644 --- a/.github/workflows/ci_only_pr.yml +++ b/.github/workflows/ci_only_pr.yml @@ -43,7 +43,7 @@ jobs: run: | npm i -g netlify-cli python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 python -m poetry config virtualenvs.create false && python -m poetry install --no-interaction --no-ansi --all-extras cd docs diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index 7e54fb9301..4923619915 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -364,7 +364,8 @@ def __get_pydantic_core_schema__( sequence_t_schema = handler(Sequence) def validate_fn(v, info): - return cls(v) + # input has already been validated + return cls(v, validate_input_docs=False) non_instance_schema = core_schema.with_info_after_validator_function( validate_fn, sequence_t_schema diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 02967a07cd..821a2cb6b4 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -9,6 +9,7 @@ from docarray.base_doc import DocArrayResponse from docarray.documents import ImageDoc, TextDoc from docarray.typing import NdArray +from docarray.utils._internal.pydantic import is_pydantic_v2 @pytest.mark.asyncio @@ -135,3 +136,29 @@ async def func(fastapi_docs: List[ImageDoc]) -> List[ImageDoc]: docs = DocList[ImageDoc].from_json(response.content.decode()) assert len(docs) == 2 assert docs[0].tensor.shape == (3, 224, 224) + + +@pytest.mark.asyncio +@pytest.mark.skipif(is_pydantic_v2, reason='Behavior is only available for Pydantic V2') +async def test_doclist_directly(): + doc = ImageDoc(tensor=np.zeros((3, 224, 224))) + docs = DocList[ImageDoc]([doc, doc]) + + app = FastAPI() + + @app.post("/doc/", response_class=DocArrayResponse) + async def func(fastapi_docs: DocList[ImageDoc]) -> DocList[ImageDoc]: + return fastapi_docs + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=docs.to_json()) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + docs = DocList[ImageDoc].from_json(response.content.decode()) + assert len(docs) == 2 + assert docs[0].tensor.shape == (3, 224, 224) diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index 1d93fb6b78..ab2772f71c 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -486,6 +486,8 @@ def test_validate_list_dict(): dict(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1] ] + # docs = DocList[Image]([Image(url=image['url'], tensor=image['tensor']) for image in images]) + docs = parse_obj_as(DocList[Image], images) assert docs.url == [ diff --git a/tests/units/array/test_doclist_schema.py b/tests/units/array/test_doclist_schema.py new file mode 100644 index 0000000000..e9a78a36c2 --- /dev/null +++ b/tests/units/array/test_doclist_schema.py @@ -0,0 +1,20 @@ +import pytest +from docarray import BaseDoc, DocList +from docarray.utils._internal.pydantic import is_pydantic_v2 + + +@pytest.mark.skipif(not is_pydantic_v2, reason='Feature only available for Pydantic V2') +def test_schema_nested(): + # check issue https://github.com/docarray/docarray/issues/1521 + + class Doc1Test(BaseDoc): + aux: str + + class DocDocTest(BaseDoc): + docs: DocList[Doc1Test] + + assert 'Doc1Test' in DocDocTest.schema()['$defs'] + d = DocDocTest(docs=DocList[Doc1Test]([Doc1Test(aux='aux')])) + + assert type(d.docs) == DocList[Doc1Test] + assert d.docs.aux == ['aux'] From 50e191adb2b45167cc1736108858bcdcb83cc906 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Wed, 13 Mar 2024 17:20:03 +0100 Subject: [PATCH 07/28] test: set test --- docarray/typing/bytes/base_bytes.py | 2 +- docarray/typing/id.py | 2 +- docarray/typing/tensor/abstract_tensor.py | 2 +- docarray/typing/url/any_url.py | 2 +- tests/integrations/externals/test_fastapi.py | 15 +++++++++++++-- 5 files changed, 17 insertions(+), 6 deletions(-) diff --git a/docarray/typing/bytes/base_bytes.py b/docarray/typing/bytes/base_bytes.py index 4c336ae694..8a944031b4 100644 --- a/docarray/typing/bytes/base_bytes.py +++ b/docarray/typing/bytes/base_bytes.py @@ -62,7 +62,7 @@ def _to_node_protobuf(self: T) -> 'NodeProto': def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: 'GetCoreSchemaHandler' ) -> 'core_schema.CoreSchema': - return core_schema.general_after_validator_function( + return core_schema.with_info_after_validator_function( cls.validate, core_schema.bytes_schema(), ) diff --git a/docarray/typing/id.py b/docarray/typing/id.py index c06951eaef..3e3fdd37ae 100644 --- a/docarray/typing/id.py +++ b/docarray/typing/id.py @@ -77,7 +77,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'str') -> T: def __get_pydantic_core_schema__( cls, source: Type[Any], handler: 'GetCoreSchemaHandler' ) -> core_schema.CoreSchema: - return core_schema.general_plain_validator_function( + return core_schema.with_info_plain_validator_function( cls.validate, ) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 994fe42cc8..4836e39dde 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -395,7 +395,7 @@ def _docarray_to_ndarray(self) -> np.ndarray: def __get_pydantic_core_schema__( cls, _source_type: Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: - return core_schema.general_plain_validator_function( + return core_schema.with_info_plain_validator_function( cls.validate, serialization=core_schema.plain_serializer_function_ser_schema( function=orjson_dumps, diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index ddd1791513..b7c5d71f83 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -56,7 +56,7 @@ def _docarray_validate( def __get_pydantic_core_schema__( cls, source: Type[Any], handler: Optional['GetCoreSchemaHandler'] = None ) -> core_schema.CoreSchema: - return core_schema.general_after_validator_function( + return core_schema.with_info_after_validator_function( cls._docarray_validate, core_schema.str_schema(), ) diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 821a2cb6b4..9eee0d9e93 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -139,26 +139,37 @@ async def func(fastapi_docs: List[ImageDoc]) -> List[ImageDoc]: @pytest.mark.asyncio -@pytest.mark.skipif(is_pydantic_v2, reason='Behavior is only available for Pydantic V2') +@pytest.mark.skipif(not is_pydantic_v2, reason='Behavior is only available for Pydantic V2') async def test_doclist_directly(): + from fastapi import Body doc = ImageDoc(tensor=np.zeros((3, 224, 224))) docs = DocList[ImageDoc]([doc, doc]) app = FastAPI() @app.post("/doc/", response_class=DocArrayResponse) - async def func(fastapi_docs: DocList[ImageDoc]) -> DocList[ImageDoc]: + async def func_embed_false(fastapi_docs: DocList[ImageDoc] = Body(embed=False)) -> DocList[ImageDoc]: + return fastapi_docs + + @app.post("/doc_embed/", response_class=DocArrayResponse) + async def func_embed_true(fastapi_docs: DocList[ImageDoc] = Body(embed=True)) -> DocList[ImageDoc]: return fastapi_docs async with AsyncClient(app=app, base_url="http://test") as ac: response = await ac.post("/doc/", data=docs.to_json()) + response_embed = await ac.post("/doc_embed/", json={'fastapi_docs': [{'tensor': doc.tensor.tolist()}, {'tensor': doc.tensor.tolist()}]}) resp_doc = await ac.get("/docs") resp_redoc = await ac.get("/redoc") assert response.status_code == 200 + assert response_embed.status_code == 200 assert resp_doc.status_code == 200 assert resp_redoc.status_code == 200 docs = DocList[ImageDoc].from_json(response.content.decode()) assert len(docs) == 2 assert docs[0].tensor.shape == (3, 224, 224) + + docs_embed = DocList[ImageDoc].from_json(response_embed.content.decode()) + assert len(docs_embed) == 2 + assert docs_embed[0].tensor.shape == (3, 224, 224) From 949f185300f0991ca7ef845d614f59b10ed356ab Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Thu, 14 Mar 2024 10:50:08 +0100 Subject: [PATCH 08/28] fix: full test for fastapi --- docarray/array/any_array.py | 18 ++++++----- tests/integrations/externals/test_fastapi.py | 33 +++++++++++++++++--- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py index e6ccfe9429..93731d6d8a 100644 --- a/docarray/array/any_array.py +++ b/docarray/array/any_array.py @@ -75,11 +75,8 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): # Promote to global scope so multiprocessing can pickle it global _DocArrayTyped - class _DocArrayTyped(cls): # type: ignore + class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item) - if is_pydantic_v2: - __origin__: Type['AnyDocArray'] = cls # add this - __args__: Tuple[Any, ...] = (item,) # add this for field in _DocArrayTyped.doc_type._docarray_fields().keys(): @@ -104,13 +101,18 @@ def _setter(self, value): setattr(_DocArrayTyped, field, _property_generator(field)) # this generates property on the fly based on the schema of the item - # The global scope and qualname need to refer to this class a unique name. - # Otherwise, creating another _DocArrayTyped will overwrite this one. + # # The global scope and qualname need to refer to this class a unique name. + # # Otherwise, creating another _DocArrayTyped will overwrite this one. change_cls_name( - _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals() + _DocArrayTyped, f'{cls.__name__}', globals() ) - cls.__typed_da__[cls][item] = _DocArrayTyped + if sys.version_info < (3, 12): + cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(_DocArrayTyped, item) # type: ignore + # this do nothing that checking that item is valid type var or str + # Keep the approach in #1147 to be compatible with lower versions of Python. + else: + cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore return cls.__typed_da__[cls][item] diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 9eee0d9e93..bfc6ba0c80 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -139,29 +139,50 @@ async def func(fastapi_docs: List[ImageDoc]) -> List[ImageDoc]: @pytest.mark.asyncio -@pytest.mark.skipif(not is_pydantic_v2, reason='Behavior is only available for Pydantic V2') +@pytest.mark.skipif( + not is_pydantic_v2, reason='Behavior is only available for Pydantic V2' +) async def test_doclist_directly(): from fastapi import Body + doc = ImageDoc(tensor=np.zeros((3, 224, 224))) docs = DocList[ImageDoc]([doc, doc]) app = FastAPI() @app.post("/doc/", response_class=DocArrayResponse) - async def func_embed_false(fastapi_docs: DocList[ImageDoc] = Body(embed=False)) -> DocList[ImageDoc]: + async def func_embed_false( + fastapi_docs: DocList[ImageDoc] = Body(embed=False), + ) -> DocList[ImageDoc]: + return fastapi_docs + + @app.post("/doc_default/", response_class=DocArrayResponse) + async def func_default(fastapi_docs: DocList[ImageDoc]) -> DocList[ImageDoc]: return fastapi_docs @app.post("/doc_embed/", response_class=DocArrayResponse) - async def func_embed_true(fastapi_docs: DocList[ImageDoc] = Body(embed=True)) -> DocList[ImageDoc]: + async def func_embed_true( + fastapi_docs: DocList[ImageDoc] = Body(embed=True), + ) -> DocList[ImageDoc]: return fastapi_docs async with AsyncClient(app=app, base_url="http://test") as ac: response = await ac.post("/doc/", data=docs.to_json()) - response_embed = await ac.post("/doc_embed/", json={'fastapi_docs': [{'tensor': doc.tensor.tolist()}, {'tensor': doc.tensor.tolist()}]}) + response_default = await ac.post("/doc_default/", data=docs.to_json()) + response_embed = await ac.post( + "/doc_embed/", + json={ + 'fastapi_docs': [ + {'tensor': doc.tensor.tolist()}, + {'tensor': doc.tensor.tolist()}, + ] + }, + ) resp_doc = await ac.get("/docs") resp_redoc = await ac.get("/redoc") assert response.status_code == 200 + assert response_default.status_code == 200 assert response_embed.status_code == 200 assert resp_doc.status_code == 200 assert resp_redoc.status_code == 200 @@ -170,6 +191,10 @@ async def func_embed_true(fastapi_docs: DocList[ImageDoc] = Body(embed=True)) -> assert len(docs) == 2 assert docs[0].tensor.shape == (3, 224, 224) + docs_default = DocList[ImageDoc].from_json(response_default.content.decode()) + assert len(docs_default) == 2 + assert docs_default[0].tensor.shape == (3, 224, 224) + docs_embed = DocList[ImageDoc].from_json(response_embed.content.decode()) assert len(docs_embed) == 2 assert docs_embed[0].tensor.shape == (3, 224, 224) From d62645327fe999063a9232b6d2e5c7fc33ecd5d7 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Thu, 14 Mar 2024 10:50:47 +0100 Subject: [PATCH 09/28] fix: try to make generic --- docarray/array/any_array.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py index 93731d6d8a..9f2d0fa890 100644 --- a/docarray/array/any_array.py +++ b/docarray/array/any_array.py @@ -16,7 +16,6 @@ Union, cast, overload, - Tuple, ) import numpy as np @@ -103,17 +102,19 @@ def _setter(self, value): # # The global scope and qualname need to refer to this class a unique name. # # Otherwise, creating another _DocArrayTyped will overwrite this one. - change_cls_name( - _DocArrayTyped, f'{cls.__name__}', globals() - ) - - if sys.version_info < (3, 12): - cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(_DocArrayTyped, item) # type: ignore - # this do nothing that checking that item is valid type var or str - # Keep the approach in #1147 to be compatible with lower versions of Python. + if not is_pydantic_v2: + change_cls_name(_DocArrayTyped, f'{cls.__name__}[{item}]', globals()) + cls.__typed_da__[cls][item] = _DocArrayTyped else: - cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore - + change_cls_name(_DocArrayTyped, f'{cls.__name__}', globals()) + if sys.version_info < (3, 12): + cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__( + _DocArrayTyped, item + ) # type: ignore + # this do nothing that checking that item is valid type var or str + # Keep the approach in #1147 to be compatible with lower versions of Python. + else: + cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore return cls.__typed_da__[cls][item] @overload From 4180a4e5a6b2b691e78f64fd90940fc8d1c3b3a8 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 25 Feb 2025 14:32:16 +0100 Subject: [PATCH 10/28] test: fix some tests --- docarray/utils/_internal/_typing.py | 12 ++-- .../index/base_classes/test_base_doc_store.py | 57 ++++++++++++------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/docarray/utils/_internal/_typing.py b/docarray/utils/_internal/_typing.py index 83e350a060..1ebdc6ea3c 100644 --- a/docarray/utils/_internal/_typing.py +++ b/docarray/utils/_internal/_typing.py @@ -61,11 +61,13 @@ def safe_issubclass(x: type, a_tuple: type) -> bool: :return: A boolean value - 'True' if 'x' is a subclass of 'A_tuple', 'False' otherwise. Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'. """ + origin = get_origin(x) or x if ( - (get_origin(x) in (list, tuple, dict, set, Union)) - or is_typevar(x) - or (type(x) == ForwardRef) - or is_typevar(x) + (origin in (list, tuple, dict, set, Union)) + or is_typevar(origin) + or (type(origin) == ForwardRef) + or is_typevar(origin) ): return False - return issubclass(x, a_tuple) + + return issubclass(origin, a_tuple) diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index faf146df6f..70116bbf9e 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -13,6 +13,7 @@ from docarray.typing import ID, ImageBytes, ImageUrl, NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal.misc import torch_imported +from docarray.utils._internal._typing import safe_issubclass pytestmark = pytest.mark.index @@ -54,7 +55,7 @@ class DummyDocIndex(BaseDocIndex): def __init__(self, db_config=None, **kwargs): super().__init__(db_config=db_config, **kwargs) for col_name, col in self._column_infos.items(): - if issubclass(col.docarray_type, AnyDocArray): + if safe_issubclass(col.docarray_type, AnyDocArray): sub_db_config = copy.deepcopy(self._db_config) self._subindices[col_name] = self.__class__[col.docarray_type.doc_type]( db_config=sub_db_config, subindex=True @@ -159,7 +160,7 @@ def test_create_columns(): assert index._column_infos['id'].n_dim is None assert index._column_infos['id'].config['hi'] == 'there' - assert issubclass(index._column_infos['tens'].docarray_type, AbstractTensor) + assert safe_issubclass(index._column_infos['tens'].docarray_type, AbstractTensor) assert index._column_infos['tens'].db_type == str assert index._column_infos['tens'].n_dim == 10 assert index._column_infos['tens'].config == {'dim': 1000, 'hi': 'there'} @@ -173,12 +174,16 @@ def test_create_columns(): assert index._column_infos['id'].n_dim is None assert index._column_infos['id'].config['hi'] == 'there' - assert issubclass(index._column_infos['tens_one'].docarray_type, AbstractTensor) + assert safe_issubclass( + index._column_infos['tens_one'].docarray_type, AbstractTensor + ) assert index._column_infos['tens_one'].db_type == str assert index._column_infos['tens_one'].n_dim is None assert index._column_infos['tens_one'].config == {'dim': 10, 'hi': 'there'} - assert issubclass(index._column_infos['tens_two'].docarray_type, AbstractTensor) + assert safe_issubclass( + index._column_infos['tens_two'].docarray_type, AbstractTensor + ) assert index._column_infos['tens_two'].db_type == str assert index._column_infos['tens_two'].n_dim is None assert index._column_infos['tens_two'].config == {'dim': 50, 'hi': 'there'} @@ -192,7 +197,7 @@ def test_create_columns(): assert index._column_infos['id'].n_dim is None assert index._column_infos['id'].config['hi'] == 'there' - assert issubclass(index._column_infos['d__tens'].docarray_type, AbstractTensor) + assert safe_issubclass(index._column_infos['d__tens'].docarray_type, AbstractTensor) assert index._column_infos['d__tens'].db_type == str assert index._column_infos['d__tens'].n_dim == 10 assert index._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} @@ -206,7 +211,7 @@ def test_create_columns(): 'parent_id', ] - assert issubclass(index._column_infos['d'].docarray_type, AnyDocArray) + assert safe_issubclass(index._column_infos['d'].docarray_type, AnyDocArray) assert index._column_infos['d'].db_type is None assert index._column_infos['d'].n_dim is None assert index._column_infos['d'].config == {} @@ -216,7 +221,7 @@ def test_create_columns(): assert index._subindices['d']._column_infos['id'].n_dim is None assert index._subindices['d']._column_infos['id'].config['hi'] == 'there' - assert issubclass( + assert safe_issubclass( index._subindices['d']._column_infos['tens'].docarray_type, AbstractTensor ) assert index._subindices['d']._column_infos['tens'].db_type == str @@ -245,7 +250,7 @@ def test_create_columns(): 'parent_id', ] - assert issubclass( + assert safe_issubclass( index._subindices['d_root']._column_infos['d'].docarray_type, AnyDocArray ) assert index._subindices['d_root']._column_infos['d'].db_type is None @@ -266,7 +271,7 @@ def test_create_columns(): index._subindices['d_root']._subindices['d']._column_infos['id'].config['hi'] == 'there' ) - assert issubclass( + assert safe_issubclass( index._subindices['d_root'] ._subindices['d'] ._column_infos['tens'] @@ -461,11 +466,14 @@ class OtherNestedDoc(NestedDoc): # SIMPLE index = DummyDocIndex[SimpleDoc]() in_list = [SimpleDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + assert isinstance(index._validate_docs(in_list)[0], BaseDoc) + in_da = DocList[SimpleDoc](in_list) assert index._validate_docs(in_da) == in_da in_other_list = [OtherSimpleDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_other_list), DocList) + assert isinstance(index._validate_docs(in_other_list)[0], BaseDoc) in_other_da = DocList[OtherSimpleDoc](in_other_list) assert index._validate_docs(in_other_da) == in_other_da @@ -494,7 +502,8 @@ class OtherNestedDoc(NestedDoc): in_list = [ FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) ] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + assert isinstance(index._validate_docs(in_list)[0], BaseDoc) in_da = DocList[FlatDoc]( [FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))] ) @@ -502,7 +511,8 @@ class OtherNestedDoc(NestedDoc): in_other_list = [ OtherFlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) ] - assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_other_list), DocList) + assert isinstance(index._validate_docs(in_other_list)[0], BaseDoc) in_other_da = DocList[OtherFlatDoc]( [ OtherFlatDoc( @@ -521,11 +531,13 @@ class OtherNestedDoc(NestedDoc): # NESTED index = DummyDocIndex[NestedDoc]() in_list = [NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + assert isinstance(index._validate_docs(in_list)[0], BaseDoc) in_da = DocList[NestedDoc]([NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))]) assert index._validate_docs(in_da) == in_da in_other_list = [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] - assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_other_list), DocList) + assert isinstance(index._validate_docs(in_other_list)[0], BaseDoc) in_other_da = DocList[OtherNestedDoc]( [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] ) @@ -552,7 +564,8 @@ class TensorUnionDoc(BaseDoc): # OPTIONAL index = DummyDocIndex[SimpleDoc]() in_list = [OptionalDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + assert isinstance(index._validate_docs(in_list)[0], BaseDoc) in_da = DocList[OptionalDoc](in_list) assert index._validate_docs(in_da) == in_da @@ -562,9 +575,11 @@ class TensorUnionDoc(BaseDoc): # MIXED UNION index = DummyDocIndex[SimpleDoc]() in_list = [MixedUnionDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + assert isinstance(index._validate_docs(in_list)[0], BaseDoc) in_da = DocList[MixedUnionDoc](in_list) - assert isinstance(index._validate_docs(in_da), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_da), DocList) + assert isinstance(index._validate_docs(in_da)[0], BaseDoc) with pytest.raises(ValueError): index._validate_docs([MixedUnionDoc(tens='hello')]) @@ -572,13 +587,15 @@ class TensorUnionDoc(BaseDoc): # TENSOR UNION index = DummyDocIndex[TensorUnionDoc]() in_list = [SimpleDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + assert isinstance(index._validate_docs(in_list)[0], BaseDoc) in_da = DocList[SimpleDoc](in_list) assert index._validate_docs(in_da) == in_da index = DummyDocIndex[SimpleDoc]() in_list = [TensorUnionDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + assert isinstance(index._validate_docs(in_list)[0], BaseDoc) in_da = DocList[TensorUnionDoc](in_list) assert index._validate_docs(in_da) == in_da From fb13c65e72dd3339e39b6defed695ebda5878e5f Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 25 Feb 2025 15:58:55 +0100 Subject: [PATCH 11/28] fix: small tests --- docarray/__init__.py | 35 +++++++++++++++++++ docarray/array/doc_list/io.py | 1 - docarray/base_doc/mixins/io.py | 6 ++-- tests/benchmark_tests/test_map.py | 6 ++-- tests/units/array/test_array_from_to_bytes.py | 6 ++-- tests/units/array/test_array_save_load.py | 8 ++--- 6 files changed, 46 insertions(+), 16 deletions(-) diff --git a/docarray/__init__.py b/docarray/__init__.py index 6ce3f9eb90..e171bae562 100644 --- a/docarray/__init__.py +++ b/docarray/__init__.py @@ -21,6 +21,41 @@ from docarray.base_doc.doc import BaseDoc from docarray.utils._internal.misc import _get_path_from_docarray_root_level + +def unpickle_doclist(doc_type, b): + return DocList[doc_type].from_bytes(b, protocol="protobuf") + + +# Register the pickle functions +def register_serializers(): + import copyreg + from functools import partial + + unpickle_doc_fn = partial(BaseDoc.from_bytes, protocol="protobuf") + + def pickle_doc(doc): + b = doc.to_bytes(protocol='protobuf') + return unpickle_doc_fn, (doc.__class__, b) + + # Register BaseDoc serialization + copyreg.pickle(BaseDoc, pickle_doc) + + # For DocList, we need to hook into __reduce__ since it's a generic + + def pickle_doclist(doc_list): + b = doc_list.to_bytes(protocol='protobuf') + doc_type = doc_list.doc_type + return unpickle_doclist, (doc_type, b) + + # Replace DocList.__reduce__ with a method that returns the correct format + def doclist_reduce(self): + return pickle_doclist(self) + + DocList.__reduce__ = doclist_reduce + + +register_serializers() + __all__ = ['BaseDoc', 'DocList', 'DocVec'] logger = logging.getLogger('docarray') diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index 82d00197e2..3acb66bf6e 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -256,7 +256,6 @@ def to_bytes( :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` :return: the binary serialization in bytes or None if file_ctx is passed where to store """ - with file_ctx or io.BytesIO() as bf: self._write_bytes( bf=bf, diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 3121c45c44..07e54239df 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -159,9 +159,9 @@ def to_bytes( :param compress: compression algorithm to use :return: the binary serialization in bytes """ - import pickle - if protocol == 'pickle': + import pickle + bstr = pickle.dumps(self) elif protocol == 'protobuf': bstr = self.to_protobuf().SerializePartialToString() @@ -188,6 +188,8 @@ def from_bytes( """ bstr = _decompress_bytes(data, algorithm=compress) if protocol == 'pickle': + import pickle + return pickle.loads(bstr) elif protocol == 'protobuf': from docarray.proto import DocProto diff --git a/tests/benchmark_tests/test_map.py b/tests/benchmark_tests/test_map.py index e5c664a408..2fc7b09496 100644 --- a/tests/benchmark_tests/test_map.py +++ b/tests/benchmark_tests/test_map.py @@ -29,9 +29,9 @@ def test_map_docs_multiprocessing(): if os.cpu_count() > 1: def time_multiprocessing(num_workers: int) -> float: - n_docs = 5 + n_docs = 10 rng = np.random.RandomState(0) - matrices = [rng.random(size=(1000, 1000)) for _ in range(n_docs)] + matrices = [rng.random(size=(100, 100)) for _ in range(n_docs)] da = DocList[MyMatrix]([MyMatrix(matrix=m) for m in matrices]) start_time = time() list( @@ -65,7 +65,7 @@ def test_map_docs_batched_multiprocessing(): def time_multiprocessing(num_workers: int) -> float: n_docs = 16 rng = np.random.RandomState(0) - matrices = [rng.random(size=(1000, 1000)) for _ in range(n_docs)] + matrices = [rng.random(size=(100, 100)) for _ in range(n_docs)] da = DocList[MyMatrix]([MyMatrix(matrix=m) for m in matrices]) start_time = time() list( diff --git a/tests/units/array/test_array_from_to_bytes.py b/tests/units/array/test_array_from_to_bytes.py index abc31cb4ac..ac9fbe313f 100644 --- a/tests/units/array/test_array_from_to_bytes.py +++ b/tests/units/array/test_array_from_to_bytes.py @@ -11,9 +11,7 @@ class MyDoc(BaseDoc): image: ImageDoc -@pytest.mark.parametrize( - 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] -) +@pytest.mark.parametrize('protocol', ['protobuf-array', 'protobuf']) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True]) @pytest.mark.parametrize('array_cls', [DocList, DocVec]) @@ -78,7 +76,7 @@ def test_from_to_base64(protocol, compress, show_progress, array_cls): @pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor]) -@pytest.mark.parametrize('protocol', ['protobuf-array', 'pickle-array']) +@pytest.mark.parametrize('protocol', ['protobuf-array']) def test_from_to_base64_tensor_type(tensor_type, protocol): class MyDoc(BaseDoc): embedding: tensor_type diff --git a/tests/units/array/test_array_save_load.py b/tests/units/array/test_array_save_load.py index b5ee6b616e..17c9f80b0d 100644 --- a/tests/units/array/test_array_save_load.py +++ b/tests/units/array/test_array_save_load.py @@ -30,9 +30,7 @@ class MyDoc(BaseDoc): @pytest.mark.slow -@pytest.mark.parametrize( - 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle', 'json-array'] -) +@pytest.mark.parametrize('protocol', ['protobuf-array', 'protobuf', 'json-array']) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True]) @pytest.mark.parametrize('array_cls', [DocList, DocVec]) @@ -67,9 +65,7 @@ def test_array_save_load_binary(protocol, compress, tmp_path, show_progress, arr @pytest.mark.slow -@pytest.mark.parametrize( - 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle', 'json-array'] -) +@pytest.mark.parametrize('protocol', ['protobuf-array', 'protobuf', 'json-array']) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True]) @pytest.mark.parametrize('to_doc_vec', [True, False]) From bfee17af0d935ab1b78286e9ca69931cbb1d2086 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 25 Feb 2025 16:35:06 +0100 Subject: [PATCH 12/28] test: fix test --- docarray/__init__.py | 58 ++++++++++++++++++--------- docarray/base_doc/mixins/io.py | 4 -- tests/units/array/stack/test_proto.py | 2 + 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/docarray/__init__.py b/docarray/__init__.py index e171bae562..3ef7a4ad97 100644 --- a/docarray/__init__.py +++ b/docarray/__init__.py @@ -20,41 +20,59 @@ from docarray.array import DocList, DocVec from docarray.base_doc.doc import BaseDoc from docarray.utils._internal.misc import _get_path_from_docarray_root_level +from docarray.utils._internal.pydantic import is_pydantic_v2 def unpickle_doclist(doc_type, b): return DocList[doc_type].from_bytes(b, protocol="protobuf") -# Register the pickle functions -def register_serializers(): - import copyreg - from functools import partial +def unpickle_docvec(doc_type, b): + return DocVec[doc_type].from_bytes(b, protocol="protobuf") - unpickle_doc_fn = partial(BaseDoc.from_bytes, protocol="protobuf") - def pickle_doc(doc): - b = doc.to_bytes(protocol='protobuf') - return unpickle_doc_fn, (doc.__class__, b) +if is_pydantic_v2: + # Register the pickle functions + def register_serializers(): + import copyreg + from functools import partial - # Register BaseDoc serialization - copyreg.pickle(BaseDoc, pickle_doc) + unpickle_doc_fn = partial(BaseDoc.from_bytes, protocol="protobuf") - # For DocList, we need to hook into __reduce__ since it's a generic + def pickle_doc(doc): + b = doc.to_bytes(protocol='protobuf') + return unpickle_doc_fn, (doc.__class__, b) - def pickle_doclist(doc_list): - b = doc_list.to_bytes(protocol='protobuf') - doc_type = doc_list.doc_type - return unpickle_doclist, (doc_type, b) + # Register BaseDoc serialization + copyreg.pickle(BaseDoc, pickle_doc) - # Replace DocList.__reduce__ with a method that returns the correct format - def doclist_reduce(self): - return pickle_doclist(self) + # For DocList, we need to hook into __reduce__ since it's a generic - DocList.__reduce__ = doclist_reduce + def pickle_doclist(doc_list): + b = doc_list.to_bytes(protocol='protobuf') + doc_type = doc_list.doc_type + return unpickle_doclist, (doc_type, b) + # Replace DocList.__reduce__ with a method that returns the correct format + def doclist_reduce(self): + return pickle_doclist(self) -register_serializers() + DocList.__reduce__ = doclist_reduce + + # For DocVec, we need to hook into __reduce__ since it's a generic + + def pickle_docvec(doc_vec): + b = doc_vec.to_bytes(protocol='protobuf') + doc_type = doc_vec.doc_type + return unpickle_docvec, (doc_type, b) + + # Replace DocList.__reduce__ with a method that returns the correct format + def docvec_reduce(self): + return pickle_docvec(self) + + DocVec.__reduce__ = docvec_reduce + + register_serializers() __all__ = ['BaseDoc', 'DocList', 'DocVec'] diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 07e54239df..903cc91b9d 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -160,8 +160,6 @@ def to_bytes( :return: the binary serialization in bytes """ if protocol == 'pickle': - import pickle - bstr = pickle.dumps(self) elif protocol == 'protobuf': bstr = self.to_protobuf().SerializePartialToString() @@ -188,8 +186,6 @@ def from_bytes( """ bstr = _decompress_bytes(data, algorithm=compress) if protocol == 'pickle': - import pickle - return pickle.loads(bstr) elif protocol == 'protobuf': from docarray.proto import DocProto diff --git a/tests/units/array/stack/test_proto.py b/tests/units/array/stack/test_proto.py index 8c559826b8..d46766cde3 100644 --- a/tests/units/array/stack/test_proto.py +++ b/tests/units/array/stack/test_proto.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Dict, Optional, Union import numpy as np @@ -245,6 +246,7 @@ class MyDoc(BaseDoc): assert da_after._storage.any_columns['d'] == [None, None] +@pytest.mark.skipif('GITHUB_WORKFLOW' in os.environ, reason='Flaky in Github') @pytest.mark.proto @pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor]) def test_proto_tensor_type(tensor_type): From 2f80c5618543c29b7b1d3442b6e3353e64e535bb Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 25 Feb 2025 18:15:05 +0100 Subject: [PATCH 13/28] fix: small test fix --- docarray/index/backends/elastic.py | 8 ++++---- docarray/index/backends/epsilla.py | 4 ++-- docarray/utils/_internal/_typing.py | 14 ++++++++------ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index c008fa29de..a335f85e32 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -352,12 +352,12 @@ def python_type_to_db_type(self, python_type: Type) -> Any: dict: 'object', } - for type in elastic_py_types.keys(): - if safe_issubclass(python_type, type): + for t in elastic_py_types.keys(): + if safe_issubclass(python_type, t): self._logger.info( - f'Mapped Python type {python_type} to database type "{elastic_py_types[type]}"' + f'Mapped Python type {python_type} to database type "{elastic_py_types[t]}"' ) - return elastic_py_types[type] + return elastic_py_types[t] err_msg = f'Unsupported column type for {type(self)}: {python_type}' self._logger.error(err_msg) diff --git a/docarray/index/backends/epsilla.py b/docarray/index/backends/epsilla.py index 83c171daed..0392e9d010 100644 --- a/docarray/index/backends/epsilla.py +++ b/docarray/index/backends/epsilla.py @@ -100,8 +100,8 @@ def __init__(self, db_config=None, **kwargs): def _validate_column_info(self): vector_columns = [] for info in self._column_infos.values(): - for type in [list, np.ndarray, AbstractTensor]: - if safe_issubclass(info.docarray_type, type) and info.config.get( + for t in [list, np.ndarray, AbstractTensor]: + if safe_issubclass(info.docarray_type, t) and info.config.get( 'is_embedding', False ): # check that dimension is present diff --git a/docarray/utils/_internal/_typing.py b/docarray/utils/_internal/_typing.py index 1ebdc6ea3c..3286c1a062 100644 --- a/docarray/utils/_internal/_typing.py +++ b/docarray/utils/_internal/_typing.py @@ -61,13 +61,15 @@ def safe_issubclass(x: type, a_tuple: type) -> bool: :return: A boolean value - 'True' if 'x' is a subclass of 'A_tuple', 'False' otherwise. Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'. """ - origin = get_origin(x) or x + origin = get_origin(x) + if origin: # If x is a generic type like DocList[SomeDoc], get its origin + x = origin if ( - (origin in (list, tuple, dict, set, Union)) - or is_typevar(origin) - or (type(origin) == ForwardRef) - or is_typevar(origin) + (get_origin(x) in (list, tuple, dict, set, Union)) + or is_typevar(x) + or (type(x) == ForwardRef) + or is_typevar(x) ): return False - return issubclass(origin, a_tuple) + return isinstance(x, type) and issubclass(x, a_tuple) From 5ff2f687cbe59e6dc037868a0315adbbc77e8c71 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Wed, 12 Mar 2025 17:53:11 +0100 Subject: [PATCH 14/28] test: change tests --- docarray/array/any_array.py | 20 ++++++++++++++++--- docarray/base_doc/mixins/io.py | 2 ++ tests/units/array/test_array_from_to_bytes.py | 8 +++++--- tests/units/array/test_array_save_load.py | 8 ++++++-- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py index 9f2d0fa890..0c29e54ae8 100644 --- a/docarray/array/any_array.py +++ b/docarray/array/any_array.py @@ -74,8 +74,19 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): # Promote to global scope so multiprocessing can pickle it global _DocArrayTyped - class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore - doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item) + if not is_pydantic_v2: + + class _DocArrayTyped(cls): # type: ignore + doc_type: Type[BaseDocWithoutId] = cast( + Type[BaseDocWithoutId], item + ) + + else: + + class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore + doc_type: Type[BaseDocWithoutId] = cast( + Type[BaseDocWithoutId], item + ) for field in _DocArrayTyped.doc_type._docarray_fields().keys(): @@ -103,7 +114,10 @@ def _setter(self, value): # # The global scope and qualname need to refer to this class a unique name. # # Otherwise, creating another _DocArrayTyped will overwrite this one. if not is_pydantic_v2: - change_cls_name(_DocArrayTyped, f'{cls.__name__}[{item}]', globals()) + change_cls_name( + _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals() + ) + cls.__typed_da__[cls][item] = _DocArrayTyped else: change_cls_name(_DocArrayTyped, f'{cls.__name__}', globals()) diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 903cc91b9d..3121c45c44 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -159,6 +159,8 @@ def to_bytes( :param compress: compression algorithm to use :return: the binary serialization in bytes """ + import pickle + if protocol == 'pickle': bstr = pickle.dumps(self) elif protocol == 'protobuf': diff --git a/tests/units/array/test_array_from_to_bytes.py b/tests/units/array/test_array_from_to_bytes.py index ac9fbe313f..1da72f571e 100644 --- a/tests/units/array/test_array_from_to_bytes.py +++ b/tests/units/array/test_array_from_to_bytes.py @@ -11,7 +11,9 @@ class MyDoc(BaseDoc): image: ImageDoc -@pytest.mark.parametrize('protocol', ['protobuf-array', 'protobuf']) +@pytest.mark.parametrize( + 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] +) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True]) @pytest.mark.parametrize('array_cls', [DocList, DocVec]) @@ -41,7 +43,7 @@ def test_from_to_bytes(protocol, compress, show_progress, array_cls): @pytest.mark.parametrize( - 'protocol', ['protobuf'] # ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] + 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] ) @pytest.mark.parametrize('compress', ['lz4']) # , 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False]) # [False, True]) @@ -76,7 +78,7 @@ def test_from_to_base64(protocol, compress, show_progress, array_cls): @pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor]) -@pytest.mark.parametrize('protocol', ['protobuf-array']) +@pytest.mark.parametrize('protocol', ['protobuf-array', 'pickle-array']) def test_from_to_base64_tensor_type(tensor_type, protocol): class MyDoc(BaseDoc): embedding: tensor_type diff --git a/tests/units/array/test_array_save_load.py b/tests/units/array/test_array_save_load.py index 17c9f80b0d..b5ee6b616e 100644 --- a/tests/units/array/test_array_save_load.py +++ b/tests/units/array/test_array_save_load.py @@ -30,7 +30,9 @@ class MyDoc(BaseDoc): @pytest.mark.slow -@pytest.mark.parametrize('protocol', ['protobuf-array', 'protobuf', 'json-array']) +@pytest.mark.parametrize( + 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle', 'json-array'] +) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True]) @pytest.mark.parametrize('array_cls', [DocList, DocVec]) @@ -65,7 +67,9 @@ def test_array_save_load_binary(protocol, compress, tmp_path, show_progress, arr @pytest.mark.slow -@pytest.mark.parametrize('protocol', ['protobuf-array', 'protobuf', 'json-array']) +@pytest.mark.parametrize( + 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle', 'json-array'] +) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True]) @pytest.mark.parametrize('to_doc_vec', [True, False]) From 00245aa1a0ccb9d3ba70ab0dd6432843ffe3ac1f Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Thu, 13 Mar 2025 09:48:48 +0100 Subject: [PATCH 15/28] fix: try to fix all pydantic-v1 tests --- docarray/utils/_internal/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/utils/_internal/_typing.py b/docarray/utils/_internal/_typing.py index 3286c1a062..3c2bd89a8e 100644 --- a/docarray/utils/_internal/_typing.py +++ b/docarray/utils/_internal/_typing.py @@ -65,7 +65,7 @@ def safe_issubclass(x: type, a_tuple: type) -> bool: if origin: # If x is a generic type like DocList[SomeDoc], get its origin x = origin if ( - (get_origin(x) in (list, tuple, dict, set, Union)) + (origin in (list, tuple, dict, set, Union)) or is_typevar(x) or (type(x) == ForwardRef) or is_typevar(x) From 40f94202f23a3a58934532e3695f00fce527dacb Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Thu, 13 Mar 2025 11:30:33 +0100 Subject: [PATCH 16/28] fix: fix small dynamic creation --- docarray/utils/create_dynamic_doc_class.py | 11 +++++++---- tests/units/typing/da/test_relations.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index 744fea58c3..ee79d16e70 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -54,8 +54,9 @@ class MyDoc(BaseDoc): fields: Dict[str, Any] = {} import copy - fields_copy = copy.deepcopy(model.__fields__) - annotations_copy = copy.deepcopy(model.__annotations__) + copy_model = copy.deepcopy(model) + fields_copy = copy_model.__fields__ + annotations_copy = copy_model.__annotations__ for field_name, field in annotations_copy.items(): if field_name not in fields_copy: continue @@ -65,7 +66,7 @@ class MyDoc(BaseDoc): else: field_info = fields_copy[field_name].field_info try: - if safe_issubclass(field, DocList): + if safe_issubclass(field, DocList) and not is_pydantic_v2: t: Any = field.doc_type t_aux = create_pure_python_type_model(t) fields[field_name] = (List[t_aux], field_info) @@ -74,7 +75,9 @@ class MyDoc(BaseDoc): except TypeError: fields[field_name] = (field, field_info) - return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields) + return create_model( + copy_model.__name__, __base__=copy_model, __doc__=copy_model.__doc__, **fields + ) def _get_field_annotation_from_schema( diff --git a/tests/units/typing/da/test_relations.py b/tests/units/typing/da/test_relations.py index f583abef2e..cadac712f5 100644 --- a/tests/units/typing/da/test_relations.py +++ b/tests/units/typing/da/test_relations.py @@ -13,9 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import pytest from docarray import BaseDoc, DocList +from docarray.utils._internal.pydantic import is_pydantic_v2 +@pytest.mark.skipif( + is_pydantic_v2, + reason="Subscripted generics cannot be used with class and instance checks", +) def test_instance_and_equivalence(): class MyDoc(BaseDoc): text: str @@ -28,6 +35,10 @@ class MyDoc(BaseDoc): assert isinstance(docs, DocList[MyDoc]) +@pytest.mark.skipif( + is_pydantic_v2, + reason="Subscripted generics cannot be used with class and instance checks", +) def test_subclassing(): class MyDoc(BaseDoc): text: str From cb9cc944ab9c5e7712720decc02f85f86b657362 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Thu, 13 Mar 2025 13:22:47 +0100 Subject: [PATCH 17/28] test: further fix tests --- docarray/array/doc_vec/doc_vec.py | 4 ++- .../index/base_classes/test_base_doc_store.py | 33 ++++++++++++------- .../array/test_optional_doc_vec.py | 3 +- tests/units/array/test_array.py | 2 -- tests/units/array/test_doclist_schema.py | 4 ++- tests/units/util/test_map.py | 4 ++- 6 files changed, 33 insertions(+), 17 deletions(-) diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index 9d515cfd96..9ac8af89cf 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -335,7 +335,9 @@ def _docarray_validate( return cast(T, value.to_doc_vec()) else: raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}') - elif isinstance(value, DocList.__class_getitem__(cls.doc_type)): + elif not is_pydantic_v2 and isinstance( + value, DocList.__class_getitem__(cls.doc_type) + ): return cast(T, value.to_doc_vec()) elif isinstance(value, Sequence): return cls(value) diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index 70116bbf9e..7337969428 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -467,13 +467,15 @@ class OtherNestedDoc(NestedDoc): index = DummyDocIndex[SimpleDoc]() in_list = [SimpleDoc(tens=np.random.random((10,)))] assert isinstance(index._validate_docs(in_list), DocList) - assert isinstance(index._validate_docs(in_list)[0], BaseDoc) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[SimpleDoc](in_list) assert index._validate_docs(in_da) == in_da in_other_list = [OtherSimpleDoc(tens=np.random.random((10,)))] assert isinstance(index._validate_docs(in_other_list), DocList) - assert isinstance(index._validate_docs(in_other_list)[0], BaseDoc) + for d in index._validate_docs(in_other_list): + assert isinstance(d, BaseDoc) in_other_da = DocList[OtherSimpleDoc](in_other_list) assert index._validate_docs(in_other_da) == in_other_da @@ -503,7 +505,8 @@ class OtherNestedDoc(NestedDoc): FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) ] assert isinstance(index._validate_docs(in_list), DocList) - assert isinstance(index._validate_docs(in_list)[0], BaseDoc) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[FlatDoc]( [FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))] ) @@ -512,7 +515,8 @@ class OtherNestedDoc(NestedDoc): OtherFlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) ] assert isinstance(index._validate_docs(in_other_list), DocList) - assert isinstance(index._validate_docs(in_other_list)[0], BaseDoc) + for d in index._validate_docs(in_other_list): + assert isinstance(d, BaseDoc) in_other_da = DocList[OtherFlatDoc]( [ OtherFlatDoc( @@ -532,12 +536,14 @@ class OtherNestedDoc(NestedDoc): index = DummyDocIndex[NestedDoc]() in_list = [NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))] assert isinstance(index._validate_docs(in_list), DocList) - assert isinstance(index._validate_docs(in_list)[0], BaseDoc) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[NestedDoc]([NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))]) assert index._validate_docs(in_da) == in_da in_other_list = [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] assert isinstance(index._validate_docs(in_other_list), DocList) - assert isinstance(index._validate_docs(in_other_list)[0], BaseDoc) + for d in index._validate_docs(in_other_list): + assert isinstance(d, BaseDoc) in_other_da = DocList[OtherNestedDoc]( [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] ) @@ -565,7 +571,8 @@ class TensorUnionDoc(BaseDoc): index = DummyDocIndex[SimpleDoc]() in_list = [OptionalDoc(tens=np.random.random((10,)))] assert isinstance(index._validate_docs(in_list), DocList) - assert isinstance(index._validate_docs(in_list)[0], BaseDoc) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[OptionalDoc](in_list) assert index._validate_docs(in_da) == in_da @@ -576,10 +583,12 @@ class TensorUnionDoc(BaseDoc): index = DummyDocIndex[SimpleDoc]() in_list = [MixedUnionDoc(tens=np.random.random((10,)))] assert isinstance(index._validate_docs(in_list), DocList) - assert isinstance(index._validate_docs(in_list)[0], BaseDoc) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[MixedUnionDoc](in_list) assert isinstance(index._validate_docs(in_da), DocList) - assert isinstance(index._validate_docs(in_da)[0], BaseDoc) + for d in index._validate_docs(in_da): + assert isinstance(d, BaseDoc) with pytest.raises(ValueError): index._validate_docs([MixedUnionDoc(tens='hello')]) @@ -588,14 +597,16 @@ class TensorUnionDoc(BaseDoc): index = DummyDocIndex[TensorUnionDoc]() in_list = [SimpleDoc(tens=np.random.random((10,)))] assert isinstance(index._validate_docs(in_list), DocList) - assert isinstance(index._validate_docs(in_list)[0], BaseDoc) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[SimpleDoc](in_list) assert index._validate_docs(in_da) == in_da index = DummyDocIndex[SimpleDoc]() in_list = [TensorUnionDoc(tens=np.random.random((10,)))] assert isinstance(index._validate_docs(in_list), DocList) - assert isinstance(index._validate_docs(in_list)[0], BaseDoc) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[TensorUnionDoc](in_list) assert index._validate_docs(in_da) == in_da diff --git a/tests/integrations/array/test_optional_doc_vec.py b/tests/integrations/array/test_optional_doc_vec.py index bb793152d3..dd77c66762 100644 --- a/tests/integrations/array/test_optional_doc_vec.py +++ b/tests/integrations/array/test_optional_doc_vec.py @@ -20,7 +20,8 @@ class Image(BaseDoc): docs.features = [Features(tensor=np.random.random([100])) for _ in range(10)] print(docs.features) # - assert isinstance(docs.features, DocVec[Features]) + assert isinstance(docs.features, DocVec) + assert isinstance(docs.features[0], Features) docs.features.tensor = np.ones((10, 100)) diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index ab2772f71c..8e51cc1c37 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -522,5 +522,3 @@ def test_not_double_subcriptable(): with pytest.raises(TypeError) as excinfo: da = DocList[TextDoc][TextDoc] assert da is None - - assert 'not subscriptable' in str(excinfo.value) diff --git a/tests/units/array/test_doclist_schema.py b/tests/units/array/test_doclist_schema.py index e9a78a36c2..7cab659192 100644 --- a/tests/units/array/test_doclist_schema.py +++ b/tests/units/array/test_doclist_schema.py @@ -16,5 +16,7 @@ class DocDocTest(BaseDoc): assert 'Doc1Test' in DocDocTest.schema()['$defs'] d = DocDocTest(docs=DocList[Doc1Test]([Doc1Test(aux='aux')])) - assert type(d.docs) == DocList[Doc1Test] + assert isinstance(d.docs, DocList) + for d in d.docs: + assert isinstance(d, Doc1Test) assert d.docs.aux == ['aux'] diff --git a/tests/units/util/test_map.py b/tests/units/util/test_map.py index 3b9f102d92..65dd3c1738 100644 --- a/tests/units/util/test_map.py +++ b/tests/units/util/test_map.py @@ -96,4 +96,6 @@ def test_map_docs_batched(n_docs, batch_size, backend): assert isinstance(it, Generator) for batch in it: - assert isinstance(batch, DocList[MyImage]) + assert isinstance(batch, DocList) + for d in batch: + assert isinstance(d, MyImage) From 4f42c0b8b35e8cc69fb1642def5e884539ff1124 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Thu, 13 Mar 2025 14:16:09 +0100 Subject: [PATCH 18/28] test: new iteration --- docarray/__init__.py | 7 ++-- .../torch/data/test_torch_dataset.py | 8 +++-- .../units/array/stack/storage/test_storage.py | 3 +- tests/units/array/stack/test_array_stacked.py | 12 +++++-- tests/units/array/test_array_from_to_bytes.py | 36 +++++++++++-------- tests/units/array/test_doclist_schema.py | 4 +-- tests/units/document/test_doc_wo_id.py | 7 +++- 7 files changed, 51 insertions(+), 26 deletions(-) diff --git a/docarray/__init__.py b/docarray/__init__.py index 3ef7a4ad97..5a18bb9588 100644 --- a/docarray/__init__.py +++ b/docarray/__init__.py @@ -27,8 +27,8 @@ def unpickle_doclist(doc_type, b): return DocList[doc_type].from_bytes(b, protocol="protobuf") -def unpickle_docvec(doc_type, b): - return DocVec[doc_type].from_bytes(b, protocol="protobuf") +def unpickle_docvec(doc_type, tensor_type, b): + return DocVec[doc_type].from_bytes(b, protocol="protobuf", tensor_type=tensor_type) if is_pydantic_v2: @@ -64,7 +64,8 @@ def doclist_reduce(self): def pickle_docvec(doc_vec): b = doc_vec.to_bytes(protocol='protobuf') doc_type = doc_vec.doc_type - return unpickle_docvec, (doc_type, b) + tensor_type = doc_vec.tensor_type + return unpickle_docvec, (doc_type, tensor_type, b) # Replace DocList.__reduce__ with a method that returns the correct format def docvec_reduce(self): diff --git a/tests/integrations/torch/data/test_torch_dataset.py b/tests/integrations/torch/data/test_torch_dataset.py index f358f1c16b..5d8236a70b 100644 --- a/tests/integrations/torch/data/test_torch_dataset.py +++ b/tests/integrations/torch/data/test_torch_dataset.py @@ -60,7 +60,9 @@ def test_torch_dataset(captions_da: DocList[PairTextImage]): batch_lens = [] for batch in loader: - assert isinstance(batch, DocVec[PairTextImage]) + assert isinstance(batch, DocVec) + for d in batch: + assert isinstance(d, PairTextImage) batch_lens.append(len(batch)) assert all(x == BATCH_SIZE for x in batch_lens[:-1]) @@ -140,7 +142,9 @@ def test_torch_dl_multiprocessing(captions_da: DocList[PairTextImage]): batch_lens = [] for batch in loader: - assert isinstance(batch, DocVec[PairTextImage]) + assert isinstance(batch, DocVec) + for d in batch: + assert isinstance(d, PairTextImage) batch_lens.append(len(batch)) assert all(x == BATCH_SIZE for x in batch_lens[:-1]) diff --git a/tests/units/array/stack/storage/test_storage.py b/tests/units/array/stack/storage/test_storage.py index 01c1b68a16..b91585d373 100644 --- a/tests/units/array/stack/storage/test_storage.py +++ b/tests/units/array/stack/storage/test_storage.py @@ -26,8 +26,9 @@ class MyDoc(BaseDoc): for name in storage.any_columns['name']: assert name == 'hello' inner_docs = storage.doc_columns['doc'] - assert isinstance(inner_docs, DocVec[InnerDoc]) + assert isinstance(inner_docs, DocVec) for i, doc in enumerate(inner_docs): + assert isinstance(doc, InnerDoc) assert doc.price == i diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 2a3790da1d..b1b385840d 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -504,7 +504,9 @@ class ImageDoc(BaseDoc): da = parse_obj_as(DocVec[ImageDoc], batch) - assert isinstance(da, DocVec[ImageDoc]) + assert isinstance(da, DocVec) + for d in da: + assert isinstance(d, ImageDoc) def test_validation_column_tensor(batch): @@ -536,14 +538,18 @@ def test_validation_column_doc(batch_nested_doc): batch, Doc, Inner = batch_nested_doc batch.inner = DocList[Inner]([Inner(hello='hello') for _ in range(10)]) - assert isinstance(batch.inner, DocVec[Inner]) + assert isinstance(batch.inner, DocVec) + for d in batch.inner: + assert isinstance(d, Inner) def test_validation_list_doc(batch_nested_doc): batch, Doc, Inner = batch_nested_doc batch.inner = [Inner(hello='hello') for _ in range(10)] - assert isinstance(batch.inner, DocVec[Inner]) + assert isinstance(batch.inner, DocVec) + for d in batch.inner: + assert isinstance(d, Inner) def test_validation_col_doc_fail(batch_nested_doc): diff --git a/tests/units/array/test_array_from_to_bytes.py b/tests/units/array/test_array_from_to_bytes.py index 1da72f571e..0ab952ce4a 100644 --- a/tests/units/array/test_array_from_to_bytes.py +++ b/tests/units/array/test_array_from_to_bytes.py @@ -45,9 +45,9 @@ def test_from_to_bytes(protocol, compress, show_progress, array_cls): @pytest.mark.parametrize( 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] ) -@pytest.mark.parametrize('compress', ['lz4']) # , 'bz2', 'lzma', 'zlib', 'gzip', None]) -@pytest.mark.parametrize('show_progress', [False]) # [False, True]) -@pytest.mark.parametrize('array_cls', [DocVec]) # [DocList, DocVec]) +@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) +@pytest.mark.parametrize('show_progress', [False, True]) # [False, True]) +@pytest.mark.parametrize('array_cls', [DocList, DocVec]) def test_from_to_base64(protocol, compress, show_progress, array_cls): da = array_cls[MyDoc]( [ @@ -75,27 +75,35 @@ def test_from_to_base64(protocol, compress, show_progress, array_cls): # test_from_to_base64('protobuf', 'lz4', False, DocVec) +class MyTensorTypeDocNdArray(BaseDoc): + embedding: NdArray + text: str + image: ImageDoc -@pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor]) -@pytest.mark.parametrize('protocol', ['protobuf-array', 'pickle-array']) -def test_from_to_base64_tensor_type(tensor_type, protocol): - class MyDoc(BaseDoc): - embedding: tensor_type - text: str - image: ImageDoc +class MyTensorTypeDocTorchTensor(BaseDoc): + embedding: TorchTensor + text: str + image: ImageDoc - da = DocVec[MyDoc]( + +@pytest.mark.parametrize( + 'doc_type, tensor_type', + [(MyTensorTypeDocNdArray, NdArray), (MyTensorTypeDocTorchTensor, TorchTensor)], +) +@pytest.mark.parametrize('protocol', ['protobuf-array', 'pickle-array']) +def test_from_to_base64_tensor_type(doc_type, tensor_type, protocol): + da = DocVec[doc_type]( [ - MyDoc( + doc_type( embedding=[1, 2, 3, 4, 5], text='hello', image=ImageDoc(url='aux.png') ), - MyDoc(embedding=[5, 4, 3, 2, 1], text='hello world', image=ImageDoc()), + doc_type(embedding=[5, 4, 3, 2, 1], text='hello world', image=ImageDoc()), ], tensor_type=tensor_type, ) bytes_da = da.to_base64(protocol=protocol) - da2 = DocVec[MyDoc].from_base64( + da2 = DocVec[doc_type].from_base64( bytes_da, tensor_type=tensor_type, protocol=protocol ) assert da2.tensor_type == tensor_type diff --git a/tests/units/array/test_doclist_schema.py b/tests/units/array/test_doclist_schema.py index 7cab659192..02a5f56280 100644 --- a/tests/units/array/test_doclist_schema.py +++ b/tests/units/array/test_doclist_schema.py @@ -17,6 +17,6 @@ class DocDocTest(BaseDoc): d = DocDocTest(docs=DocList[Doc1Test]([Doc1Test(aux='aux')])) assert isinstance(d.docs, DocList) - for d in d.docs: - assert isinstance(d, Doc1Test) + for dd in d.docs: + assert isinstance(dd, Doc1Test) assert d.docs.aux == ['aux'] diff --git a/tests/units/document/test_doc_wo_id.py b/tests/units/document/test_doc_wo_id.py index ffda3ceec4..4e2a8bba11 100644 --- a/tests/units/document/test_doc_wo_id.py +++ b/tests/units/document/test_doc_wo_id.py @@ -23,4 +23,9 @@ class A(BaseDocWithoutId): cls_doc_list = DocList[A] - assert isinstance(cls_doc_list, type) + da = cls_doc_list([A(text='hey here')]) + + assert isinstance(da, DocList) + for d in da: + assert isinstance(d, A) + assert not hasattr(d, 'id') From cfa7ea5d007f029c70463729c9b15b3ef786491d Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Fri, 14 Mar 2025 09:22:30 +0100 Subject: [PATCH 19/28] test: more tests --- docarray/array/doc_vec/doc_vec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index 9ac8af89cf..0cc462f173 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -198,7 +198,7 @@ def _check_doc_field_not_none(field_name, doc): if safe_issubclass(tensor.__class__, tensor_type): field_type = tensor_type - if isinstance(field_type, type): + if isinstance(field_type, type) or safe_issubclass(field_type, AnyDocArray): if tf_available and safe_issubclass(field_type, TensorFlowTensor): # tf.Tensor does not allow item assignment, therefore the # optimized way From 659b992a426534ec80a2ab4731986082ceec405c Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Fri, 14 Mar 2025 09:45:16 +0100 Subject: [PATCH 20/28] fix: tests --- docarray/base_doc/doc.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 48fb3076cd..e880504bc0 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -326,8 +326,13 @@ def _exclude_doclist( from docarray.array.any_array import AnyDocArray type_ = self._get_field_annotation(field) - if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray): - doclist_exclude_fields.append(field) + if is_pydantic_v2: + # Conservative when touching pydantic v1 logic + if safe_issubclass(type_, AnyDocArray): + doclist_exclude_fields.append(field) + else: + if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray): + doclist_exclude_fields.append(field) original_exclude = exclude if exclude is None: @@ -480,7 +485,6 @@ def model_dump( # type: ignore warnings: bool = True, ) -> Dict[str, Any]: def _model_dump(doc): - ( exclude_, original_exclude, From 9c9439427ab3d0bb786d7c871d8850cc93ed9620 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Mon, 17 Mar 2025 18:07:55 +0100 Subject: [PATCH 21/28] test: fix schemas from new model --- docarray/utils/create_dynamic_doc_class.py | 9 +- tests/integrations/externals/test_fastapi.py | 167 ++++++++++++++++-- .../util/test_create_dynamic_code_class.py | 10 +- 3 files changed, 173 insertions(+), 13 deletions(-) diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index ee79d16e70..d3ea2720b8 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -163,7 +163,10 @@ def _get_field_annotation_from_schema( doc_type: Any if 'additionalProperties' in field_schema: # handle Dictionaries additional_props = field_schema['additionalProperties'] - if additional_props.get('type') == 'object': + if ( + isinstance(additional_props, dict) + and additional_props.get('type') == 'object' + ): doc_type = create_base_doc_from_schema( additional_props, field_name, cached_models=cached_models ) @@ -300,7 +303,9 @@ class MyDoc(BaseDoc): if k in FieldInfo.__slots__: field_kwargs[k] = v else: - field_json_schema_extra[k] = v + if k != '$ref' and '#/$defs' not in str(v): + field_json_schema_extra[k] = v + fields[field_name] = ( field_type, FieldInfo( diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index bfc6ba0c80..1d0d6ebba6 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -1,5 +1,5 @@ -from typing import List - +from typing import Any, Dict, List, Optional, Union, ClassVar +import json import numpy as np import pytest from fastapi import FastAPI @@ -8,7 +8,8 @@ from docarray import BaseDoc, DocList from docarray.base_doc import DocArrayResponse from docarray.documents import ImageDoc, TextDoc -from docarray.typing import NdArray +from docarray.typing import NdArray, AnyTensor, ImageUrl + from docarray.utils._internal.pydantic import is_pydantic_v2 @@ -145,7 +146,7 @@ async def func(fastapi_docs: List[ImageDoc]) -> List[ImageDoc]: async def test_doclist_directly(): from fastapi import Body - doc = ImageDoc(tensor=np.zeros((3, 224, 224))) + doc = ImageDoc(tensor=np.zeros((3, 224, 224)), url='url') docs = DocList[ImageDoc]([doc, doc]) app = FastAPI() @@ -169,14 +170,10 @@ async def func_embed_true( async with AsyncClient(app=app, base_url="http://test") as ac: response = await ac.post("/doc/", data=docs.to_json()) response_default = await ac.post("/doc_default/", data=docs.to_json()) + embed_content_json = {'fastapi_docs': json.loads(docs.to_json())} response_embed = await ac.post( "/doc_embed/", - json={ - 'fastapi_docs': [ - {'tensor': doc.tensor.tolist()}, - {'tensor': doc.tensor.tolist()}, - ] - }, + json=embed_content_json, ) resp_doc = await ac.get("/docs") resp_redoc = await ac.get("/redoc") @@ -198,3 +195,153 @@ async def func_embed_true( docs_embed = DocList[ImageDoc].from_json(response_embed.content.decode()) assert len(docs_embed) == 2 assert docs_embed[0].tensor.shape == (3, 224, 224) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not is_pydantic_v2, reason='Behavior is only available for Pydantic V2' +) +async def test_doclist_complex_schema(): + from fastapi import Body + + class Nested2Doc(BaseDoc): + value: str + classvar: ClassVar[str] = 'classvar2' + + class Nested1Doc(BaseDoc): + nested: Nested2Doc + classvar: ClassVar[str] = 'classvar1' + + class CustomDoc(BaseDoc): + tensor: Optional[AnyTensor] = None + url: ImageUrl + num: float = 0.5 + num_num: List[float] = [1.5, 2.5] + lll: List[List[List[int]]] = [[[5]]] + fff: List[List[List[float]]] = [[[5.2]]] + single_text: TextDoc + texts: DocList[TextDoc] + d: Dict[str, str] = {'a': 'b'} + di: Optional[Dict[str, int]] = None + u: Union[str, int] + lu: List[Union[str, int]] = [0, 1, 2] + tags: Optional[Dict[str, Any]] = None + nested: Nested1Doc + classvar: ClassVar[str] = 'classvar' + + docs = DocList[CustomDoc]( + [ + CustomDoc( + num=3.5, + num_num=[4.5, 5.5], + url='photo.jpg', + lll=[[[40]]], + fff=[[[40.2]]], + d={'b': 'a'}, + texts=DocList[TextDoc]([TextDoc(text='hey ha', embedding=np.zeros(3))]), + single_text=TextDoc(text='single hey ha', embedding=np.zeros(2)), + u='a', + lu=[3, 4], + nested=Nested1Doc(nested=Nested2Doc(value='hello world')), + ) + ] + ) + + app = FastAPI() + + @app.post("/doc/", response_class=DocArrayResponse) + async def func_embed_false( + fastapi_docs: DocList[CustomDoc] = Body(embed=False), + ) -> DocList[CustomDoc]: + for doc in fastapi_docs: + doc.tensor = np.zeros((10, 10, 10)) + doc.di = {'a': 2} + + return fastapi_docs + + @app.post("/doc_default/", response_class=DocArrayResponse) + async def func_default(fastapi_docs: DocList[CustomDoc]) -> DocList[CustomDoc]: + for doc in fastapi_docs: + doc.tensor = np.zeros((10, 10, 10)) + doc.di = {'a': 2} + return fastapi_docs + + @app.post("/doc_embed/", response_class=DocArrayResponse) + async def func_embed_true( + fastapi_docs: DocList[CustomDoc] = Body(embed=True), + ) -> DocList[CustomDoc]: + for doc in fastapi_docs: + doc.tensor = np.zeros((10, 10, 10)) + doc.di = {'a': 2} + return fastapi_docs + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=docs.to_json()) + response_default = await ac.post("/doc_default/", data=docs.to_json()) + embed_content_json = {'fastapi_docs': json.loads(docs.to_json())} + response_embed = await ac.post( + "/doc_embed/", + json=embed_content_json, + ) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert response_default.status_code == 200 + assert response_embed.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + docs_response = DocList[CustomDoc].from_json(response.content.decode()) + assert len(docs_response) == 1 + assert docs_response[0].url == 'photo.jpg' + assert docs_response[0].num == 3.5 + assert docs_response[0].num_num == [4.5, 5.5] + assert docs_response[0].lll == [[[40]]] + assert docs_response[0].lu == [3, 4] + assert docs_response[0].fff == [[[40.2]]] + assert docs_response[0].di == {'a': 2} + assert docs_response[0].d == {'b': 'a'} + assert len(docs_response[0].texts) == 1 + assert docs_response[0].texts[0].text == 'hey ha' + assert docs_response[0].texts[0].embedding.shape == (3,) + assert docs_response[0].tensor.shape == (10, 10, 10) + assert docs_response[0].u == 'a' + assert docs_response[0].single_text.text == 'single hey ha' + assert docs_response[0].single_text.embedding.shape == (2,) + + docs_default = DocList[CustomDoc].from_json(response_default.content.decode()) + assert len(docs_default) == 1 + assert docs_default[0].url == 'photo.jpg' + assert docs_default[0].num == 3.5 + assert docs_default[0].num_num == [4.5, 5.5] + assert docs_default[0].lll == [[[40]]] + assert docs_default[0].lu == [3, 4] + assert docs_default[0].fff == [[[40.2]]] + assert docs_default[0].di == {'a': 2} + assert docs_default[0].d == {'b': 'a'} + assert len(docs_default[0].texts) == 1 + assert docs_default[0].texts[0].text == 'hey ha' + assert docs_default[0].texts[0].embedding.shape == (3,) + assert docs_default[0].tensor.shape == (10, 10, 10) + assert docs_default[0].u == 'a' + assert docs_default[0].single_text.text == 'single hey ha' + assert docs_default[0].single_text.embedding.shape == (2,) + + docs_embed = DocList[CustomDoc].from_json(response_embed.content.decode()) + assert len(docs_embed) == 1 + assert docs_embed[0].url == 'photo.jpg' + assert docs_embed[0].num == 3.5 + assert docs_embed[0].num_num == [4.5, 5.5] + assert docs_embed[0].lll == [[[40]]] + assert docs_embed[0].lu == [3, 4] + assert docs_embed[0].fff == [[[40.2]]] + assert docs_embed[0].di == {'a': 2} + assert docs_embed[0].d == {'b': 'a'} + assert len(docs_embed[0].texts) == 1 + assert docs_embed[0].texts[0].text == 'hey ha' + assert docs_embed[0].texts[0].embedding.shape == (3,) + assert docs_embed[0].tensor.shape == (10, 10, 10) + assert docs_embed[0].u == 'a' + assert docs_embed[0].single_text.text == 'single hey ha' + assert docs_embed[0].single_text.embedding.shape == (2,) diff --git a/tests/units/util/test_create_dynamic_code_class.py b/tests/units/util/test_create_dynamic_code_class.py index eba25911c4..3fa4fe26cb 100644 --- a/tests/units/util/test_create_dynamic_code_class.py +++ b/tests/units/util/test_create_dynamic_code_class.py @@ -45,6 +45,7 @@ class CustomDoc(BaseDoc): new_custom_doc_model = create_base_doc_from_schema( CustomDocCopy.schema(), 'CustomDoc', {} ) + print(f'new_custom_doc_model {new_custom_doc_model.schema()}') original_custom_docs = DocList[CustomDoc]( [ @@ -131,6 +132,7 @@ class TextDocWithId(BaseDoc): new_textdoc_with_id_model = create_base_doc_from_schema( TextDocWithIdCopy.schema(), 'TextDocWithId', {} ) + print(f'new_textdoc_with_id_model {new_textdoc_with_id_model.schema()}') original_text_doc_with_id = DocList[TextDocWithId]( [TextDocWithId(ia=f'ID {i}') for i in range(10)] @@ -207,6 +209,7 @@ class CustomDoc(BaseDoc): new_custom_doc_model = create_base_doc_from_schema( CustomDocCopy.schema(), 'CustomDoc' ) + print(f'new_custom_doc_model {new_custom_doc_model.schema()}') original_custom_docs = DocList[CustomDoc]() if transformation == 'proto': @@ -232,6 +235,7 @@ class TextDocWithId(BaseDoc): new_textdoc_with_id_model = create_base_doc_from_schema( TextDocWithIdCopy.schema(), 'TextDocWithId', {} ) + print(f'new_textdoc_with_id_model {new_textdoc_with_id_model.schema()}') original_text_doc_with_id = DocList[TextDocWithId]() if transformation == 'proto': @@ -255,6 +259,9 @@ class ResultTestDoc(BaseDoc): new_result_test_doc_with_id_model = create_base_doc_from_schema( ResultTestDocCopy.schema(), 'ResultTestDoc', {} ) + print( + f'new_result_test_doc_with_id_model {new_result_test_doc_with_id_model.schema()}' + ) result_test_docs = DocList[ResultTestDoc]() if transformation == 'proto': @@ -309,9 +316,10 @@ class SearchResult(BaseDoc): models_created_by_name = {} SearchResult_aux = create_pure_python_type_model(SearchResult) - _ = create_base_doc_from_schema( + m = create_base_doc_from_schema( SearchResult_aux.schema(), 'SearchResult', models_created_by_name ) + print(f'm {m.schema()}') QuoteFile_reconstructed_in_gateway_from_Search_results = models_created_by_name[ 'QuoteFile' ] From f422e88e88a4980a95147638663ce99bacdd593c Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 18 Mar 2025 10:21:46 +0100 Subject: [PATCH 22/28] fix: improve cleaning refs --- docarray/utils/create_dynamic_doc_class.py | 29 ++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index d3ea2720b8..30458df995 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -268,6 +268,24 @@ class MyDoc(BaseDoc): :param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas. :return: A BaseDoc class dynamically created following the `schema`. """ + + def clean_refs(value): + """Recursively remove $ref keys and #/$defs values from a data structure.""" + if isinstance(value, dict): + # Create a new dictionary without $ref keys and without values containing #/$defs + cleaned_dict = {} + for k, v in value.items(): + if k == '$ref': + continue + cleaned_dict[k] = clean_refs(v) + return cleaned_dict + elif isinstance(value, list): + # Process each item in the list + return [clean_refs(item) for item in value] + else: + # Return primitive values as-is + return value + if not definitions: definitions = ( schema.get('definitions', {}) if not is_pydantic_v2 else schema.get('$defs') @@ -303,8 +321,15 @@ class MyDoc(BaseDoc): if k in FieldInfo.__slots__: field_kwargs[k] = v else: - if k != '$ref' and '#/$defs' not in str(v): - field_json_schema_extra[k] = v + if k != '$ref': + if isinstance(v, dict): + cleaned_v = clean_refs(v) + if ( + cleaned_v + ): # Only add if there's something left after cleaning + field_json_schema_extra[k] = cleaned_v + else: + field_json_schema_extra[k] = v fields[field_name] = ( field_type, From c7e9bf618c7e729aa40e244802dcae0cdd068903 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 18 Mar 2025 12:12:06 +0100 Subject: [PATCH 23/28] fix: get from definitions --- docarray/utils/create_dynamic_doc_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index 30458df995..accdfa04aa 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -113,7 +113,7 @@ def _get_field_annotation_from_schema( ref_name = obj_ref.split('/')[-1] any_of_types.append( create_base_doc_from_schema( - root_schema['definitions'][ref_name], + definitions[ref_name], ref_name, cached_models=cached_models, definitions=definitions, From 5d3e73fd2dcf29ad2f838c6d9ccaffb825f5c4f7 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 18 Mar 2025 12:13:38 +0100 Subject: [PATCH 24/28] fix: remove unneeded argument --- docarray/utils/create_dynamic_doc_class.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index accdfa04aa..96a4391458 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -83,7 +83,6 @@ class MyDoc(BaseDoc): def _get_field_annotation_from_schema( field_schema: Dict[str, Any], field_name: str, - root_schema: Dict[str, Any], cached_models: Dict[str, Any], is_tensor: bool = False, num_recursions: int = 0, @@ -93,7 +92,6 @@ def _get_field_annotation_from_schema( Private method used to extract the corresponding field type from the schema. :param field_schema: The schema from which to extract the type :param field_name: The name of the field to be created - :param root_schema: The schema of the root object, important to get references :param cached_models: Parameter used when this method is called recursively to reuse partial nested classes. :param is_tensor: Boolean used to tell between tensor and list :param num_recursions: Number of recursions to properly handle nested types (Dict, List, etc ..) @@ -124,7 +122,6 @@ def _get_field_annotation_from_schema( _get_field_annotation_from_schema( any_of_schema, field_name, - root_schema=root_schema, cached_models=cached_models, is_tensor=tensor_shape is not None, num_recursions=0, @@ -207,7 +204,6 @@ def _get_field_annotation_from_schema( ret = _get_field_annotation_from_schema( field_schema=field_schema.get('items', {}), field_name=field_name, - root_schema=root_schema, cached_models=cached_models, is_tensor=tensor_shape is not None, num_recursions=num_recursions + 1, @@ -302,7 +298,6 @@ def clean_refs(value): field_type = _get_field_annotation_from_schema( field_schema=field_schema, field_name=field_name, - root_schema=schema, cached_models=cached_models, is_tensor=False, num_recursions=0, From e12c03ac73ba9c46ee902379ddfd00814d40158a Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 18 Mar 2025 13:46:00 +0100 Subject: [PATCH 25/28] fix: fix update --- docarray/base_doc/mixins/update.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docarray/base_doc/mixins/update.py b/docarray/base_doc/mixins/update.py index 721f8225eb..7ce596ce1a 100644 --- a/docarray/base_doc/mixins/update.py +++ b/docarray/base_doc/mixins/update.py @@ -110,9 +110,7 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups: if field_name not in FORBIDDEN_FIELDS_TO_UPDATE: field_type = doc._get_field_annotation(field_name) - if isinstance(field_type, type) and safe_issubclass( - field_type, DocList - ): + if safe_issubclass(field_type, DocList): nested_docarray_fields.append(field_name) else: origin = get_origin(field_type) From a6fef44fcf973bbd94f47a13384484a72f1bacef Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 18 Mar 2025 16:56:39 +0100 Subject: [PATCH 26/28] fix: handle ID optional --- docarray/utils/create_dynamic_doc_class.py | 4 ++ tests/integrations/externals/test_fastapi.py | 38 +++++++++++++++++++ .../util/test_create_dynamic_code_class.py | 25 ++++++++++++ 3 files changed, 67 insertions(+) diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index 96a4391458..c82a7c8948 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -295,6 +295,7 @@ def clean_refs(value): for field_name, field_schema in schema.get('properties', {}).items(): if field_name == 'id': has_id = True + # Get the field type field_type = _get_field_annotation_from_schema( field_schema=field_schema, field_name=field_name, @@ -313,6 +314,9 @@ def clean_refs(value): field_kwargs = {} field_json_schema_extra = {} for k, v in field_schema.items(): + if field_name == 'id': + # Skip default_factory for Optional fields and use None + field_kwargs['default'] = None if k in FieldInfo.__slots__: field_kwargs[k] = v else: diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 1d0d6ebba6..821852f884 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -345,3 +345,41 @@ async def func_embed_true( assert docs_embed[0].u == 'a' assert docs_embed[0].single_text.text == 'single hey ha' assert docs_embed[0].single_text.embedding.shape == (2,) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not is_pydantic_v2, reason='Behavior is only available for Pydantic V2' +) +async def test_simple_directly(): + app = FastAPI() + + @app.post("/doc_list/", response_class=DocArrayResponse) + async def func_doc_list(fastapi_docs: DocList[TextDoc]) -> DocList[TextDoc]: + return fastapi_docs + + @app.post("/doc_single/", response_class=DocArrayResponse) + async def func_doc_single(fastapi_doc: TextDoc) -> TextDoc: + return fastapi_doc + + async with AsyncClient(app=app, base_url="http://test") as ac: + response_doc_list = await ac.post( + "/doc_list/", data=json.dumps([{"text": "text"}]) + ) + response_single = await ac.post( + "/doc_single/", data=json.dumps({"text": "text"}) + ) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response_doc_list.status_code == 200 + assert response_single.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + docs = DocList[TextDoc].from_json(response_doc_list.content.decode()) + assert len(docs) == 1 + assert docs[0].text == 'text' + + doc = TextDoc.from_json(response_single.content.decode()) + assert doc == 'text' diff --git a/tests/units/util/test_create_dynamic_code_class.py b/tests/units/util/test_create_dynamic_code_class.py index 3fa4fe26cb..b7df497816 100644 --- a/tests/units/util/test_create_dynamic_code_class.py +++ b/tests/units/util/test_create_dynamic_code_class.py @@ -331,3 +331,28 @@ class SearchResult(BaseDoc): QuoteFile_reconstructed_in_gateway_from_Search_results(id='0', texts=textlist) ) assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey' + + +def test_id_optional(): + from docarray import BaseDoc + import json + + class MyTextDoc(BaseDoc): + text: str + opt: Optional[str] = None + + MyTextDoc_aux = create_pure_python_type_model(MyTextDoc) + td = create_base_doc_from_schema(MyTextDoc_aux.schema(), 'MyTextDoc') + print(f'{td.schema()}') + direct = MyTextDoc.from_json(json.dumps({"text": "text"})) + aux = MyTextDoc_aux.from_json(json.dumps({"text": "text"})) + indirect = td.from_json(json.dumps({"text": "text"})) + assert direct.text == 'text' + assert aux.text == 'text' + assert indirect.text == 'text' + direct = MyTextDoc(text='hey') + aux = MyTextDoc_aux(text='hey') + indirect = td(text='hey') + assert direct.text == 'hey' + assert aux.text == 'hey' + assert indirect.text == 'hey' From 559c0ee105f83ce0f57253b038ee47c208dd53c5 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 18 Mar 2025 18:19:22 +0100 Subject: [PATCH 27/28] fix: remove problematic action --- .github/workflows/cd.yml | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 5f565ecb7a..e0a14b5252 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -35,20 +35,16 @@ jobs: steps: - uses: actions/checkout@v3 with: - fetch-depth: 0 - - - name: Get changed files - id: changed-files-specific - uses: tj-actions/changed-files@v41 - with: - files: | - README.md + fetch-depth: 2 - name: Check if README is modified id: step_output - if: steps.changed-files-specific.outputs.any_changed == 'true' run: | - echo "readme_changed=true" >> $GITHUB_OUTPUT + if git diff --name-only HEAD^ HEAD | grep -q "README.md"; then + echo "readme_changed=true" >> $GITHUB_OUTPUT + else + echo "readme_changed=false" >> $GITHUB_OUTPUT + fi publish-docarray-org: needs: check-readme-modification From 8265d959a4adf1c7bd9e4ce8ff1b7cd8a182daf1 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Wed, 19 Mar 2025 18:43:53 +0100 Subject: [PATCH 28/28] fix: fix resp as json --- docarray/typing/tensor/abstract_tensor.py | 2 +- tests/integrations/externals/test_fastapi.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 4836e39dde..e7e4fbe705 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -398,7 +398,7 @@ def __get_pydantic_core_schema__( return core_schema.with_info_plain_validator_function( cls.validate, serialization=core_schema.plain_serializer_function_ser_schema( - function=orjson_dumps, + function=lambda x: x._docarray_to_ndarray().tolist(), return_schema=handler.generate_schema(bytes), when_used="json-unless-none", ), diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 821852f884..c5ef186821 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -227,6 +227,7 @@ class CustomDoc(BaseDoc): lu: List[Union[str, int]] = [0, 1, 2] tags: Optional[Dict[str, Any]] = None nested: Nested1Doc + embedding: NdArray classvar: ClassVar[str] = 'classvar' docs = DocList[CustomDoc]( @@ -242,6 +243,7 @@ class CustomDoc(BaseDoc): single_text=TextDoc(text='single hey ha', embedding=np.zeros(2)), u='a', lu=[3, 4], + embedding=np.random.random((1, 4)), nested=Nested1Doc(nested=Nested2Doc(value='hello world')), ) ] @@ -292,6 +294,11 @@ async def func_embed_true( assert resp_doc.status_code == 200 assert resp_redoc.status_code == 200 + resp_json = json.loads(response_default.content.decode()) + assert isinstance(resp_json[0]["tensor"], list) + assert isinstance(resp_json[0]["embedding"], list) + assert isinstance(resp_json[0]["texts"][0]["embedding"], list) + docs_response = DocList[CustomDoc].from_json(response.content.decode()) assert len(docs_response) == 1 assert docs_response[0].url == 'photo.jpg' pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy