From ea97d59f46b46a9d3dc6ff3c166509311b246406 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 11:34:58 +0100 Subject: [PATCH 01/11] refactor: create io mixin Signed-off-by: samsja --- docarray/base_document/document.py | 96 +------------------ docarray/base_document/mixins/__init__.py | 4 +- .../base_document/mixins/{proto.py => io.py} | 96 ++++++++++++++++++- 3 files changed, 99 insertions(+), 97 deletions(-) rename docarray/base_document/mixins/{proto.py => io.py} (67%) diff --git a/docarray/base_document/document.py b/docarray/base_document/document.py index f04e0093254..7bdce5a8cd2 100644 --- a/docarray/base_document/document.py +++ b/docarray/base_document/document.py @@ -1,26 +1,21 @@ import os -from typing import Type, Optional, TypeVar +from typing import Type import orjson from pydantic import BaseModel, Field, parse_obj_as from rich.console import Console -import pickle -import base64 from docarray.base_document.base_node import BaseNode from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode -from docarray.utils.compress import _compress_bytes, _decompress_bytes -from docarray.base_document.mixins import ProtoMixin, UpdateMixin +from docarray.base_document.mixins import IOMixin, UpdateMixin from docarray.typing import ID _console: Console = Console() -T = TypeVar('T', bound='BaseDocument') - -class BaseDocument(BaseModel, ProtoMixin, UpdateMixin, BaseNode): +class BaseDocument(BaseModel, IOMixin, UpdateMixin, BaseNode): """ - The base class for Document + The base class for Documents """ id: ID = Field(default_factory=lambda: parse_obj_as(ID, os.urandom(16).hex())) @@ -33,7 +28,7 @@ class Config: validate_assignment = True @classmethod - def _get_field_type(cls, field: str) -> Type['BaseDocument']: + def _get_field_type(cls, field: str) -> Type: """ Accessing the nested python Class define in the schema. Could be useful for reconstruction of Document in serialization/deserilization @@ -61,87 +56,6 @@ def schema_summary(cls) -> None: DocumentSummary.schema_summary(cls) - def __bytes__(self) -> bytes: - return self.to_bytes() - - def to_bytes( - self, protocol: str = 'protobuf', compress: Optional[str] = None - ) -> bytes: - """Serialize itself into bytes. - - For more Pythonic code, please use ``bytes(...)``. - - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :return: the binary serialization in bytes - """ - import pickle - - if protocol == 'pickle': - bstr = pickle.dumps(self) - elif protocol == 'protobuf': - bstr = self.to_protobuf().SerializePartialToString() - else: - raise ValueError( - f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' - ) - return _compress_bytes(bstr, algorithm=compress) - - @classmethod - def from_bytes( - cls: Type[T], - data: bytes, - protocol: str = 'protobuf', - compress: Optional[str] = None, - ) -> T: - """Build Document object from binary bytes - - :param data: binary bytes - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress method to use - :return: a Document object - """ - bstr = _decompress_bytes(data, algorithm=compress) - if protocol == 'pickle': - return pickle.loads(bstr) - elif protocol == 'protobuf': - from docarray.proto import DocumentProto - - pb_msg = DocumentProto() - pb_msg.ParseFromString(bstr) - return cls.from_protobuf(pb_msg) - else: - raise ValueError( - f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' - ) - - def to_base64( - self, protocol: str = 'protobuf', compress: Optional[str] = None - ) -> str: - """Serialize a Document object into as base64 string - - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress method to use - :return: a base64 encoded string - """ - return base64.b64encode(self.to_bytes(protocol, compress)).decode('utf-8') - - @classmethod - def from_base64( - cls: Type[T], - data: str, - protocol: str = 'pickle', - compress: Optional[str] = None, - ) -> T: - """Build Document object from binary bytes - - :param data: a base64 encoded string - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress method to use - :return: a Document object - """ - return cls.from_bytes(base64.b64decode(data), protocol, compress) - def _ipython_display_(self): """Displays the object in IPython as a summary""" self.summary() diff --git a/docarray/base_document/mixins/__init__.py b/docarray/base_document/mixins/__init__.py index e4fdf7a6e7e..53b3242874a 100644 --- a/docarray/base_document/mixins/__init__.py +++ b/docarray/base_document/mixins/__init__.py @@ -1,4 +1,4 @@ -from docarray.base_document.mixins.proto import ProtoMixin +from docarray.base_document.mixins.io import IOMixin from docarray.base_document.mixins.update import UpdateMixin -__all__ = ['ProtoMixin', 'UpdateMixin'] +__all__ = ['IOMixin', 'UpdateMixin'] diff --git a/docarray/base_document/mixins/proto.py b/docarray/base_document/mixins/io.py similarity index 67% rename from docarray/base_document/mixins/proto.py rename to docarray/base_document/mixins/io.py index 10bb9fe991d..05806b35680 100644 --- a/docarray/base_document/mixins/proto.py +++ b/docarray/base_document/mixins/io.py @@ -1,8 +1,11 @@ +import base64 +import pickle from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar from docarray.base_document.base_node import BaseNode from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS +from docarray.utils.compress import _compress_bytes, _decompress_bytes if TYPE_CHECKING: from pydantic.fields import ModelField @@ -10,17 +13,102 @@ from docarray.proto import DocumentProto, NodeProto -T = TypeVar('T', bound='ProtoMixin') +T = TypeVar('T', bound='IOMixin') -class ProtoMixin(Iterable): +class IOMixin: + """ + IOMixin to define all the bytes/protobuf/json related part of BaseDocument + """ + __fields__: Dict[str, 'ModelField'] @classmethod @abstractmethod - def _get_field_type(cls, field: str) -> Type['ProtoMixin']: + def _get_field_type(cls, field: str) -> Type: ... + def __bytes__(self) -> bytes: + return self.to_bytes() + + def to_bytes( + self, protocol: str = 'protobuf', compress: Optional[str] = None + ) -> bytes: + """Serialize itself into bytes. + + For more Pythonic code, please use ``bytes(...)``. + + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :return: the binary serialization in bytes + """ + import pickle + + if protocol == 'pickle': + bstr = pickle.dumps(self) + elif protocol == 'protobuf': + bstr = self.to_protobuf().SerializePartialToString() + else: + raise ValueError( + f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' + ) + return _compress_bytes(bstr, algorithm=compress) + + @classmethod + def from_bytes( + cls: Type[T], + data: bytes, + protocol: str = 'protobuf', + compress: Optional[str] = None, + ) -> T: + """Build Document object from binary bytes + + :param data: binary bytes + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress method to use + :return: a Document object + """ + bstr = _decompress_bytes(data, algorithm=compress) + if protocol == 'pickle': + return pickle.loads(bstr) + elif protocol == 'protobuf': + from docarray.proto import DocumentProto + + pb_msg = DocumentProto() + pb_msg.ParseFromString(bstr) + return cls.from_protobuf(pb_msg) + else: + raise ValueError( + f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' + ) + + def to_base64( + self, protocol: str = 'protobuf', compress: Optional[str] = None + ) -> str: + """Serialize a Document object into as base64 string + + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress method to use + :return: a base64 encoded string + """ + return base64.b64encode(self.to_bytes(protocol, compress)).decode('utf-8') + + @classmethod + def from_base64( + cls: Type[T], + data: str, + protocol: str = 'pickle', + compress: Optional[str] = None, + ) -> T: + """Build Document object from binary bytes + + :param data: a base64 encoded string + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress method to use + :return: a Document object + """ + return cls.from_bytes(base64.b64decode(data), protocol, compress) + @classmethod def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: """create a Document from a protobuf message From 3d2e6563f6b842408ac39725d9aa4b8cbba74c05 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 11:47:35 +0100 Subject: [PATCH 02/11] fix: fix mypy Signed-off-by: samsja --- docarray/array/array.py | 34 ++++++++++++++++------------- docarray/base_document/mixins/io.py | 14 ++++++++++-- docarray/utils/find.py | 4 ++-- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/docarray/array/array.py b/docarray/array/array.py index b8bb580777c..87cecab589c 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -1,40 +1,39 @@ +import base64 +import io +import json +import os +import pathlib +import pickle from contextlib import contextmanager, nullcontext from functools import wraps from typing import ( TYPE_CHECKING, Any, + BinaryIO, Callable, + ContextManager, + Generator, Generic, Iterable, List, Optional, Sequence, + Tuple, Type, TypeVar, Union, cast, overload, - BinaryIO, - ContextManager, - Tuple, - Generator, ) import numpy as np -import json -import io -import os -import pickle -import pathlib -import base64 - from typing_inspect import is_union_type from docarray.array.abstract_array import AnyDocumentArray from docarray.base_document import AnyDocument, BaseDocument from docarray.typing import NdArray -from docarray.utils.misc import is_torch_available from docarray.utils.compress import _decompress_bytes, _get_compress_ctx +from docarray.utils.misc import is_torch_available if TYPE_CHECKING: from pydantic import BaseConfig @@ -268,7 +267,10 @@ def __setitem__(self: T, key: IndexIterType, value: Union[T, BaseDocument]): return self._set_by_mask(key_norm_, value_) elif isinstance(head, int): key_norm__ = cast(Iterable[int], key_norm) - return self._set_by_indices(key_norm__, value) + value_ = cast(Sequence[BaseDocument], value) # this is no strictly true + # set_by_mask requires value_ to have getitem which + # _normalize_index_item() ensures + return self._set_by_indices(key_norm__, value_) else: raise TypeError(f'Invalid type {type(head)} for indexing') @@ -566,6 +568,7 @@ def _write_bytes( f.write(pickle.dumps(self)) elif protocol in SINGLE_PROTOCOLS: from rich import filesize + from docarray.utils.progress_bar import _get_progressbar pbar, t = _get_progressbar( @@ -741,6 +744,7 @@ def _load_binary_all( # Binary format for streaming case else: from rich import filesize + from docarray.utils.progress_bar import _get_progressbar # 1 byte (uint8) @@ -797,10 +801,10 @@ def _load_binary_stream( :return: a generator of `Document` objects """ - from docarray import BaseDocument + from rich import filesize + from docarray import BaseDocument from docarray.utils.progress_bar import _get_progressbar - from rich import filesize with file_ctx as f: version_numdocs_lendoc0 = f.read(9) diff --git a/docarray/base_document/mixins/io.py b/docarray/base_document/mixins/io.py index 05806b35680..25c21c3ffac 100644 --- a/docarray/base_document/mixins/io.py +++ b/docarray/base_document/mixins/io.py @@ -1,7 +1,17 @@ import base64 import pickle from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Optional, + Tuple, + Type, + TypeVar, +) from docarray.base_document.base_node import BaseNode from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS @@ -16,7 +26,7 @@ T = TypeVar('T', bound='IOMixin') -class IOMixin: +class IOMixin(Iterable[Tuple[str, Any]]): """ IOMixin to define all the bytes/protobuf/json related part of BaseDocument """ diff --git a/docarray/utils/find.py b/docarray/utils/find.py index 28898752df4..c20f183abb8 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -1,4 +1,4 @@ -from typing import List, NamedTuple, Optional, Type, Union +from typing import List, NamedTuple, Optional, Type, Union, cast from typing_inspect import is_union_type @@ -284,4 +284,4 @@ def _da_attr_type(da: AnyDocumentArray, attr: str) -> Type[AnyTensor]: f'but {field_type.__class__.__name__}' ) - return field_type + return cast(Type[AnyTensor], field_type) From cb30f82bd9b57982cba1169a97a96925b3d05188 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 11:50:59 +0100 Subject: [PATCH 03/11] refactor: move array and array stacked to subfolder Signed-off-by: samsja --- docarray/__init__.py | 2 +- docarray/array/__init__.py | 4 +-- docarray/array/array/__init__.py | 0 docarray/array/{ => array}/array.py | 8 +++--- docarray/array/stacked/__init__.py | 0 docarray/array/{ => stacked}/array_stacked.py | 2 +- docarray/utils/filter.py | 27 ++++++++++++------- docarray/utils/find.py | 4 +-- .../torch/data/test_torch_dataset.py | 6 ++--- tests/units/array/test_array_proto.py | 2 +- 10 files changed, 31 insertions(+), 24 deletions(-) create mode 100644 docarray/array/array/__init__.py rename docarray/array/{ => array}/array.py (99%) create mode 100644 docarray/array/stacked/__init__.py rename docarray/array/{ => stacked}/array_stacked.py (99%) diff --git a/docarray/__init__.py b/docarray/__init__.py index bfc0842a846..9482eae3ebf 100644 --- a/docarray/__init__.py +++ b/docarray/__init__.py @@ -1,6 +1,6 @@ __version__ = '2023.01.18.alpha' -from docarray.array.array import DocumentArray +from docarray.array.array.array import DocumentArray from docarray.base_document.document import BaseDocument __all__ = [ diff --git a/docarray/array/__init__.py b/docarray/array/__init__.py index 7099e10f238..1b88646ebf1 100644 --- a/docarray/array/__init__.py +++ b/docarray/array/__init__.py @@ -1,4 +1,4 @@ -from docarray.array.array import DocumentArray -from docarray.array.array_stacked import DocumentArrayStacked +from docarray.array.array.array import DocumentArray +from docarray.array.stacked.array_stacked import DocumentArrayStacked __all__ = ['DocumentArray', 'DocumentArrayStacked'] diff --git a/docarray/array/array/__init__.py b/docarray/array/array/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/array.py b/docarray/array/array/array.py similarity index 99% rename from docarray/array/array.py rename to docarray/array/array/array.py index 87cecab589c..5ce98624fcd 100644 --- a/docarray/array/array.py +++ b/docarray/array/array/array.py @@ -39,7 +39,7 @@ from pydantic import BaseConfig from pydantic.fields import ModelField - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked from docarray.proto import DocumentArrayProto from docarray.typing import TorchTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -448,7 +448,7 @@ def stacked_mode(self): ... """ - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked try: da_stacked = DocumentArrayStacked.__class_getitem__(self.document_type)( @@ -465,7 +465,7 @@ def stack(self) -> 'DocumentArrayStacked': Convert the DocumentArray into a DocumentArrayStacked. `Self` cannot be used afterwards """ - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked return DocumentArrayStacked.__class_getitem__(self.document_type)(self) @@ -476,7 +476,7 @@ def validate( field: 'ModelField', config: 'BaseConfig', ): - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked if isinstance(value, (cls, DocumentArrayStacked)): return value diff --git a/docarray/array/stacked/__init__.py b/docarray/array/stacked/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/array_stacked.py b/docarray/array/stacked/array_stacked.py similarity index 99% rename from docarray/array/array_stacked.py rename to docarray/array/stacked/array_stacked.py index 95795b6e6ff..cf2bf2accab 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -15,7 +15,7 @@ ) from docarray.array.abstract_array import AnyDocumentArray -from docarray.array.array import DocumentArray +from docarray.array.array.array import DocumentArray from docarray.base_document import AnyDocument, BaseDocument from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor diff --git a/docarray/utils/filter.py b/docarray/utils/filter.py index 6d666a4a96f..97233016200 100644 --- a/docarray/utils/filter.py +++ b/docarray/utils/filter.py @@ -1,10 +1,8 @@ import json - -from typing import Union, Dict, List - +from typing import Dict, List, Union from docarray.array.abstract_array import AnyDocumentArray -from docarray.array.array import DocumentArray +from docarray.array.array.array import DocumentArray def filter( @@ -31,12 +29,21 @@ class MyDocument(BaseDocument): docs = DocumentArray[MyDocument]( - [MyDocument(caption='A tiger in the jungle', - image=Image(url='tigerphoto.png'), price=100), - MyDocument(caption='A swimming turtle', - image=Image(url='turtlepic.png'), price=50), - MyDocument(caption='A couple birdwatching with binoculars', - image=Image(url='binocularsphoto.png'), price=30)] + [ + MyDocument( + caption='A tiger in the jungle', + image=Image(url='tigerphoto.png'), + price=100, + ), + MyDocument( + caption='A swimming turtle', image=Image(url='turtlepic.png'), price=50 + ), + MyDocument( + caption='A couple birdwatching with binoculars', + image=Image(url='binocularsphoto.png'), + price=30, + ), + ] ) query = { '$and': { diff --git a/docarray/utils/find.py b/docarray/utils/find.py index c20f183abb8..60025ca9c19 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -3,8 +3,8 @@ from typing_inspect import is_union_type from docarray.array.abstract_array import AnyDocumentArray -from docarray.array.array import DocumentArray -from docarray.array.array_stacked import DocumentArrayStacked +from docarray.array.array.array import DocumentArray +from docarray.array.stacked.array_stacked import DocumentArrayStacked from docarray.base_document import BaseDocument from docarray.typing import AnyTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor diff --git a/tests/integrations/torch/data/test_torch_dataset.py b/tests/integrations/torch/data/test_torch_dataset.py index d79e57f716a..c9f6c54a8fe 100644 --- a/tests/integrations/torch/data/test_torch_dataset.py +++ b/tests/integrations/torch/data/test_torch_dataset.py @@ -56,7 +56,7 @@ def test_torch_dataset(captions_da: DocumentArray[PairTextImage]): dataset, batch_size=BATCH_SIZE, collate_fn=dataset.collate_fn, shuffle=True ) - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked batch_lens = [] for batch in loader: @@ -135,7 +135,7 @@ def test_torch_dl_multiprocessing(captions_da: DocumentArray[PairTextImage]): multiprocessing_context='fork', ) - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked batch_lens = [] for batch in loader: @@ -163,7 +163,7 @@ def test_torch_dl_pin_memory(captions_da: DocumentArray[PairTextImage]): multiprocessing_context='fork', ) - from docarray.array.array_stacked import DocumentArrayStacked + from docarray.array.stacked.array_stacked import DocumentArrayStacked batch_lens = [] for batch in loader: diff --git a/tests/units/array/test_array_proto.py b/tests/units/array/test_array_proto.py index 062062d6f2d..dd8de4014e2 100644 --- a/tests/units/array/test_array_proto.py +++ b/tests/units/array/test_array_proto.py @@ -2,7 +2,7 @@ import pytest from docarray import BaseDocument, DocumentArray -from docarray.array.array_stacked import DocumentArrayStacked +from docarray.array.stacked.array_stacked import DocumentArrayStacked from docarray.documents import Image, Text from docarray.typing import NdArray From 27bbbd1b189e9978448813393e048e8df6ed0805 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 11:56:40 +0100 Subject: [PATCH 04/11] refactor: move document array io to mixin Signed-off-by: samsja --- docarray/array/array/array.py | 504 +-------------------------------- docarray/array/array/io.py | 516 ++++++++++++++++++++++++++++++++++ 2 files changed, 519 insertions(+), 501 deletions(-) create mode 100644 docarray/array/array/io.py diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index 5ce98624fcd..2b3c16b8fb9 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -1,24 +1,15 @@ -import base64 import io -import json -import os -import pathlib -import pickle -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager from functools import wraps from typing import ( TYPE_CHECKING, Any, - BinaryIO, Callable, - ContextManager, - Generator, Generic, Iterable, List, Optional, Sequence, - Tuple, Type, TypeVar, Union, @@ -30,9 +21,9 @@ from typing_inspect import is_union_type from docarray.array.abstract_array import AnyDocumentArray +from docarray.array.array.io import IOMixinArray from docarray.base_document import AnyDocument, BaseDocument from docarray.typing import NdArray -from docarray.utils.compress import _decompress_bytes, _get_compress_ctx from docarray.utils.misc import is_torch_available if TYPE_CHECKING: @@ -40,7 +31,6 @@ from pydantic.fields import ModelField from docarray.array.stacked.array_stacked import DocumentArrayStacked - from docarray.proto import DocumentArrayProto from docarray.typing import TorchTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -48,42 +38,6 @@ T_doc = TypeVar('T_doc', bound=BaseDocument) IndexIterType = Union[slice, Iterable[int], Iterable[bool], None] -ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'} -SINGLE_PROTOCOLS = {'pickle', 'protobuf'} -ALLOWED_PROTOCOLS = ARRAY_PROTOCOLS.union(SINGLE_PROTOCOLS) -ALLOWED_COMPRESSIONS = {'lz4', 'bz2', 'lzma', 'zlib', 'gzip'} - - -def _protocol_and_compress_from_file_path( - file_path: Union[pathlib.Path, str], - default_protocol: Optional[str] = None, - default_compress: Optional[str] = None, -) -> Tuple[Optional[str], Optional[str]]: - """Extract protocol and compression algorithm from a string, use defaults if not found. - :param file_path: path of a file. - :param default_protocol: default serialization protocol used in case not found. - :param default_compress: default compression method used in case not found. - Examples: - >>> _protocol_and_compress_from_file_path('./docarray_fashion_mnist.protobuf.gzip') - ('protobuf', 'gzip') - >>> _protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.protobuf') - ('protobuf', None) - >>> _protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.gzip') - (None, gzip) - """ - - protocol = default_protocol - compress = default_compress - - file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes] - for extension in file_extensions: - if extension in ALLOWED_PROTOCOLS: - protocol = extension - elif extension in ALLOWED_COMPRESSIONS: - compress = extension - - return protocol, compress - def _delegate_meth_to_data(meth_name: str) -> Callable: """ @@ -113,21 +67,7 @@ def _is_np_int(item: Any) -> bool: return False # this is unreachable, but mypy wants it -class _LazyRequestReader: - def __init__(self, r): - self._data = r.iter_content(chunk_size=1024 * 1024) - self.content = b'' - - def __getitem__(self, item: slice): - while len(self.content) < item.stop: - try: - self.content += next(self._data) - except StopIteration: - return self.content[item.start : -1 : item.step] - return self.content[item] - - -class DocumentArray(AnyDocumentArray, Generic[T_doc]): +class DocumentArray(IOMixinArray, AnyDocumentArray, Generic[T_doc]): """ DocumentArray is a container of Documents. @@ -493,441 +433,3 @@ def traverse_flat( flattened = AnyDocumentArray._flatten_one_level(nodes) return flattened - - # Methods to load from/to different formats - - @classmethod - def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T: - """create a Document from a protobuf message - :param pb_msg: The protobuf message from where to construct the DocumentArray - """ - return cls( - cls.document_type.from_protobuf(doc_proto) for doc_proto in pb_msg.docs - ) - - def to_protobuf(self) -> 'DocumentArrayProto': - """Convert DocumentArray into a Protobuf message""" - from docarray.proto import DocumentArrayProto - - da_proto = DocumentArrayProto() - for doc in self: - da_proto.docs.append(doc.to_protobuf()) - - return da_proto - - @classmethod - def from_bytes( - cls: Type[T], - data: bytes, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> T: - """Deserialize bytes into a DocumentArray. - - :param data: Bytes from which to deserialize - :param protocol: protocol that was used to serialize - :param compress: compress algorithm that was used to serialize - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: the deserialized DocumentArray - """ - return cls._load_binary_all( - file_ctx=nullcontext(data), - protocol=protocol, - compress=compress, - show_progress=show_progress, - ) - - def _write_bytes( - self, - bf: BinaryIO, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> None: - if protocol in ARRAY_PROTOCOLS: - compress_ctx = _get_compress_ctx(compress) - else: - # delegate the compression to per-doc compression - compress_ctx = None - - fc: ContextManager - if compress_ctx is None: - # if compress do not support streaming then postpone the compress - # into the for-loop - f, fc = bf, nullcontext() - else: - f = compress_ctx(bf) - fc = f - compress = None - - with fc: - if protocol == 'protobuf-array': - f.write(self.to_protobuf().SerializePartialToString()) - elif protocol == 'pickle-array': - f.write(pickle.dumps(self)) - elif protocol in SINGLE_PROTOCOLS: - from rich import filesize - - from docarray.utils.progress_bar import _get_progressbar - - pbar, t = _get_progressbar( - 'Serializing', disable=not show_progress, total=len(self) - ) - - f.write(self._stream_header) - - with pbar: - _total_size = 0 - pbar.start_task(t) - for doc in self: - doc_bytes = doc.to_bytes(protocol=protocol, compress=compress) - len_doc_as_bytes = len(doc_bytes).to_bytes( - 4, 'big', signed=False - ) - all_bytes = len_doc_as_bytes + doc_bytes - f.write(all_bytes) - _total_size += len(all_bytes) - pbar.update( - t, - advance=1, - total_size=str(filesize.decimal(_total_size)), - ) - else: - raise ValueError( - f'protocol={protocol} is not supported. Can be only {ALLOWED_PROTOCOLS}.' - ) - - def to_bytes( - self, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - file_ctx: Optional[BinaryIO] = None, - show_progress: bool = False, - ) -> Optional[bytes]: - """Serialize itself into bytes. - - For more Pythonic code, please use ``bytes(...)``. - - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param file_ctx: File or filename or serialized bytes where the data is stored. - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: the binary serialization in bytes or None if file_ctx is passed where to store - """ - - with (file_ctx or io.BytesIO()) as bf: - self._write_bytes( - bf=bf, - protocol=protocol, - compress=compress, - show_progress=show_progress, - ) - if isinstance(bf, io.BytesIO): - return bf.getvalue() - - return None - - @classmethod - def from_base64( - cls: Type[T], - data: str, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> T: - """Deserialize base64 strings into a DocumentArray. - - :param data: Base64 string to deserialize - :param protocol: protocol that was used to serialize - :param compress: compress algorithm that was used to serialize - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: the deserialized DocumentArray - """ - return cls._load_binary_all( - file_ctx=nullcontext(base64.b64decode(data)), - protocol=protocol, - compress=compress, - show_progress=show_progress, - ) - - def to_base64( - self, - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> str: - """Serialize itself into base64 encoded string. - - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: the binary serialization in bytes or None if file_ctx is passed where to store - """ - with io.BytesIO() as bf: - self._write_bytes( - bf=bf, - compress=compress, - protocol=protocol, - show_progress=show_progress, - ) - return base64.b64encode(bf.getvalue()).decode('utf-8') - - @classmethod - def from_json( - cls: Type[T], - file: Union[str, bytes, bytearray], - ) -> T: - """Deserialize JSON strings or bytes into a DocumentArray. - - :param file: JSON object from where to deserialize a DocumentArray - :return: the deserialized DocumentArray - """ - json_docs = json.loads(file) - return cls([cls.document_type.parse_raw(v) for v in json_docs]) - - def to_json(self) -> str: - """Convert the object into a JSON string. Can be loaded via :meth:`.from_json`. - :return: JSON serialization of DocumentArray - """ - return json.dumps([doc.json() for doc in self]) - - # Methods to load from/to files in different formats - @property - def _stream_header(self) -> bytes: - # Binary format for streaming case - - # V1 DocArray streaming serialization format - # | 1 byte | 8 bytes | 4 bytes | variable | 4 bytes | variable ... - - # 1 byte (uint8) - version_byte = b'\x01' - # 8 bytes (uint64) - num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False) - return version_byte + num_docs_as_bytes - - @classmethod - def _load_binary_all( - cls: Type[T], - file_ctx: Union[ContextManager[io.BufferedReader], ContextManager[bytes]], - protocol: Optional[str], - compress: Optional[str], - show_progress: bool, - ): - """Read a `DocumentArray` object from a binary file - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: a `DocumentArray` - """ - with file_ctx as fp: - if isinstance(fp, bytes): - d = fp - else: - d = fp.read() - - if protocol is not None and protocol in ('pickle-array', 'protobuf-array'): - if _get_compress_ctx(algorithm=compress) is not None: - d = _decompress_bytes(d, algorithm=compress) - compress = None - - if protocol is not None and protocol == 'protobuf-array': - from docarray.proto import DocumentArrayProto - - dap = DocumentArrayProto() - dap.ParseFromString(d) - - return cls.from_protobuf(dap) - elif protocol is not None and protocol == 'pickle-array': - return pickle.loads(d) - - # Binary format for streaming case - else: - from rich import filesize - - from docarray.utils.progress_bar import _get_progressbar - - # 1 byte (uint8) - # 8 bytes (uint64) - num_docs = int.from_bytes(d[1:9], 'big', signed=False) - - pbar, t = _get_progressbar( - 'Deserializing', disable=not show_progress, total=num_docs - ) - - # this 9 is version + num_docs bytes used - start_pos = 9 - docs = [] - with pbar: - _total_size = 0 - pbar.start_task(t) - - for _ in range(num_docs): - # 4 bytes (uint32) - len_current_doc_in_bytes = int.from_bytes( - d[start_pos : start_pos + 4], 'big', signed=False - ) - start_doc_pos = start_pos + 4 - end_doc_pos = start_doc_pos + len_current_doc_in_bytes - start_pos = end_doc_pos - - # variable length bytes doc - load_protocol: str = protocol or 'protobuf' - doc = cls.document_type.from_bytes( - d[start_doc_pos:end_doc_pos], - protocol=load_protocol, - compress=compress, - ) - docs.append(doc) - _total_size += len_current_doc_in_bytes - pbar.update( - t, advance=1, total_size=str(filesize.decimal(_total_size)) - ) - return cls(docs) - - @classmethod - def _load_binary_stream( - cls: Type[T], - file_ctx: ContextManager[io.BufferedReader], - protocol: Optional[str] = None, - compress: Optional[str] = None, - show_progress: bool = False, - ) -> Generator['BaseDocument', None, None]: - """Yield `Document` objects from a binary file - - :param protocol: protocol to use. It can be 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :return: a generator of `Document` objects - """ - - from rich import filesize - - from docarray import BaseDocument - from docarray.utils.progress_bar import _get_progressbar - - with file_ctx as f: - version_numdocs_lendoc0 = f.read(9) - # 1 byte (uint8) - # 8 bytes (uint64) - num_docs = int.from_bytes(version_numdocs_lendoc0[1:9], 'big', signed=False) - - pbar, t = _get_progressbar( - 'Deserializing', disable=not show_progress, total=num_docs - ) - - with pbar: - _total_size = 0 - pbar.start_task(t) - for _ in range(num_docs): - # 4 bytes (uint32) - len_current_doc_in_bytes = int.from_bytes( - f.read(4), 'big', signed=False - ) - _total_size += len_current_doc_in_bytes - load_protocol: str = protocol or 'protobuf' - yield BaseDocument.from_bytes( - f.read(len_current_doc_in_bytes), - protocol=load_protocol, - compress=compress, - ) - pbar.update( - t, advance=1, total_size=str(filesize.decimal(_total_size)) - ) - - @classmethod - def load_binary( - cls: Type[T], - file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader], - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - streaming: bool = False, - ) -> Union[T, Generator['BaseDocument', None, None]]: - """Load array elements from a compressed binary file. - - :param file: File or filename or serialized bytes where the data is stored. - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - :param streaming: if `True` returns a generator over `Document` objects. - In case protocol is pickle the `Documents` are streamed from disk to save memory usage - :return: a DocumentArray object - - .. note:: - If `file` is `str` it can specify `protocol` and `compress` as file extensions. - This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a - string interpolation of the respective `protocol` and `compress` methods. - For example if `file=my_docarray.protobuf.lz4` then the binary data will be loaded assuming `protocol=protobuf` - and `compress=lz4`. - """ - load_protocol: Optional[str] = protocol - load_compress: Optional[str] = compress - file_ctx: Union[nullcontext, io.BufferedReader] - if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)): - file_ctx = nullcontext(file) - # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str - elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file): - load_protocol, load_compress = _protocol_and_compress_from_file_path( - file, protocol, compress - ) - file_ctx = open(file, 'rb') - else: - raise FileNotFoundError(f'cannot find file {file}') - if streaming: - return cls._load_binary_stream( - file_ctx, - protocol=load_protocol, - compress=load_compress, - show_progress=show_progress, - ) - else: - return cls._load_binary_all( - file_ctx, load_protocol, load_compress, show_progress - ) - - def save_binary( - self, - file: Union[str, pathlib.Path], - protocol: str = 'protobuf-array', - compress: Optional[str] = None, - show_progress: bool = False, - ) -> None: - """Save DocumentArray into a binary file. - - It will use the protocol to pick how to save the DocumentArray. - If used 'picke-array` and `protobuf-array` the DocumentArray will be stored - and compressed at complete level using `pickle` or `protobuf`. - When using `protobuf` or `pickle` as protocol each Document in DocumentArray - will be stored individually and this would make it available for streaming. - - :param file: File or filename to which the data is saved. - :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' - :param compress: compress algorithm to use - :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - - .. note:: - If `file` is `str` it can specify `protocol` and `compress` as file extensions. - This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a - string interpolation of the respective `protocol` and `compress` methods. - For example if `file=my_docarray.protobuf.lz4` then the binary data will be created using `protocol=protobuf` - and `compress=lz4`. - """ - if isinstance(file, io.BufferedWriter): - file_ctx = nullcontext(file) - else: - _protocol, _compress = _protocol_and_compress_from_file_path(file) - - if _protocol is not None: - protocol = _protocol - if _compress is not None: - compress = _compress - - file_ctx = open(file, 'wb') - - self.to_bytes( - protocol=protocol, - compress=compress, - file_ctx=file_ctx, - show_progress=show_progress, - ) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py new file mode 100644 index 00000000000..bb3e5aff991 --- /dev/null +++ b/docarray/array/array/io.py @@ -0,0 +1,516 @@ +import base64 +import io +import json +import os +import pathlib +import pickle +from contextlib import nullcontext +from typing import ( + TYPE_CHECKING, + BinaryIO, + ContextManager, + Generator, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from docarray.base_document import BaseDocument +from docarray.utils.compress import _decompress_bytes, _get_compress_ctx + +if TYPE_CHECKING: + + from docarray.proto import DocumentArrayProto + +T = TypeVar('T', bound='IOMixinArray') + + +ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'} +SINGLE_PROTOCOLS = {'pickle', 'protobuf'} +ALLOWED_PROTOCOLS = ARRAY_PROTOCOLS.union(SINGLE_PROTOCOLS) +ALLOWED_COMPRESSIONS = {'lz4', 'bz2', 'lzma', 'zlib', 'gzip'} + + +def _protocol_and_compress_from_file_path( + file_path: Union[pathlib.Path, str], + default_protocol: Optional[str] = None, + default_compress: Optional[str] = None, +) -> Tuple[Optional[str], Optional[str]]: + """Extract protocol and compression algorithm from a string, use defaults if not found. + :param file_path: path of a file. + :param default_protocol: default serialization protocol used in case not found. + :param default_compress: default compression method used in case not found. + Examples: + >>> _protocol_and_compress_from_file_path('./docarray_fashion_mnist.protobuf.gzip') + ('protobuf', 'gzip') + >>> _protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.protobuf') + ('protobuf', None) + >>> _protocol_and_compress_from_file_path('/Documents/docarray_fashion_mnist.gzip') + (None, gzip) + """ + + protocol = default_protocol + compress = default_compress + + file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes] + for extension in file_extensions: + if extension in ALLOWED_PROTOCOLS: + protocol = extension + elif extension in ALLOWED_COMPRESSIONS: + compress = extension + + return protocol, compress + + +class _LazyRequestReader: + def __init__(self, r): + self._data = r.iter_content(chunk_size=1024 * 1024) + self.content = b'' + + def __getitem__(self, item: slice): + while len(self.content) < item.stop: + try: + self.content += next(self._data) + except StopIteration: + return self.content[item.start : -1 : item.step] + return self.content[item] + + +class IOMixinArray: + @classmethod + def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T: + """create a Document from a protobuf message + :param pb_msg: The protobuf message from where to construct the DocumentArray + """ + return cls( + cls.document_type.from_protobuf(doc_proto) for doc_proto in pb_msg.docs + ) + + def to_protobuf(self) -> 'DocumentArrayProto': + """Convert DocumentArray into a Protobuf message""" + from docarray.proto import DocumentArrayProto + + da_proto = DocumentArrayProto() + for doc in self: + da_proto.docs.append(doc.to_protobuf()) + + return da_proto + + @classmethod + def from_bytes( + cls: Type[T], + data: bytes, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> T: + """Deserialize bytes into a DocumentArray. + + :param data: Bytes from which to deserialize + :param protocol: protocol that was used to serialize + :param compress: compress algorithm that was used to serialize + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: the deserialized DocumentArray + """ + return cls._load_binary_all( + file_ctx=nullcontext(data), + protocol=protocol, + compress=compress, + show_progress=show_progress, + ) + + def _write_bytes( + self, + bf: BinaryIO, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> None: + if protocol in ARRAY_PROTOCOLS: + compress_ctx = _get_compress_ctx(compress) + else: + # delegate the compression to per-doc compression + compress_ctx = None + + fc: ContextManager + if compress_ctx is None: + # if compress do not support streaming then postpone the compress + # into the for-loop + f, fc = bf, nullcontext() + else: + f = compress_ctx(bf) + fc = f + compress = None + + with fc: + if protocol == 'protobuf-array': + f.write(self.to_protobuf().SerializePartialToString()) + elif protocol == 'pickle-array': + f.write(pickle.dumps(self)) + elif protocol in SINGLE_PROTOCOLS: + from rich import filesize + + from docarray.utils.progress_bar import _get_progressbar + + pbar, t = _get_progressbar( + 'Serializing', disable=not show_progress, total=len(self) + ) + + f.write(self._stream_header) + + with pbar: + _total_size = 0 + pbar.start_task(t) + for doc in self: + doc_bytes = doc.to_bytes(protocol=protocol, compress=compress) + len_doc_as_bytes = len(doc_bytes).to_bytes( + 4, 'big', signed=False + ) + all_bytes = len_doc_as_bytes + doc_bytes + f.write(all_bytes) + _total_size += len(all_bytes) + pbar.update( + t, + advance=1, + total_size=str(filesize.decimal(_total_size)), + ) + else: + raise ValueError( + f'protocol={protocol} is not supported. Can be only {ALLOWED_PROTOCOLS}.' + ) + + def to_bytes( + self, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + file_ctx: Optional[BinaryIO] = None, + show_progress: bool = False, + ) -> Optional[bytes]: + """Serialize itself into bytes. + + For more Pythonic code, please use ``bytes(...)``. + + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param file_ctx: File or filename or serialized bytes where the data is stored. + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: the binary serialization in bytes or None if file_ctx is passed where to store + """ + + with (file_ctx or io.BytesIO()) as bf: + self._write_bytes( + bf=bf, + protocol=protocol, + compress=compress, + show_progress=show_progress, + ) + if isinstance(bf, io.BytesIO): + return bf.getvalue() + + return None + + @classmethod + def from_base64( + cls: Type[T], + data: str, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> T: + """Deserialize base64 strings into a DocumentArray. + + :param data: Base64 string to deserialize + :param protocol: protocol that was used to serialize + :param compress: compress algorithm that was used to serialize + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: the deserialized DocumentArray + """ + return cls._load_binary_all( + file_ctx=nullcontext(base64.b64decode(data)), + protocol=protocol, + compress=compress, + show_progress=show_progress, + ) + + def to_base64( + self, + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> str: + """Serialize itself into base64 encoded string. + + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: the binary serialization in bytes or None if file_ctx is passed where to store + """ + with io.BytesIO() as bf: + self._write_bytes( + bf=bf, + compress=compress, + protocol=protocol, + show_progress=show_progress, + ) + return base64.b64encode(bf.getvalue()).decode('utf-8') + + @classmethod + def from_json( + cls: Type[T], + file: Union[str, bytes, bytearray], + ) -> T: + """Deserialize JSON strings or bytes into a DocumentArray. + + :param file: JSON object from where to deserialize a DocumentArray + :return: the deserialized DocumentArray + """ + json_docs = json.loads(file) + return cls([cls.document_type.parse_raw(v) for v in json_docs]) + + def to_json(self) -> str: + """Convert the object into a JSON string. Can be loaded via :meth:`.from_json`. + :return: JSON serialization of DocumentArray + """ + return json.dumps([doc.json() for doc in self]) + + # Methods to load from/to files in different formats + @property + def _stream_header(self) -> bytes: + # Binary format for streaming case + + # V1 DocArray streaming serialization format + # | 1 byte | 8 bytes | 4 bytes | variable | 4 bytes | variable ... + + # 1 byte (uint8) + version_byte = b'\x01' + # 8 bytes (uint64) + num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False) + return version_byte + num_docs_as_bytes + + @classmethod + def _load_binary_all( + cls: Type[T], + file_ctx: Union[ContextManager[io.BufferedReader], ContextManager[bytes]], + protocol: Optional[str], + compress: Optional[str], + show_progress: bool, + ): + """Read a `DocumentArray` object from a binary file + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: a `DocumentArray` + """ + with file_ctx as fp: + if isinstance(fp, bytes): + d = fp + else: + d = fp.read() + + if protocol is not None and protocol in ('pickle-array', 'protobuf-array'): + if _get_compress_ctx(algorithm=compress) is not None: + d = _decompress_bytes(d, algorithm=compress) + compress = None + + if protocol is not None and protocol == 'protobuf-array': + from docarray.proto import DocumentArrayProto + + dap = DocumentArrayProto() + dap.ParseFromString(d) + + return cls.from_protobuf(dap) + elif protocol is not None and protocol == 'pickle-array': + return pickle.loads(d) + + # Binary format for streaming case + else: + from rich import filesize + + from docarray.utils.progress_bar import _get_progressbar + + # 1 byte (uint8) + # 8 bytes (uint64) + num_docs = int.from_bytes(d[1:9], 'big', signed=False) + + pbar, t = _get_progressbar( + 'Deserializing', disable=not show_progress, total=num_docs + ) + + # this 9 is version + num_docs bytes used + start_pos = 9 + docs = [] + with pbar: + _total_size = 0 + pbar.start_task(t) + + for _ in range(num_docs): + # 4 bytes (uint32) + len_current_doc_in_bytes = int.from_bytes( + d[start_pos : start_pos + 4], 'big', signed=False + ) + start_doc_pos = start_pos + 4 + end_doc_pos = start_doc_pos + len_current_doc_in_bytes + start_pos = end_doc_pos + + # variable length bytes doc + load_protocol: str = protocol or 'protobuf' + doc = cls.document_type.from_bytes( + d[start_doc_pos:end_doc_pos], + protocol=load_protocol, + compress=compress, + ) + docs.append(doc) + _total_size += len_current_doc_in_bytes + pbar.update( + t, advance=1, total_size=str(filesize.decimal(_total_size)) + ) + return cls(docs) + + @classmethod + def _load_binary_stream( + cls: Type[T], + file_ctx: ContextManager[io.BufferedReader], + protocol: Optional[str] = None, + compress: Optional[str] = None, + show_progress: bool = False, + ) -> Generator['BaseDocument', None, None]: + """Yield `Document` objects from a binary file + + :param protocol: protocol to use. It can be 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: a generator of `Document` objects + """ + + from rich import filesize + + from docarray import BaseDocument + from docarray.utils.progress_bar import _get_progressbar + + with file_ctx as f: + version_numdocs_lendoc0 = f.read(9) + # 1 byte (uint8) + # 8 bytes (uint64) + num_docs = int.from_bytes(version_numdocs_lendoc0[1:9], 'big', signed=False) + + pbar, t = _get_progressbar( + 'Deserializing', disable=not show_progress, total=num_docs + ) + + with pbar: + _total_size = 0 + pbar.start_task(t) + for _ in range(num_docs): + # 4 bytes (uint32) + len_current_doc_in_bytes = int.from_bytes( + f.read(4), 'big', signed=False + ) + _total_size += len_current_doc_in_bytes + load_protocol: str = protocol or 'protobuf' + yield BaseDocument.from_bytes( + f.read(len_current_doc_in_bytes), + protocol=load_protocol, + compress=compress, + ) + pbar.update( + t, advance=1, total_size=str(filesize.decimal(_total_size)) + ) + + @classmethod + def load_binary( + cls: Type[T], + file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader], + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + streaming: bool = False, + ) -> Union[T, Generator['BaseDocument', None, None]]: + """Load array elements from a compressed binary file. + + :param file: File or filename or serialized bytes where the data is stored. + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :param streaming: if `True` returns a generator over `Document` objects. + In case protocol is pickle the `Documents` are streamed from disk to save memory usage + :return: a DocumentArray object + + .. note:: + If `file` is `str` it can specify `protocol` and `compress` as file extensions. + This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a + string interpolation of the respective `protocol` and `compress` methods. + For example if `file=my_docarray.protobuf.lz4` then the binary data will be loaded assuming `protocol=protobuf` + and `compress=lz4`. + """ + load_protocol: Optional[str] = protocol + load_compress: Optional[str] = compress + file_ctx: Union[nullcontext, io.BufferedReader] + if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)): + file_ctx = nullcontext(file) + # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str + elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file): + load_protocol, load_compress = _protocol_and_compress_from_file_path( + file, protocol, compress + ) + file_ctx = open(file, 'rb') + else: + raise FileNotFoundError(f'cannot find file {file}') + if streaming: + return cls._load_binary_stream( + file_ctx, + protocol=load_protocol, + compress=load_compress, + show_progress=show_progress, + ) + else: + return cls._load_binary_all( + file_ctx, load_protocol, load_compress, show_progress + ) + + def save_binary( + self, + file: Union[str, pathlib.Path], + protocol: str = 'protobuf-array', + compress: Optional[str] = None, + show_progress: bool = False, + ) -> None: + """Save DocumentArray into a binary file. + + It will use the protocol to pick how to save the DocumentArray. + If used 'picke-array` and `protobuf-array` the DocumentArray will be stored + and compressed at complete level using `pickle` or `protobuf`. + When using `protobuf` or `pickle` as protocol each Document in DocumentArray + will be stored individually and this would make it available for streaming. + + :param file: File or filename to which the data is saved. + :param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf' + :param compress: compress algorithm to use + :param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + + .. note:: + If `file` is `str` it can specify `protocol` and `compress` as file extensions. + This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a + string interpolation of the respective `protocol` and `compress` methods. + For example if `file=my_docarray.protobuf.lz4` then the binary data will be created using `protocol=protobuf` + and `compress=lz4`. + """ + if isinstance(file, io.BufferedWriter): + file_ctx = nullcontext(file) + else: + _protocol, _compress = _protocol_and_compress_from_file_path(file) + + if _protocol is not None: + protocol = _protocol + if _compress is not None: + compress = _compress + + file_ctx = open(file, 'wb') + + self.to_bytes( + protocol=protocol, + compress=compress, + file_ctx=file_ctx, + show_progress=show_progress, + ) From 5e324cf5156b28f616f9d8234823d0b0eacaa417 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 13:26:36 +0100 Subject: [PATCH 05/11] refactor: create io mixin for docuemnt array Signed-off-by: samsja --- docarray/array/array/array.py | 8 ++++++++ docarray/array/array/io.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index 2b3c16b8fb9..f39bae21efa 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -31,6 +31,7 @@ from pydantic.fields import ModelField from docarray.array.stacked.array_stacked import DocumentArrayStacked + from docarray.proto import DocumentArrayProto from docarray.typing import TorchTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -433,3 +434,10 @@ def traverse_flat( flattened = AnyDocumentArray._flatten_one_level(nodes) return flattened + + @classmethod + def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T: + """create a Document from a protobuf message + :param pb_msg: The protobuf message from where to construct the DocumentArray + """ + return super().from_protobuf(pb_msg) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index bb3e5aff991..490d30c584d 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -4,13 +4,16 @@ import os import pathlib import pickle +from abc import abstractmethod from contextlib import nullcontext from typing import ( TYPE_CHECKING, BinaryIO, ContextManager, Generator, + Iterable, Optional, + Sized, Tuple, Type, TypeVar, @@ -78,7 +81,17 @@ def __getitem__(self, item: slice): return self.content[item] -class IOMixinArray: +class IOMixinArray(Iterable[BaseDocument], Sized): + + document_type: Type[BaseDocument] + + @abstractmethod + def __init__( + self, + docs: Optional[Iterable[BaseDocument]] = None, + ): + ... + @classmethod def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T: """create a Document from a protobuf message From c3482b8538eada9c8b9c152067d792e5bba1a37e Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 14:30:08 +0100 Subject: [PATCH 06/11] fix: fix sized problem Signed-off-by: samsja --- docarray/array/array/io.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docarray/array/array/io.py b/docarray/array/array/io.py index 490d30c584d..b5a846102d2 100644 --- a/docarray/array/array/io.py +++ b/docarray/array/array/io.py @@ -13,7 +13,6 @@ Generator, Iterable, Optional, - Sized, Tuple, Type, TypeVar, @@ -81,10 +80,14 @@ def __getitem__(self, item: slice): return self.content[item] -class IOMixinArray(Iterable[BaseDocument], Sized): +class IOMixinArray(Iterable[BaseDocument]): document_type: Type[BaseDocument] + @abstractmethod + def __len__(self): + ... + @abstractmethod def __init__( self, From 990f4ff340cece4c8fe3ec4c9dd49f10953976f6 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 15:00:55 +0100 Subject: [PATCH 07/11] fix: make any array generic for type var --- docarray/array/abstract_array.py | 2 ++ tests/units/array/test_array.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index b0a4b6db70c..828c5e1085b 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -35,6 +35,8 @@ def __repr__(self): @classmethod def __class_getitem__(cls, item: Type[BaseDocument]): + if isinstance(item, TypeVar): + return super().__class_getitem__(item) if not issubclass(item, BaseDocument): raise ValueError( f'{cls.__name__}[item] item should be a Document not a {item} ' diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index d69514caab7..ec15f21afbc 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, TypeVar, Union import numpy as np import pytest @@ -294,3 +294,13 @@ def test_del_item(da): 'hello 8', 'hello 9', ] + + +def test_generic_type_var(): + T = TypeVar('T', bound=BaseDocument) + + def f(a: DocumentArray[T]) -> DocumentArray[T]: + return a + + a = DocumentArray() + f(a) From 57b2101de4b1f2e4c1c87b64a243d2c442a75fce Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 15:10:49 +0100 Subject: [PATCH 08/11] fix : allow string --- docarray/array/abstract_array.py | 7 ++++--- tests/units/array/test_array.py | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 828c5e1085b..e9ebfb43314 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -9,6 +9,7 @@ Type, TypeVar, Union, + cast, ) from docarray.base_document import BaseDocument @@ -34,8 +35,8 @@ def __repr__(self): return f'<{self.__class__.__name__} (length={len(self)})>' @classmethod - def __class_getitem__(cls, item: Type[BaseDocument]): - if isinstance(item, TypeVar): + def __class_getitem__(cls, item: Union[Type[BaseDocument], TypeVar, str]): + if not isinstance(item, type): return super().__class_getitem__(item) if not issubclass(item, BaseDocument): raise ValueError( @@ -50,7 +51,7 @@ def __class_getitem__(cls, item: Type[BaseDocument]): global _DocumentArrayTyped class _DocumentArrayTyped(cls): # type: ignore - document_type: Type[BaseDocument] = item + document_type: Type[BaseDocument] = cast(Type[BaseDocument], item) for field in _DocumentArrayTyped.document_type.__fields__.keys(): diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index ec15f21afbc..7c0e9b329f8 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -302,5 +302,9 @@ def test_generic_type_var(): def f(a: DocumentArray[T]) -> DocumentArray[T]: return a + def g(a: DocumentArray['BaseDocument']) -> DocumentArray['BaseDocument']: + return a + a = DocumentArray() f(a) + g(a) From 0bded51419a96a9960c4a74d8d3411c812dc6979 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 16:12:22 +0100 Subject: [PATCH 09/11] feat: add generic to document array stacked --- docarray/array/array/array.py | 3 +-- docarray/array/stacked/array_stacked.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index f39bae21efa..a42e0d2dca1 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -5,7 +5,6 @@ TYPE_CHECKING, Any, Callable, - Generic, Iterable, List, Optional, @@ -68,7 +67,7 @@ def _is_np_int(item: Any) -> bool: return False # this is unreachable, but mypy wants it -class DocumentArray(IOMixinArray, AnyDocumentArray, Generic[T_doc]): +class DocumentArray(IOMixinArray, AnyDocumentArray[T_doc]): """ DocumentArray is a container of Documents. diff --git a/docarray/array/stacked/array_stacked.py b/docarray/array/stacked/array_stacked.py index cf2bf2accab..50ef017113a 100644 --- a/docarray/array/stacked/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -42,11 +42,12 @@ else: TensorFlowTensor = None # type: ignore +T_doc = TypeVar('T_doc', bound=BaseDocument) T = TypeVar('T', bound='DocumentArrayStacked') IndexIterType = Union[slice, Iterable[int], Iterable[bool], None] -class DocumentArrayStacked(AnyDocumentArray): +class DocumentArrayStacked(AnyDocumentArray[T_doc]): """ DocumentArrayStacked is a container of Documents appropriates to perform computation that require batches of data (ex: matrix multiplication, distance @@ -70,7 +71,7 @@ class DocumentArrayStacked(AnyDocumentArray): def __init__( self: T, - docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]] = None, + docs: Optional[Union[DocumentArray, Iterable[T_doc]]] = None, tensor_type: Type['AbstractTensor'] = NdArray, ): self._doc_columns: Dict[str, 'DocumentArrayStacked'] = {} @@ -80,7 +81,7 @@ def __init__( self.from_iterable_document(docs) def from_iterable_document( - self: T, docs: Optional[Union[DocumentArray, Iterable[BaseDocument]]] + self: T, docs: Optional[Union[DocumentArray, Iterable[T_doc]]] ): self._docs = ( docs @@ -254,7 +255,7 @@ def _set_array_attribute( setattr(self._docs, field, values) @overload - def __getitem__(self: T, item: int) -> BaseDocument: + def __getitem__(self: T, item: int) -> T_doc: ... @overload @@ -276,9 +277,7 @@ def __getitem__(self, item): setattr(doc, field, self._tensor_columns[field][item]) return doc - def __setitem__( - self: T, key: Union[int, IndexIterType], value: Union[T, BaseDocument] - ): + def __setitem__(self: T, key: Union[int, IndexIterType], value: Union[T, T_doc]): # multiple docs case if isinstance(key, (slice, Iterable)): return self._set_data_and_columns(key, value) @@ -476,7 +475,7 @@ def unstacked_mode(self): @classmethod def validate( cls: Type[T], - value: Union[T, Iterable[BaseDocument]], + value: Union[T, Iterable[T_doc]], field: 'ModelField', config: 'BaseConfig', ) -> T: From 4a477f4256e38a19ec1a20f35f115e11db3194ee Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 16:55:40 +0100 Subject: [PATCH 10/11] fix: fix generic class getitem --- docarray/array/abstract_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index e9ebfb43314..c12bb8e7ee4 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -37,7 +37,7 @@ def __repr__(self): @classmethod def __class_getitem__(cls, item: Union[Type[BaseDocument], TypeVar, str]): if not isinstance(item, type): - return super().__class_getitem__(item) + return Generic.__class_getitem__.__func__(cls, item) # type: ignore if not issubclass(item, BaseDocument): raise ValueError( f'{cls.__name__}[item] item should be a Document not a {item} ' From d07ccd074cf93dec37f665dd1be1112103ce5488 Mon Sep 17 00:00:00 2001 From: samsja Date: Fri, 17 Feb 2023 17:00:17 +0100 Subject: [PATCH 11/11] fix: ad comment --- docarray/array/abstract_array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index c12bb8e7ee4..832bc251517 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -38,6 +38,7 @@ def __repr__(self): def __class_getitem__(cls, item: Union[Type[BaseDocument], TypeVar, str]): if not isinstance(item, type): return Generic.__class_getitem__.__func__(cls, item) # type: ignore + # this do nothing that checking that item is valid type var or str if not issubclass(item, BaseDocument): raise ValueError( f'{cls.__name__}[item] item should be a Document not a {item} ' pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy