diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index b0a4b6db70c..832bc251517 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -9,6 +9,7 @@ Type, TypeVar, Union, + cast, ) from docarray.base_document import BaseDocument @@ -34,7 +35,10 @@ def __repr__(self): return f'<{self.__class__.__name__} (length={len(self)})>' @classmethod - def __class_getitem__(cls, item: Type[BaseDocument]): + def __class_getitem__(cls, item: Union[Type[BaseDocument], TypeVar, str]): + if not isinstance(item, type): + return Generic.__class_getitem__.__func__(cls, item) # type: ignore + # this do nothing that checking that item is valid type var or str if not issubclass(item, BaseDocument): raise ValueError( f'{cls.__name__}[item] item should be a Document not a {item} ' @@ -48,7 +52,7 @@ def __class_getitem__(cls, item: Type[BaseDocument]): global _DocumentArrayTyped class _DocumentArrayTyped(cls): # type: ignore - document_type: Type[BaseDocument] = item + document_type: Type[BaseDocument] = cast(Type[BaseDocument], item) for field in _DocumentArrayTyped.document_type.__fields__.keys(): diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index f39bae21efa..a42e0d2dca1 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -5,7 +5,6 @@ TYPE_CHECKING, Any, Callable, - Generic, Iterable, List, Optional, @@ -68,7 +67,7 @@ def _is_np_int(item: Any) -> bool: return False # this is unreachable, but mypy wants it -class DocumentArray(IOMixinArray, AnyDocumentArray, Generic[T_doc]): +class DocumentArray(IOMixinArray, AnyDocumentArray[T_doc]): """ DocumentArray is a container of Documents. diff --git a/docarray/array/stacked/array_stacked.py b/docarray/array/stacked/array_stacked.py index cf2bf2accab..50ef017113a 100644 --- a/docarray/array/stacked/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -42,11 +42,12 @@ else: TensorFlowTensor = None # type: ignore +T_doc = TypeVar('T_doc', bound=BaseDocument) T = TypeVar('T', bound='DocumentArrayStacked') IndexIterType = Union[slice, Iterable[int], Iterable[bool], None] -class DocumentArrayStacked(AnyDocumentArray): +class DocumentArrayStacked(AnyDocumentArray[T_doc]): """ DocumentArrayStacked is a container of Documents appropriates to perform computation that require batches of data (ex: matrix multiplication, distance @@ -70,7 +71,7 @@ class DocumentArrayStacked(AnyDocumentArray): def __init__( self: T, - docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]] = None, + docs: Optional[Union[DocumentArray, Iterable[T_doc]]] = None, tensor_type: Type['AbstractTensor'] = NdArray, ): self._doc_columns: Dict[str, 'DocumentArrayStacked'] = {} @@ -80,7 +81,7 @@ def __init__( self.from_iterable_document(docs) def from_iterable_document( - self: T, docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]] + self: T, docs: Optional[Union[DocumentArray, Iterable[T_doc]]] ): self._docs = ( docs @@ -254,7 +255,7 @@ def _set_array_attribute( setattr(self._docs, field, values) @overload - def __getitem__(self: T, item: int) -> BaseDocument: + def __getitem__(self: T, item: int) -> T_doc: ... @overload @@ -276,9 +277,7 @@ def __getitem__(self, item): setattr(doc, field, self._tensor_columns[field][item]) return doc - def __setitem__( - self: T, key: Union[int, IndexIterType], value: Union[T, BaseDocument] - ): + def __setitem__(self: T, key: Union[int, IndexIterType], value: Union[T, T_doc]): # multiple docs case if isinstance(key, (slice, Iterable)): return self._set_data_and_columns(key, value) @@ -476,7 +475,7 @@ def unstacked_mode(self): @classmethod def validate( cls: Type[T], - value: Union[T, Iterable[BaseDocument]], + value: Union[T, Iterable[T_doc]], field: 'ModelField', config: 'BaseConfig', ) -> T: diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index d69514caab7..7c0e9b329f8 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, TypeVar, Union import numpy as np import pytest @@ -294,3 +294,17 @@ def test_del_item(da): 'hello 8', 'hello 9', ] + + +def test_generic_type_var(): + T = TypeVar('T', bound=BaseDocument) + + def f(a: DocumentArray[T]) -> DocumentArray[T]: + return a + + def g(a: DocumentArray['BaseDocument']) -> DocumentArray['BaseDocument']: + return a + + a = DocumentArray() + f(a) + g(a) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy