diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index a1aae08ec9b..e0a14b5252c 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -21,7 +21,7 @@ jobs: - name: Pre-release (.devN) run: | git fetch --depth=1 origin +refs/tags/*:refs/tags/* - pip install poetry + pip install poetry==1.7.1 ./scripts/release.sh env: PYPI_USERNAME: ${{ secrets.TWINE_USERNAME }} @@ -35,20 +35,16 @@ jobs: steps: - uses: actions/checkout@v3 with: - fetch-depth: 0 - - - name: Get changed files - id: changed-files-specific - uses: tj-actions/changed-files@v41 - with: - files: | - README.md + fetch-depth: 2 - name: Check if README is modified id: step_output - if: steps.changed-files-specific.outputs.any_changed == 'true' run: | - echo "readme_changed=true" >> $GITHUB_OUTPUT + if git diff --name-only HEAD^ HEAD | grep -q "README.md"; then + echo "readme_changed=true" >> $GITHUB_OUTPUT + else + echo "readme_changed=false" >> $GITHUB_OUTPUT + fi publish-docarray-org: needs: check-readme-modification diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0e98f9ce7be..07c32d0b873 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: - name: Lint with ruff run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install # stop the build if there are Python syntax errors or undefined names @@ -44,7 +44,7 @@ jobs: - name: check black run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --only dev poetry run black --check . @@ -62,7 +62,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --without dev poetry run pip install tensorflow==2.12.0 poetry run pip install jax @@ -106,7 +106,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras poetry run pip install elasticsearch==8.6.2 ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} @@ -156,7 +156,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip install protobuf==3.20.0 # we check that we support 3.19 @@ -204,7 +204,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip install protobuf==3.20.0 @@ -253,7 +253,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip install protobuf==3.20.0 @@ -302,7 +302,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip install protobuf==3.20.0 @@ -351,7 +351,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras ./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }} poetry run pip uninstall -y torch @@ -398,7 +398,7 @@ jobs: - name: Prepare environment run: | python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 poetry install --all-extras poetry run pip uninstall -y torch poetry run pip install torch diff --git a/.github/workflows/ci_only_pr.yml b/.github/workflows/ci_only_pr.yml index 1e8d3f9694f..9d040e72b62 100644 --- a/.github/workflows/ci_only_pr.yml +++ b/.github/workflows/ci_only_pr.yml @@ -43,7 +43,7 @@ jobs: run: | npm i -g netlify-cli python -m pip install --upgrade pip - python -m pip install poetry + python -m pip install poetry==1.7.1 python -m poetry config virtualenvs.create false && python -m poetry install --no-interaction --no-ansi --all-extras cd docs diff --git a/docarray/__init__.py b/docarray/__init__.py index 6ce3f9eb90f..5a18bb9588b 100644 --- a/docarray/__init__.py +++ b/docarray/__init__.py @@ -20,6 +20,60 @@ from docarray.array import DocList, DocVec from docarray.base_doc.doc import BaseDoc from docarray.utils._internal.misc import _get_path_from_docarray_root_level +from docarray.utils._internal.pydantic import is_pydantic_v2 + + +def unpickle_doclist(doc_type, b): + return DocList[doc_type].from_bytes(b, protocol="protobuf") + + +def unpickle_docvec(doc_type, tensor_type, b): + return DocVec[doc_type].from_bytes(b, protocol="protobuf", tensor_type=tensor_type) + + +if is_pydantic_v2: + # Register the pickle functions + def register_serializers(): + import copyreg + from functools import partial + + unpickle_doc_fn = partial(BaseDoc.from_bytes, protocol="protobuf") + + def pickle_doc(doc): + b = doc.to_bytes(protocol='protobuf') + return unpickle_doc_fn, (doc.__class__, b) + + # Register BaseDoc serialization + copyreg.pickle(BaseDoc, pickle_doc) + + # For DocList, we need to hook into __reduce__ since it's a generic + + def pickle_doclist(doc_list): + b = doc_list.to_bytes(protocol='protobuf') + doc_type = doc_list.doc_type + return unpickle_doclist, (doc_type, b) + + # Replace DocList.__reduce__ with a method that returns the correct format + def doclist_reduce(self): + return pickle_doclist(self) + + DocList.__reduce__ = doclist_reduce + + # For DocVec, we need to hook into __reduce__ since it's a generic + + def pickle_docvec(doc_vec): + b = doc_vec.to_bytes(protocol='protobuf') + doc_type = doc_vec.doc_type + tensor_type = doc_vec.tensor_type + return unpickle_docvec, (doc_type, tensor_type, b) + + # Replace DocList.__reduce__ with a method that returns the correct format + def docvec_reduce(self): + return pickle_docvec(self) + + DocVec.__reduce__ = docvec_reduce + + register_serializers() __all__ = ['BaseDoc', 'DocList', 'DocVec'] diff --git a/docarray/array/any_array.py b/docarray/array/any_array.py index 50c47cf4ec4..0c29e54ae82 100644 --- a/docarray/array/any_array.py +++ b/docarray/array/any_array.py @@ -25,6 +25,7 @@ from docarray.exceptions.exceptions import UnusableObjectError from docarray.typing.abstract_type import AbstractType from docarray.utils._internal._typing import change_cls_name, safe_issubclass +from docarray.utils._internal.pydantic import is_pydantic_v2 if TYPE_CHECKING: from docarray.proto import DocListProto, NodeProto @@ -73,8 +74,19 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]): # Promote to global scope so multiprocessing can pickle it global _DocArrayTyped - class _DocArrayTyped(cls): # type: ignore - doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item) + if not is_pydantic_v2: + + class _DocArrayTyped(cls): # type: ignore + doc_type: Type[BaseDocWithoutId] = cast( + Type[BaseDocWithoutId], item + ) + + else: + + class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore + doc_type: Type[BaseDocWithoutId] = cast( + Type[BaseDocWithoutId], item + ) for field in _DocArrayTyped.doc_type._docarray_fields().keys(): @@ -99,14 +111,24 @@ def _setter(self, value): setattr(_DocArrayTyped, field, _property_generator(field)) # this generates property on the fly based on the schema of the item - # The global scope and qualname need to refer to this class a unique name. - # Otherwise, creating another _DocArrayTyped will overwrite this one. - change_cls_name( - _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals() - ) - - cls.__typed_da__[cls][item] = _DocArrayTyped + # # The global scope and qualname need to refer to this class a unique name. + # # Otherwise, creating another _DocArrayTyped will overwrite this one. + if not is_pydantic_v2: + change_cls_name( + _DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals() + ) + cls.__typed_da__[cls][item] = _DocArrayTyped + else: + change_cls_name(_DocArrayTyped, f'{cls.__name__}', globals()) + if sys.version_info < (3, 12): + cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__( + _DocArrayTyped, item + ) # type: ignore + # this do nothing that checking that item is valid type var or str + # Keep the approach in #1147 to be compatible with lower versions of Python. + else: + cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore return cls.__typed_da__[cls][item] @overload diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index c21cf934132..49236199153 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -12,6 +12,7 @@ Union, cast, overload, + Callable, ) from pydantic import parse_obj_as @@ -28,7 +29,6 @@ from docarray.utils._internal.pydantic import is_pydantic_v2 if is_pydantic_v2: - from pydantic import GetCoreSchemaHandler from pydantic_core import core_schema from docarray.utils._internal._typing import safe_issubclass @@ -45,10 +45,7 @@ class DocList( - ListAdvancedIndexing[T_doc], - PushPullMixin, - IOMixinDocList, - AnyDocArray[T_doc], + ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc] ): """ DocList is a container of Documents. @@ -357,8 +354,20 @@ def __repr__(self): @classmethod def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: GetCoreSchemaHandler + cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema] ) -> core_schema.CoreSchema: - return core_schema.general_plain_validator_function( - cls.validate, + instance_schema = core_schema.is_instance_schema(cls) + args = getattr(source, '__args__', None) + if args: + sequence_t_schema = handler(Sequence[args[0]]) + else: + sequence_t_schema = handler(Sequence) + + def validate_fn(v, info): + # input has already been validated + return cls(v, validate_input_docs=False) + + non_instance_schema = core_schema.with_info_after_validator_function( + validate_fn, sequence_t_schema ) + return core_schema.union_schema([instance_schema, non_instance_schema]) diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index 82d00197e26..3acb66bf6e8 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -256,7 +256,6 @@ def to_bytes( :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` :return: the binary serialization in bytes or None if file_ctx is passed where to store """ - with file_ctx or io.BytesIO() as bf: self._write_bytes( bf=bf, diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index 9d515cfd96f..0cc462f173d 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -198,7 +198,7 @@ def _check_doc_field_not_none(field_name, doc): if safe_issubclass(tensor.__class__, tensor_type): field_type = tensor_type - if isinstance(field_type, type): + if isinstance(field_type, type) or safe_issubclass(field_type, AnyDocArray): if tf_available and safe_issubclass(field_type, TensorFlowTensor): # tf.Tensor does not allow item assignment, therefore the # optimized way @@ -335,7 +335,9 @@ def _docarray_validate( return cast(T, value.to_doc_vec()) else: raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}') - elif isinstance(value, DocList.__class_getitem__(cls.doc_type)): + elif not is_pydantic_v2 and isinstance( + value, DocList.__class_getitem__(cls.doc_type) + ): return cast(T, value.to_doc_vec()) elif isinstance(value, Sequence): return cls(value) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 48fb3076cd0..e880504bc05 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -326,8 +326,13 @@ def _exclude_doclist( from docarray.array.any_array import AnyDocArray type_ = self._get_field_annotation(field) - if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray): - doclist_exclude_fields.append(field) + if is_pydantic_v2: + # Conservative when touching pydantic v1 logic + if safe_issubclass(type_, AnyDocArray): + doclist_exclude_fields.append(field) + else: + if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray): + doclist_exclude_fields.append(field) original_exclude = exclude if exclude is None: @@ -480,7 +485,6 @@ def model_dump( # type: ignore warnings: bool = True, ) -> Dict[str, Any]: def _model_dump(doc): - ( exclude_, original_exclude, diff --git a/docarray/base_doc/mixins/update.py b/docarray/base_doc/mixins/update.py index 721f8225ebb..7ce596ce1aa 100644 --- a/docarray/base_doc/mixins/update.py +++ b/docarray/base_doc/mixins/update.py @@ -110,9 +110,7 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups: if field_name not in FORBIDDEN_FIELDS_TO_UPDATE: field_type = doc._get_field_annotation(field_name) - if isinstance(field_type, type) and safe_issubclass( - field_type, DocList - ): + if safe_issubclass(field_type, DocList): nested_docarray_fields.append(field_name) else: origin = get_origin(field_type) diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index c008fa29de0..a335f85e32a 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -352,12 +352,12 @@ def python_type_to_db_type(self, python_type: Type) -> Any: dict: 'object', } - for type in elastic_py_types.keys(): - if safe_issubclass(python_type, type): + for t in elastic_py_types.keys(): + if safe_issubclass(python_type, t): self._logger.info( - f'Mapped Python type {python_type} to database type "{elastic_py_types[type]}"' + f'Mapped Python type {python_type} to database type "{elastic_py_types[t]}"' ) - return elastic_py_types[type] + return elastic_py_types[t] err_msg = f'Unsupported column type for {type(self)}: {python_type}' self._logger.error(err_msg) diff --git a/docarray/index/backends/epsilla.py b/docarray/index/backends/epsilla.py index 83c171daed0..0392e9d010e 100644 --- a/docarray/index/backends/epsilla.py +++ b/docarray/index/backends/epsilla.py @@ -100,8 +100,8 @@ def __init__(self, db_config=None, **kwargs): def _validate_column_info(self): vector_columns = [] for info in self._column_infos.values(): - for type in [list, np.ndarray, AbstractTensor]: - if safe_issubclass(info.docarray_type, type) and info.config.get( + for t in [list, np.ndarray, AbstractTensor]: + if safe_issubclass(info.docarray_type, t) and info.config.get( 'is_embedding', False ): # check that dimension is present diff --git a/docarray/typing/bytes/base_bytes.py b/docarray/typing/bytes/base_bytes.py index 4c336ae6940..8a944031b4e 100644 --- a/docarray/typing/bytes/base_bytes.py +++ b/docarray/typing/bytes/base_bytes.py @@ -62,7 +62,7 @@ def _to_node_protobuf(self: T) -> 'NodeProto': def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: 'GetCoreSchemaHandler' ) -> 'core_schema.CoreSchema': - return core_schema.general_after_validator_function( + return core_schema.with_info_after_validator_function( cls.validate, core_schema.bytes_schema(), ) diff --git a/docarray/typing/id.py b/docarray/typing/id.py index c06951eaef7..3e3fdd37ae4 100644 --- a/docarray/typing/id.py +++ b/docarray/typing/id.py @@ -77,7 +77,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'str') -> T: def __get_pydantic_core_schema__( cls, source: Type[Any], handler: 'GetCoreSchemaHandler' ) -> core_schema.CoreSchema: - return core_schema.general_plain_validator_function( + return core_schema.with_info_plain_validator_function( cls.validate, ) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 994fe42cc85..e7e4fbe7056 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -395,10 +395,10 @@ def _docarray_to_ndarray(self) -> np.ndarray: def __get_pydantic_core_schema__( cls, _source_type: Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: - return core_schema.general_plain_validator_function( + return core_schema.with_info_plain_validator_function( cls.validate, serialization=core_schema.plain_serializer_function_ser_schema( - function=orjson_dumps, + function=lambda x: x._docarray_to_ndarray().tolist(), return_schema=handler.generate_schema(bytes), when_used="json-unless-none", ), diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index ddd17915132..b7c5d71f835 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -56,7 +56,7 @@ def _docarray_validate( def __get_pydantic_core_schema__( cls, source: Type[Any], handler: Optional['GetCoreSchemaHandler'] = None ) -> core_schema.CoreSchema: - return core_schema.general_after_validator_function( + return core_schema.with_info_after_validator_function( cls._docarray_validate, core_schema.str_schema(), ) diff --git a/docarray/utils/_internal/_typing.py b/docarray/utils/_internal/_typing.py index 83e350a0602..3c2bd89a8e5 100644 --- a/docarray/utils/_internal/_typing.py +++ b/docarray/utils/_internal/_typing.py @@ -61,11 +61,15 @@ def safe_issubclass(x: type, a_tuple: type) -> bool: :return: A boolean value - 'True' if 'x' is a subclass of 'A_tuple', 'False' otherwise. Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'. """ + origin = get_origin(x) + if origin: # If x is a generic type like DocList[SomeDoc], get its origin + x = origin if ( - (get_origin(x) in (list, tuple, dict, set, Union)) + (origin in (list, tuple, dict, set, Union)) or is_typevar(x) or (type(x) == ForwardRef) or is_typevar(x) ): return False - return issubclass(x, a_tuple) + + return isinstance(x, type) and issubclass(x, a_tuple) diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index 744fea58c3e..c82a7c89487 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -54,8 +54,9 @@ class MyDoc(BaseDoc): fields: Dict[str, Any] = {} import copy - fields_copy = copy.deepcopy(model.__fields__) - annotations_copy = copy.deepcopy(model.__annotations__) + copy_model = copy.deepcopy(model) + fields_copy = copy_model.__fields__ + annotations_copy = copy_model.__annotations__ for field_name, field in annotations_copy.items(): if field_name not in fields_copy: continue @@ -65,7 +66,7 @@ class MyDoc(BaseDoc): else: field_info = fields_copy[field_name].field_info try: - if safe_issubclass(field, DocList): + if safe_issubclass(field, DocList) and not is_pydantic_v2: t: Any = field.doc_type t_aux = create_pure_python_type_model(t) fields[field_name] = (List[t_aux], field_info) @@ -74,13 +75,14 @@ class MyDoc(BaseDoc): except TypeError: fields[field_name] = (field, field_info) - return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields) + return create_model( + copy_model.__name__, __base__=copy_model, __doc__=copy_model.__doc__, **fields + ) def _get_field_annotation_from_schema( field_schema: Dict[str, Any], field_name: str, - root_schema: Dict[str, Any], cached_models: Dict[str, Any], is_tensor: bool = False, num_recursions: int = 0, @@ -90,7 +92,6 @@ def _get_field_annotation_from_schema( Private method used to extract the corresponding field type from the schema. :param field_schema: The schema from which to extract the type :param field_name: The name of the field to be created - :param root_schema: The schema of the root object, important to get references :param cached_models: Parameter used when this method is called recursively to reuse partial nested classes. :param is_tensor: Boolean used to tell between tensor and list :param num_recursions: Number of recursions to properly handle nested types (Dict, List, etc ..) @@ -110,7 +111,7 @@ def _get_field_annotation_from_schema( ref_name = obj_ref.split('/')[-1] any_of_types.append( create_base_doc_from_schema( - root_schema['definitions'][ref_name], + definitions[ref_name], ref_name, cached_models=cached_models, definitions=definitions, @@ -121,7 +122,6 @@ def _get_field_annotation_from_schema( _get_field_annotation_from_schema( any_of_schema, field_name, - root_schema=root_schema, cached_models=cached_models, is_tensor=tensor_shape is not None, num_recursions=0, @@ -160,7 +160,10 @@ def _get_field_annotation_from_schema( doc_type: Any if 'additionalProperties' in field_schema: # handle Dictionaries additional_props = field_schema['additionalProperties'] - if additional_props.get('type') == 'object': + if ( + isinstance(additional_props, dict) + and additional_props.get('type') == 'object' + ): doc_type = create_base_doc_from_schema( additional_props, field_name, cached_models=cached_models ) @@ -201,7 +204,6 @@ def _get_field_annotation_from_schema( ret = _get_field_annotation_from_schema( field_schema=field_schema.get('items', {}), field_name=field_name, - root_schema=root_schema, cached_models=cached_models, is_tensor=tensor_shape is not None, num_recursions=num_recursions + 1, @@ -262,6 +264,24 @@ class MyDoc(BaseDoc): :param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas. :return: A BaseDoc class dynamically created following the `schema`. """ + + def clean_refs(value): + """Recursively remove $ref keys and #/$defs values from a data structure.""" + if isinstance(value, dict): + # Create a new dictionary without $ref keys and without values containing #/$defs + cleaned_dict = {} + for k, v in value.items(): + if k == '$ref': + continue + cleaned_dict[k] = clean_refs(v) + return cleaned_dict + elif isinstance(value, list): + # Process each item in the list + return [clean_refs(item) for item in value] + else: + # Return primitive values as-is + return value + if not definitions: definitions = ( schema.get('definitions', {}) if not is_pydantic_v2 else schema.get('$defs') @@ -275,10 +295,10 @@ class MyDoc(BaseDoc): for field_name, field_schema in schema.get('properties', {}).items(): if field_name == 'id': has_id = True + # Get the field type field_type = _get_field_annotation_from_schema( field_schema=field_schema, field_name=field_name, - root_schema=schema, cached_models=cached_models, is_tensor=False, num_recursions=0, @@ -294,10 +314,22 @@ class MyDoc(BaseDoc): field_kwargs = {} field_json_schema_extra = {} for k, v in field_schema.items(): + if field_name == 'id': + # Skip default_factory for Optional fields and use None + field_kwargs['default'] = None if k in FieldInfo.__slots__: field_kwargs[k] = v else: - field_json_schema_extra[k] = v + if k != '$ref': + if isinstance(v, dict): + cleaned_v = clean_refs(v) + if ( + cleaned_v + ): # Only add if there's something left after cleaning + field_json_schema_extra[k] = cleaned_v + else: + field_json_schema_extra[k] = v + fields[field_name] = ( field_type, FieldInfo( diff --git a/tests/benchmark_tests/test_map.py b/tests/benchmark_tests/test_map.py index e5c664a408b..2fc7b09496e 100644 --- a/tests/benchmark_tests/test_map.py +++ b/tests/benchmark_tests/test_map.py @@ -29,9 +29,9 @@ def test_map_docs_multiprocessing(): if os.cpu_count() > 1: def time_multiprocessing(num_workers: int) -> float: - n_docs = 5 + n_docs = 10 rng = np.random.RandomState(0) - matrices = [rng.random(size=(1000, 1000)) for _ in range(n_docs)] + matrices = [rng.random(size=(100, 100)) for _ in range(n_docs)] da = DocList[MyMatrix]([MyMatrix(matrix=m) for m in matrices]) start_time = time() list( @@ -65,7 +65,7 @@ def test_map_docs_batched_multiprocessing(): def time_multiprocessing(num_workers: int) -> float: n_docs = 16 rng = np.random.RandomState(0) - matrices = [rng.random(size=(1000, 1000)) for _ in range(n_docs)] + matrices = [rng.random(size=(100, 100)) for _ in range(n_docs)] da = DocList[MyMatrix]([MyMatrix(matrix=m) for m in matrices]) start_time = time() list( diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index faf146df6f1..73379694284 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -13,6 +13,7 @@ from docarray.typing import ID, ImageBytes, ImageUrl, NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal.misc import torch_imported +from docarray.utils._internal._typing import safe_issubclass pytestmark = pytest.mark.index @@ -54,7 +55,7 @@ class DummyDocIndex(BaseDocIndex): def __init__(self, db_config=None, **kwargs): super().__init__(db_config=db_config, **kwargs) for col_name, col in self._column_infos.items(): - if issubclass(col.docarray_type, AnyDocArray): + if safe_issubclass(col.docarray_type, AnyDocArray): sub_db_config = copy.deepcopy(self._db_config) self._subindices[col_name] = self.__class__[col.docarray_type.doc_type]( db_config=sub_db_config, subindex=True @@ -159,7 +160,7 @@ def test_create_columns(): assert index._column_infos['id'].n_dim is None assert index._column_infos['id'].config['hi'] == 'there' - assert issubclass(index._column_infos['tens'].docarray_type, AbstractTensor) + assert safe_issubclass(index._column_infos['tens'].docarray_type, AbstractTensor) assert index._column_infos['tens'].db_type == str assert index._column_infos['tens'].n_dim == 10 assert index._column_infos['tens'].config == {'dim': 1000, 'hi': 'there'} @@ -173,12 +174,16 @@ def test_create_columns(): assert index._column_infos['id'].n_dim is None assert index._column_infos['id'].config['hi'] == 'there' - assert issubclass(index._column_infos['tens_one'].docarray_type, AbstractTensor) + assert safe_issubclass( + index._column_infos['tens_one'].docarray_type, AbstractTensor + ) assert index._column_infos['tens_one'].db_type == str assert index._column_infos['tens_one'].n_dim is None assert index._column_infos['tens_one'].config == {'dim': 10, 'hi': 'there'} - assert issubclass(index._column_infos['tens_two'].docarray_type, AbstractTensor) + assert safe_issubclass( + index._column_infos['tens_two'].docarray_type, AbstractTensor + ) assert index._column_infos['tens_two'].db_type == str assert index._column_infos['tens_two'].n_dim is None assert index._column_infos['tens_two'].config == {'dim': 50, 'hi': 'there'} @@ -192,7 +197,7 @@ def test_create_columns(): assert index._column_infos['id'].n_dim is None assert index._column_infos['id'].config['hi'] == 'there' - assert issubclass(index._column_infos['d__tens'].docarray_type, AbstractTensor) + assert safe_issubclass(index._column_infos['d__tens'].docarray_type, AbstractTensor) assert index._column_infos['d__tens'].db_type == str assert index._column_infos['d__tens'].n_dim == 10 assert index._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} @@ -206,7 +211,7 @@ def test_create_columns(): 'parent_id', ] - assert issubclass(index._column_infos['d'].docarray_type, AnyDocArray) + assert safe_issubclass(index._column_infos['d'].docarray_type, AnyDocArray) assert index._column_infos['d'].db_type is None assert index._column_infos['d'].n_dim is None assert index._column_infos['d'].config == {} @@ -216,7 +221,7 @@ def test_create_columns(): assert index._subindices['d']._column_infos['id'].n_dim is None assert index._subindices['d']._column_infos['id'].config['hi'] == 'there' - assert issubclass( + assert safe_issubclass( index._subindices['d']._column_infos['tens'].docarray_type, AbstractTensor ) assert index._subindices['d']._column_infos['tens'].db_type == str @@ -245,7 +250,7 @@ def test_create_columns(): 'parent_id', ] - assert issubclass( + assert safe_issubclass( index._subindices['d_root']._column_infos['d'].docarray_type, AnyDocArray ) assert index._subindices['d_root']._column_infos['d'].db_type is None @@ -266,7 +271,7 @@ def test_create_columns(): index._subindices['d_root']._subindices['d']._column_infos['id'].config['hi'] == 'there' ) - assert issubclass( + assert safe_issubclass( index._subindices['d_root'] ._subindices['d'] ._column_infos['tens'] @@ -461,11 +466,16 @@ class OtherNestedDoc(NestedDoc): # SIMPLE index = DummyDocIndex[SimpleDoc]() in_list = [SimpleDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) + in_da = DocList[SimpleDoc](in_list) assert index._validate_docs(in_da) == in_da in_other_list = [OtherSimpleDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_other_list), DocList) + for d in index._validate_docs(in_other_list): + assert isinstance(d, BaseDoc) in_other_da = DocList[OtherSimpleDoc](in_other_list) assert index._validate_docs(in_other_da) == in_other_da @@ -494,7 +504,9 @@ class OtherNestedDoc(NestedDoc): in_list = [ FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) ] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[FlatDoc]( [FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))] ) @@ -502,7 +514,9 @@ class OtherNestedDoc(NestedDoc): in_other_list = [ OtherFlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) ] - assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_other_list), DocList) + for d in index._validate_docs(in_other_list): + assert isinstance(d, BaseDoc) in_other_da = DocList[OtherFlatDoc]( [ OtherFlatDoc( @@ -521,11 +535,15 @@ class OtherNestedDoc(NestedDoc): # NESTED index = DummyDocIndex[NestedDoc]() in_list = [NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[NestedDoc]([NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))]) assert index._validate_docs(in_da) == in_da in_other_list = [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] - assert isinstance(index._validate_docs(in_other_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_other_list), DocList) + for d in index._validate_docs(in_other_list): + assert isinstance(d, BaseDoc) in_other_da = DocList[OtherNestedDoc]( [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] ) @@ -552,7 +570,9 @@ class TensorUnionDoc(BaseDoc): # OPTIONAL index = DummyDocIndex[SimpleDoc]() in_list = [OptionalDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[OptionalDoc](in_list) assert index._validate_docs(in_da) == in_da @@ -562,9 +582,13 @@ class TensorUnionDoc(BaseDoc): # MIXED UNION index = DummyDocIndex[SimpleDoc]() in_list = [MixedUnionDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[MixedUnionDoc](in_list) - assert isinstance(index._validate_docs(in_da), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_da), DocList) + for d in index._validate_docs(in_da): + assert isinstance(d, BaseDoc) with pytest.raises(ValueError): index._validate_docs([MixedUnionDoc(tens='hello')]) @@ -572,13 +596,17 @@ class TensorUnionDoc(BaseDoc): # TENSOR UNION index = DummyDocIndex[TensorUnionDoc]() in_list = [SimpleDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[SimpleDoc](in_list) assert index._validate_docs(in_da) == in_da index = DummyDocIndex[SimpleDoc]() in_list = [TensorUnionDoc(tens=np.random.random((10,)))] - assert isinstance(index._validate_docs(in_list), DocList[BaseDoc]) + assert isinstance(index._validate_docs(in_list), DocList) + for d in index._validate_docs(in_list): + assert isinstance(d, BaseDoc) in_da = DocList[TensorUnionDoc](in_list) assert index._validate_docs(in_da) == in_da diff --git a/tests/integrations/array/test_optional_doc_vec.py b/tests/integrations/array/test_optional_doc_vec.py index bb793152d3d..dd77c66762b 100644 --- a/tests/integrations/array/test_optional_doc_vec.py +++ b/tests/integrations/array/test_optional_doc_vec.py @@ -20,7 +20,8 @@ class Image(BaseDoc): docs.features = [Features(tensor=np.random.random([100])) for _ in range(10)] print(docs.features) # - assert isinstance(docs.features, DocVec[Features]) + assert isinstance(docs.features, DocVec) + assert isinstance(docs.features[0], Features) docs.features.tensor = np.ones((10, 100)) diff --git a/tests/integrations/externals/test_fastapi.py b/tests/integrations/externals/test_fastapi.py index 02967a07cd0..c5ef1868219 100644 --- a/tests/integrations/externals/test_fastapi.py +++ b/tests/integrations/externals/test_fastapi.py @@ -1,5 +1,5 @@ -from typing import List - +from typing import Any, Dict, List, Optional, Union, ClassVar +import json import numpy as np import pytest from fastapi import FastAPI @@ -8,7 +8,9 @@ from docarray import BaseDoc, DocList from docarray.base_doc import DocArrayResponse from docarray.documents import ImageDoc, TextDoc -from docarray.typing import NdArray +from docarray.typing import NdArray, AnyTensor, ImageUrl + +from docarray.utils._internal.pydantic import is_pydantic_v2 @pytest.mark.asyncio @@ -135,3 +137,256 @@ async def func(fastapi_docs: List[ImageDoc]) -> List[ImageDoc]: docs = DocList[ImageDoc].from_json(response.content.decode()) assert len(docs) == 2 assert docs[0].tensor.shape == (3, 224, 224) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not is_pydantic_v2, reason='Behavior is only available for Pydantic V2' +) +async def test_doclist_directly(): + from fastapi import Body + + doc = ImageDoc(tensor=np.zeros((3, 224, 224)), url='url') + docs = DocList[ImageDoc]([doc, doc]) + + app = FastAPI() + + @app.post("/doc/", response_class=DocArrayResponse) + async def func_embed_false( + fastapi_docs: DocList[ImageDoc] = Body(embed=False), + ) -> DocList[ImageDoc]: + return fastapi_docs + + @app.post("/doc_default/", response_class=DocArrayResponse) + async def func_default(fastapi_docs: DocList[ImageDoc]) -> DocList[ImageDoc]: + return fastapi_docs + + @app.post("/doc_embed/", response_class=DocArrayResponse) + async def func_embed_true( + fastapi_docs: DocList[ImageDoc] = Body(embed=True), + ) -> DocList[ImageDoc]: + return fastapi_docs + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=docs.to_json()) + response_default = await ac.post("/doc_default/", data=docs.to_json()) + embed_content_json = {'fastapi_docs': json.loads(docs.to_json())} + response_embed = await ac.post( + "/doc_embed/", + json=embed_content_json, + ) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert response_default.status_code == 200 + assert response_embed.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + docs = DocList[ImageDoc].from_json(response.content.decode()) + assert len(docs) == 2 + assert docs[0].tensor.shape == (3, 224, 224) + + docs_default = DocList[ImageDoc].from_json(response_default.content.decode()) + assert len(docs_default) == 2 + assert docs_default[0].tensor.shape == (3, 224, 224) + + docs_embed = DocList[ImageDoc].from_json(response_embed.content.decode()) + assert len(docs_embed) == 2 + assert docs_embed[0].tensor.shape == (3, 224, 224) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not is_pydantic_v2, reason='Behavior is only available for Pydantic V2' +) +async def test_doclist_complex_schema(): + from fastapi import Body + + class Nested2Doc(BaseDoc): + value: str + classvar: ClassVar[str] = 'classvar2' + + class Nested1Doc(BaseDoc): + nested: Nested2Doc + classvar: ClassVar[str] = 'classvar1' + + class CustomDoc(BaseDoc): + tensor: Optional[AnyTensor] = None + url: ImageUrl + num: float = 0.5 + num_num: List[float] = [1.5, 2.5] + lll: List[List[List[int]]] = [[[5]]] + fff: List[List[List[float]]] = [[[5.2]]] + single_text: TextDoc + texts: DocList[TextDoc] + d: Dict[str, str] = {'a': 'b'} + di: Optional[Dict[str, int]] = None + u: Union[str, int] + lu: List[Union[str, int]] = [0, 1, 2] + tags: Optional[Dict[str, Any]] = None + nested: Nested1Doc + embedding: NdArray + classvar: ClassVar[str] = 'classvar' + + docs = DocList[CustomDoc]( + [ + CustomDoc( + num=3.5, + num_num=[4.5, 5.5], + url='photo.jpg', + lll=[[[40]]], + fff=[[[40.2]]], + d={'b': 'a'}, + texts=DocList[TextDoc]([TextDoc(text='hey ha', embedding=np.zeros(3))]), + single_text=TextDoc(text='single hey ha', embedding=np.zeros(2)), + u='a', + lu=[3, 4], + embedding=np.random.random((1, 4)), + nested=Nested1Doc(nested=Nested2Doc(value='hello world')), + ) + ] + ) + + app = FastAPI() + + @app.post("/doc/", response_class=DocArrayResponse) + async def func_embed_false( + fastapi_docs: DocList[CustomDoc] = Body(embed=False), + ) -> DocList[CustomDoc]: + for doc in fastapi_docs: + doc.tensor = np.zeros((10, 10, 10)) + doc.di = {'a': 2} + + return fastapi_docs + + @app.post("/doc_default/", response_class=DocArrayResponse) + async def func_default(fastapi_docs: DocList[CustomDoc]) -> DocList[CustomDoc]: + for doc in fastapi_docs: + doc.tensor = np.zeros((10, 10, 10)) + doc.di = {'a': 2} + return fastapi_docs + + @app.post("/doc_embed/", response_class=DocArrayResponse) + async def func_embed_true( + fastapi_docs: DocList[CustomDoc] = Body(embed=True), + ) -> DocList[CustomDoc]: + for doc in fastapi_docs: + doc.tensor = np.zeros((10, 10, 10)) + doc.di = {'a': 2} + return fastapi_docs + + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.post("/doc/", data=docs.to_json()) + response_default = await ac.post("/doc_default/", data=docs.to_json()) + embed_content_json = {'fastapi_docs': json.loads(docs.to_json())} + response_embed = await ac.post( + "/doc_embed/", + json=embed_content_json, + ) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response.status_code == 200 + assert response_default.status_code == 200 + assert response_embed.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + resp_json = json.loads(response_default.content.decode()) + assert isinstance(resp_json[0]["tensor"], list) + assert isinstance(resp_json[0]["embedding"], list) + assert isinstance(resp_json[0]["texts"][0]["embedding"], list) + + docs_response = DocList[CustomDoc].from_json(response.content.decode()) + assert len(docs_response) == 1 + assert docs_response[0].url == 'photo.jpg' + assert docs_response[0].num == 3.5 + assert docs_response[0].num_num == [4.5, 5.5] + assert docs_response[0].lll == [[[40]]] + assert docs_response[0].lu == [3, 4] + assert docs_response[0].fff == [[[40.2]]] + assert docs_response[0].di == {'a': 2} + assert docs_response[0].d == {'b': 'a'} + assert len(docs_response[0].texts) == 1 + assert docs_response[0].texts[0].text == 'hey ha' + assert docs_response[0].texts[0].embedding.shape == (3,) + assert docs_response[0].tensor.shape == (10, 10, 10) + assert docs_response[0].u == 'a' + assert docs_response[0].single_text.text == 'single hey ha' + assert docs_response[0].single_text.embedding.shape == (2,) + + docs_default = DocList[CustomDoc].from_json(response_default.content.decode()) + assert len(docs_default) == 1 + assert docs_default[0].url == 'photo.jpg' + assert docs_default[0].num == 3.5 + assert docs_default[0].num_num == [4.5, 5.5] + assert docs_default[0].lll == [[[40]]] + assert docs_default[0].lu == [3, 4] + assert docs_default[0].fff == [[[40.2]]] + assert docs_default[0].di == {'a': 2} + assert docs_default[0].d == {'b': 'a'} + assert len(docs_default[0].texts) == 1 + assert docs_default[0].texts[0].text == 'hey ha' + assert docs_default[0].texts[0].embedding.shape == (3,) + assert docs_default[0].tensor.shape == (10, 10, 10) + assert docs_default[0].u == 'a' + assert docs_default[0].single_text.text == 'single hey ha' + assert docs_default[0].single_text.embedding.shape == (2,) + + docs_embed = DocList[CustomDoc].from_json(response_embed.content.decode()) + assert len(docs_embed) == 1 + assert docs_embed[0].url == 'photo.jpg' + assert docs_embed[0].num == 3.5 + assert docs_embed[0].num_num == [4.5, 5.5] + assert docs_embed[0].lll == [[[40]]] + assert docs_embed[0].lu == [3, 4] + assert docs_embed[0].fff == [[[40.2]]] + assert docs_embed[0].di == {'a': 2} + assert docs_embed[0].d == {'b': 'a'} + assert len(docs_embed[0].texts) == 1 + assert docs_embed[0].texts[0].text == 'hey ha' + assert docs_embed[0].texts[0].embedding.shape == (3,) + assert docs_embed[0].tensor.shape == (10, 10, 10) + assert docs_embed[0].u == 'a' + assert docs_embed[0].single_text.text == 'single hey ha' + assert docs_embed[0].single_text.embedding.shape == (2,) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not is_pydantic_v2, reason='Behavior is only available for Pydantic V2' +) +async def test_simple_directly(): + app = FastAPI() + + @app.post("/doc_list/", response_class=DocArrayResponse) + async def func_doc_list(fastapi_docs: DocList[TextDoc]) -> DocList[TextDoc]: + return fastapi_docs + + @app.post("/doc_single/", response_class=DocArrayResponse) + async def func_doc_single(fastapi_doc: TextDoc) -> TextDoc: + return fastapi_doc + + async with AsyncClient(app=app, base_url="http://test") as ac: + response_doc_list = await ac.post( + "/doc_list/", data=json.dumps([{"text": "text"}]) + ) + response_single = await ac.post( + "/doc_single/", data=json.dumps({"text": "text"}) + ) + resp_doc = await ac.get("/docs") + resp_redoc = await ac.get("/redoc") + + assert response_doc_list.status_code == 200 + assert response_single.status_code == 200 + assert resp_doc.status_code == 200 + assert resp_redoc.status_code == 200 + + docs = DocList[TextDoc].from_json(response_doc_list.content.decode()) + assert len(docs) == 1 + assert docs[0].text == 'text' + + doc = TextDoc.from_json(response_single.content.decode()) + assert doc == 'text' diff --git a/tests/integrations/torch/data/test_torch_dataset.py b/tests/integrations/torch/data/test_torch_dataset.py index f358f1c16b8..5d8236a70b3 100644 --- a/tests/integrations/torch/data/test_torch_dataset.py +++ b/tests/integrations/torch/data/test_torch_dataset.py @@ -60,7 +60,9 @@ def test_torch_dataset(captions_da: DocList[PairTextImage]): batch_lens = [] for batch in loader: - assert isinstance(batch, DocVec[PairTextImage]) + assert isinstance(batch, DocVec) + for d in batch: + assert isinstance(d, PairTextImage) batch_lens.append(len(batch)) assert all(x == BATCH_SIZE for x in batch_lens[:-1]) @@ -140,7 +142,9 @@ def test_torch_dl_multiprocessing(captions_da: DocList[PairTextImage]): batch_lens = [] for batch in loader: - assert isinstance(batch, DocVec[PairTextImage]) + assert isinstance(batch, DocVec) + for d in batch: + assert isinstance(d, PairTextImage) batch_lens.append(len(batch)) assert all(x == BATCH_SIZE for x in batch_lens[:-1]) diff --git a/tests/units/array/stack/storage/test_storage.py b/tests/units/array/stack/storage/test_storage.py index 01c1b68a165..b91585d3737 100644 --- a/tests/units/array/stack/storage/test_storage.py +++ b/tests/units/array/stack/storage/test_storage.py @@ -26,8 +26,9 @@ class MyDoc(BaseDoc): for name in storage.any_columns['name']: assert name == 'hello' inner_docs = storage.doc_columns['doc'] - assert isinstance(inner_docs, DocVec[InnerDoc]) + assert isinstance(inner_docs, DocVec) for i, doc in enumerate(inner_docs): + assert isinstance(doc, InnerDoc) assert doc.price == i diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 2a3790da1d3..b1b385840dd 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -504,7 +504,9 @@ class ImageDoc(BaseDoc): da = parse_obj_as(DocVec[ImageDoc], batch) - assert isinstance(da, DocVec[ImageDoc]) + assert isinstance(da, DocVec) + for d in da: + assert isinstance(d, ImageDoc) def test_validation_column_tensor(batch): @@ -536,14 +538,18 @@ def test_validation_column_doc(batch_nested_doc): batch, Doc, Inner = batch_nested_doc batch.inner = DocList[Inner]([Inner(hello='hello') for _ in range(10)]) - assert isinstance(batch.inner, DocVec[Inner]) + assert isinstance(batch.inner, DocVec) + for d in batch.inner: + assert isinstance(d, Inner) def test_validation_list_doc(batch_nested_doc): batch, Doc, Inner = batch_nested_doc batch.inner = [Inner(hello='hello') for _ in range(10)] - assert isinstance(batch.inner, DocVec[Inner]) + assert isinstance(batch.inner, DocVec) + for d in batch.inner: + assert isinstance(d, Inner) def test_validation_col_doc_fail(batch_nested_doc): diff --git a/tests/units/array/stack/test_proto.py b/tests/units/array/stack/test_proto.py index 8c559826b80..d46766cde30 100644 --- a/tests/units/array/stack/test_proto.py +++ b/tests/units/array/stack/test_proto.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Dict, Optional, Union import numpy as np @@ -245,6 +246,7 @@ class MyDoc(BaseDoc): assert da_after._storage.any_columns['d'] == [None, None] +@pytest.mark.skipif('GITHUB_WORKFLOW' in os.environ, reason='Flaky in Github') @pytest.mark.proto @pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor]) def test_proto_tensor_type(tensor_type): diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index 1d93fb6b78c..8e51cc1c37e 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -486,6 +486,8 @@ def test_validate_list_dict(): dict(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1] ] + # docs = DocList[Image]([Image(url=image['url'], tensor=image['tensor']) for image in images]) + docs = parse_obj_as(DocList[Image], images) assert docs.url == [ @@ -520,5 +522,3 @@ def test_not_double_subcriptable(): with pytest.raises(TypeError) as excinfo: da = DocList[TextDoc][TextDoc] assert da is None - - assert 'not subscriptable' in str(excinfo.value) diff --git a/tests/units/array/test_array_from_to_bytes.py b/tests/units/array/test_array_from_to_bytes.py index abc31cb4ac7..0ab952ce4a7 100644 --- a/tests/units/array/test_array_from_to_bytes.py +++ b/tests/units/array/test_array_from_to_bytes.py @@ -43,11 +43,11 @@ def test_from_to_bytes(protocol, compress, show_progress, array_cls): @pytest.mark.parametrize( - 'protocol', ['protobuf'] # ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] + 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] ) -@pytest.mark.parametrize('compress', ['lz4']) # , 'bz2', 'lzma', 'zlib', 'gzip', None]) -@pytest.mark.parametrize('show_progress', [False]) # [False, True]) -@pytest.mark.parametrize('array_cls', [DocVec]) # [DocList, DocVec]) +@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) +@pytest.mark.parametrize('show_progress', [False, True]) # [False, True]) +@pytest.mark.parametrize('array_cls', [DocList, DocVec]) def test_from_to_base64(protocol, compress, show_progress, array_cls): da = array_cls[MyDoc]( [ @@ -75,27 +75,35 @@ def test_from_to_base64(protocol, compress, show_progress, array_cls): # test_from_to_base64('protobuf', 'lz4', False, DocVec) +class MyTensorTypeDocNdArray(BaseDoc): + embedding: NdArray + text: str + image: ImageDoc -@pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor]) -@pytest.mark.parametrize('protocol', ['protobuf-array', 'pickle-array']) -def test_from_to_base64_tensor_type(tensor_type, protocol): - class MyDoc(BaseDoc): - embedding: tensor_type - text: str - image: ImageDoc +class MyTensorTypeDocTorchTensor(BaseDoc): + embedding: TorchTensor + text: str + image: ImageDoc - da = DocVec[MyDoc]( + +@pytest.mark.parametrize( + 'doc_type, tensor_type', + [(MyTensorTypeDocNdArray, NdArray), (MyTensorTypeDocTorchTensor, TorchTensor)], +) +@pytest.mark.parametrize('protocol', ['protobuf-array', 'pickle-array']) +def test_from_to_base64_tensor_type(doc_type, tensor_type, protocol): + da = DocVec[doc_type]( [ - MyDoc( + doc_type( embedding=[1, 2, 3, 4, 5], text='hello', image=ImageDoc(url='aux.png') ), - MyDoc(embedding=[5, 4, 3, 2, 1], text='hello world', image=ImageDoc()), + doc_type(embedding=[5, 4, 3, 2, 1], text='hello world', image=ImageDoc()), ], tensor_type=tensor_type, ) bytes_da = da.to_base64(protocol=protocol) - da2 = DocVec[MyDoc].from_base64( + da2 = DocVec[doc_type].from_base64( bytes_da, tensor_type=tensor_type, protocol=protocol ) assert da2.tensor_type == tensor_type diff --git a/tests/units/array/test_doclist_schema.py b/tests/units/array/test_doclist_schema.py new file mode 100644 index 00000000000..02a5f562807 --- /dev/null +++ b/tests/units/array/test_doclist_schema.py @@ -0,0 +1,22 @@ +import pytest +from docarray import BaseDoc, DocList +from docarray.utils._internal.pydantic import is_pydantic_v2 + + +@pytest.mark.skipif(not is_pydantic_v2, reason='Feature only available for Pydantic V2') +def test_schema_nested(): + # check issue https://github.com/docarray/docarray/issues/1521 + + class Doc1Test(BaseDoc): + aux: str + + class DocDocTest(BaseDoc): + docs: DocList[Doc1Test] + + assert 'Doc1Test' in DocDocTest.schema()['$defs'] + d = DocDocTest(docs=DocList[Doc1Test]([Doc1Test(aux='aux')])) + + assert isinstance(d.docs, DocList) + for dd in d.docs: + assert isinstance(dd, Doc1Test) + assert d.docs.aux == ['aux'] diff --git a/tests/units/document/test_doc_wo_id.py b/tests/units/document/test_doc_wo_id.py index ffda3ceec4f..4e2a8bba118 100644 --- a/tests/units/document/test_doc_wo_id.py +++ b/tests/units/document/test_doc_wo_id.py @@ -23,4 +23,9 @@ class A(BaseDocWithoutId): cls_doc_list = DocList[A] - assert isinstance(cls_doc_list, type) + da = cls_doc_list([A(text='hey here')]) + + assert isinstance(da, DocList) + for d in da: + assert isinstance(d, A) + assert not hasattr(d, 'id') diff --git a/tests/units/typing/da/test_relations.py b/tests/units/typing/da/test_relations.py index f583abef2ec..cadac712f5a 100644 --- a/tests/units/typing/da/test_relations.py +++ b/tests/units/typing/da/test_relations.py @@ -13,9 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import pytest from docarray import BaseDoc, DocList +from docarray.utils._internal.pydantic import is_pydantic_v2 +@pytest.mark.skipif( + is_pydantic_v2, + reason="Subscripted generics cannot be used with class and instance checks", +) def test_instance_and_equivalence(): class MyDoc(BaseDoc): text: str @@ -28,6 +35,10 @@ class MyDoc(BaseDoc): assert isinstance(docs, DocList[MyDoc]) +@pytest.mark.skipif( + is_pydantic_v2, + reason="Subscripted generics cannot be used with class and instance checks", +) def test_subclassing(): class MyDoc(BaseDoc): text: str diff --git a/tests/units/util/test_create_dynamic_code_class.py b/tests/units/util/test_create_dynamic_code_class.py index eba25911c4f..b7df497816d 100644 --- a/tests/units/util/test_create_dynamic_code_class.py +++ b/tests/units/util/test_create_dynamic_code_class.py @@ -45,6 +45,7 @@ class CustomDoc(BaseDoc): new_custom_doc_model = create_base_doc_from_schema( CustomDocCopy.schema(), 'CustomDoc', {} ) + print(f'new_custom_doc_model {new_custom_doc_model.schema()}') original_custom_docs = DocList[CustomDoc]( [ @@ -131,6 +132,7 @@ class TextDocWithId(BaseDoc): new_textdoc_with_id_model = create_base_doc_from_schema( TextDocWithIdCopy.schema(), 'TextDocWithId', {} ) + print(f'new_textdoc_with_id_model {new_textdoc_with_id_model.schema()}') original_text_doc_with_id = DocList[TextDocWithId]( [TextDocWithId(ia=f'ID {i}') for i in range(10)] @@ -207,6 +209,7 @@ class CustomDoc(BaseDoc): new_custom_doc_model = create_base_doc_from_schema( CustomDocCopy.schema(), 'CustomDoc' ) + print(f'new_custom_doc_model {new_custom_doc_model.schema()}') original_custom_docs = DocList[CustomDoc]() if transformation == 'proto': @@ -232,6 +235,7 @@ class TextDocWithId(BaseDoc): new_textdoc_with_id_model = create_base_doc_from_schema( TextDocWithIdCopy.schema(), 'TextDocWithId', {} ) + print(f'new_textdoc_with_id_model {new_textdoc_with_id_model.schema()}') original_text_doc_with_id = DocList[TextDocWithId]() if transformation == 'proto': @@ -255,6 +259,9 @@ class ResultTestDoc(BaseDoc): new_result_test_doc_with_id_model = create_base_doc_from_schema( ResultTestDocCopy.schema(), 'ResultTestDoc', {} ) + print( + f'new_result_test_doc_with_id_model {new_result_test_doc_with_id_model.schema()}' + ) result_test_docs = DocList[ResultTestDoc]() if transformation == 'proto': @@ -309,9 +316,10 @@ class SearchResult(BaseDoc): models_created_by_name = {} SearchResult_aux = create_pure_python_type_model(SearchResult) - _ = create_base_doc_from_schema( + m = create_base_doc_from_schema( SearchResult_aux.schema(), 'SearchResult', models_created_by_name ) + print(f'm {m.schema()}') QuoteFile_reconstructed_in_gateway_from_Search_results = models_created_by_name[ 'QuoteFile' ] @@ -323,3 +331,28 @@ class SearchResult(BaseDoc): QuoteFile_reconstructed_in_gateway_from_Search_results(id='0', texts=textlist) ) assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey' + + +def test_id_optional(): + from docarray import BaseDoc + import json + + class MyTextDoc(BaseDoc): + text: str + opt: Optional[str] = None + + MyTextDoc_aux = create_pure_python_type_model(MyTextDoc) + td = create_base_doc_from_schema(MyTextDoc_aux.schema(), 'MyTextDoc') + print(f'{td.schema()}') + direct = MyTextDoc.from_json(json.dumps({"text": "text"})) + aux = MyTextDoc_aux.from_json(json.dumps({"text": "text"})) + indirect = td.from_json(json.dumps({"text": "text"})) + assert direct.text == 'text' + assert aux.text == 'text' + assert indirect.text == 'text' + direct = MyTextDoc(text='hey') + aux = MyTextDoc_aux(text='hey') + indirect = td(text='hey') + assert direct.text == 'hey' + assert aux.text == 'hey' + assert indirect.text == 'hey' diff --git a/tests/units/util/test_map.py b/tests/units/util/test_map.py index 3b9f102d928..65dd3c17389 100644 --- a/tests/units/util/test_map.py +++ b/tests/units/util/test_map.py @@ -96,4 +96,6 @@ def test_map_docs_batched(n_docs, batch_size, backend): assert isinstance(it, Generator) for batch in it: - assert isinstance(batch, DocList[MyImage]) + assert isinstance(batch, DocList) + for d in batch: + assert isinstance(d, MyImage) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy