From cc2e62bf42451b56b85741549463a96c98945cf2 Mon Sep 17 00:00:00 2001 From: Naymul Islam Date: Sat, 2 Dec 2023 11:06:54 +0600 Subject: [PATCH 1/4] fix: fix storage issue in torchtensor class Signed-off-by: Naymul Islam --- docarray/typing/tensor/torch_tensor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 7ad743721a..2f97a51c62 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -292,6 +292,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): torch.Tensor if t in docarray_torch_tensors else t for t in types ) return super().__torch_function__(func, types_, args, kwargs) + + def __deepcopy__(self, memo): + """ + Custom implementation of deepcopy for TorchTensor to avoid storage sharing issues. + """ + # Create a new tensor with the same data and properties + new_tensor = self.clone() + # Set the class to the custom TorchTensor class + new_tensor.__class__ = self.__class__ + return new_tensor @classmethod def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T: From 39e56ec178363717b5f0564594432c3529fdf531 Mon Sep 17 00:00:00 2001 From: Naymul Islam Date: Sun, 3 Dec 2023 10:58:30 +0600 Subject: [PATCH 2/4] tests for deepcopy method Signed-off-by: Naymul Islam --- tests/integrations/typing/test_torch_tensor.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/integrations/typing/test_torch_tensor.py b/tests/integrations/typing/test_torch_tensor.py index 2a84489cd9..bbaf0caa76 100644 --- a/tests/integrations/typing/test_torch_tensor.py +++ b/tests/integrations/typing/test_torch_tensor.py @@ -1,4 +1,6 @@ import torch +from docarray.typing.tensor.torch_tensor import TorchTensor +import copy from docarray import BaseDoc from docarray.typing import TorchEmbedding, TorchTensor @@ -25,3 +27,18 @@ class MyDocument(BaseDoc): assert isinstance(d.embedding, TorchEmbedding) assert isinstance(d.embedding, torch.Tensor) assert (d.embedding == torch.zeros((128,))).all() + +def test_torchtensor_deepcopy(): + # Setup + original_tensor_float = TorchTensor(torch.rand(10)) + original_tensor_int = TorchTensor(torch.randint(0, 100, (10,))) + + # Exercise + copied_tensor_float = copy.deepcopy(original_tensor_float) + copied_tensor_int = copy.deepcopy(original_tensor_int) + + # Verify + assert torch.equal(original_tensor_float, copied_tensor_float) + assert original_tensor_float is not copied_tensor_float + assert torch.equal(original_tensor_int, copied_tensor_int) + assert original_tensor_int is not copied_tensor_int \ No newline at end of file From e5ac2f31cfc813a1d0694d7c0c2cd3c38b4e4355 Mon Sep 17 00:00:00 2001 From: Naymul Islam Date: Mon, 4 Dec 2023 23:41:24 +0600 Subject: [PATCH 3/4] format code into black Signed-off-by: Naymul Islam --- tests/integrations/typing/test_torch_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integrations/typing/test_torch_tensor.py b/tests/integrations/typing/test_torch_tensor.py index bbaf0caa76..0e485fcd07 100644 --- a/tests/integrations/typing/test_torch_tensor.py +++ b/tests/integrations/typing/test_torch_tensor.py @@ -28,6 +28,7 @@ class MyDocument(BaseDoc): assert isinstance(d.embedding, torch.Tensor) assert (d.embedding == torch.zeros((128,))).all() + def test_torchtensor_deepcopy(): # Setup original_tensor_float = TorchTensor(torch.rand(10)) @@ -41,4 +42,4 @@ def test_torchtensor_deepcopy(): assert torch.equal(original_tensor_float, copied_tensor_float) assert original_tensor_float is not copied_tensor_float assert torch.equal(original_tensor_int, copied_tensor_int) - assert original_tensor_int is not copied_tensor_int \ No newline at end of file + assert original_tensor_int is not copied_tensor_int From 48b0f32154bba0d096bc1f64cac5615e4db3c612 Mon Sep 17 00:00:00 2001 From: Naymul Islam Date: Thu, 7 Dec 2023 20:26:23 +0600 Subject: [PATCH 4/4] reformat torchtensor class Signed-off-by: Naymul Islam --- docarray/typing/tensor/torch_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 2f97a51c62..7ed3bd3800 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -292,15 +292,15 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): torch.Tensor if t in docarray_torch_tensors else t for t in types ) return super().__torch_function__(func, types_, args, kwargs) - + def __deepcopy__(self, memo): """ Custom implementation of deepcopy for TorchTensor to avoid storage sharing issues. """ # Create a new tensor with the same data and properties - new_tensor = self.clone() + new_tensor = self.clone() # Set the class to the custom TorchTensor class - new_tensor.__class__ = self.__class__ + new_tensor.__class__ = self.__class__ return new_tensor @classmethod pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy