Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WeaveModel not working with cuda #4071

Open
soerenbrandt opened this issue Jul 27, 2024 · 1 comment
Open

WeaveModel not working with cuda #4071

soerenbrandt opened this issue Jul 27, 2024 · 1 comment

Comments

@soerenbrandt
Copy link

🐛 Bug

I am trying to run the pytorch WeaveModel using CUDA and am getting a TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.. I am seeing the same error with other devices except CPU. I couldn't find any guidance on the documentation or issues and every example I've found uses CPU except for the one below.

To Reproduce

Steps to reproduce the behavior:

tasks_, datasets_, transformers_ = dc.molnet.load_hiv(featurizer='Weave', splitter='scaffold', reload=False)
train_dataset, valid_dataset, test_dataset = datasets_
model = WeaveModel(mode='classification', n_tasks=1, batch_size=32, learning_rate=1e-1, dropout=0.05)#, device='cuda:0')
loss = model.fit(train_dataset)

This is the error trace:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[61], [line 5](vscode-notebook-cell:?execution_count=61&line=5)
      [2](vscode-notebook-cell:?execution_count=61&line=2) y = np.array([1, 0])
      [3](vscode-notebook-cell:?execution_count=61&line=3) dataset = dc.data.NumpyDataset(X, y)
----> [5](vscode-notebook-cell:?execution_count=61&line=5) model.fit(dataset)

File ~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:338, in TorchModel.fit(self, dataset, nb_epoch, max_checkpoints_to_keep, checkpoint_interval, deterministic, restore, variables, loss, callbacks, all_losses)
    [289](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:289) def fit(self,
    [290](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:290)         dataset: Dataset,
    [291](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:291)         nb_epoch: int = 10,
   (...)
    [298](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:298)         callbacks: Union[Callable, List[Callable]] = [],
    [299](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:299)         all_losses: Optional[List[float]] = None) -> float:
    [300](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:300)     """Train this model on a dataset.
    [301](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:301) 
    [302](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:302)     Parameters
   (...)
    [336](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:336)     The average loss over the most recent checkpoint interval
    [337](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:337)     """
--> [338](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:338)     return self.fit_generator(
    [339](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:339)         self.default_generator(dataset,
    [340](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:340)                                epochs=nb_epoch,
    [341](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:341)                                deterministic=deterministic),
    [342](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:342)         max_checkpoints_to_keep, checkpoint_interval, restore, variables,
    [343](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:343)         loss, callbacks, all_losses)

File ~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:433, in TorchModel.fit_generator(self, generator, max_checkpoints_to_keep, checkpoint_interval, restore, variables, loss, callbacks, all_losses)
    [430](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:430)     inputs = inputs[0]
    [432](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:432) optimizer.zero_grad()
--> [433](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:433) outputs = self.model(inputs)
    [434](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:434) if isinstance(outputs, torch.Tensor):
    [435](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/torch_model.py:435)     outputs = [outputs]

File ~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   [1496](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/nn/modules/module.py:1496) # If we don't have any hooks, we want to skip the rest of the logic in
   [1497](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/nn/modules/module.py:1497) # this function, and just call forward.
   [1498](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/nn/modules/module.py:1498) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1499](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/nn/modules/module.py:1499)         or _global_backward_pre_hooks or _global_backward_hooks
   [1500](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/nn/modules/module.py:1500)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1501](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/nn/modules/module.py:1501)     return forward_call(*args, **kwargs)
   [1502](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/nn/modules/module.py:1502) # Do not call functions when jit is used
   [1503](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/nn/modules/module.py:1503) full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:280, in Weave.forward(self, inputs)
    [267](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:267) def forward(self, inputs: OneOrMany[torch.Tensor]) -> List[torch.Tensor]:
    [268](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:268)     """
    [269](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:269)     Parameters
    [270](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:270)     ----------
   (...)
    [277](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:277)         Output as per use case : regression/classification
    [278](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:278)     """
    [279](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:279)     input1: List[np.ndarray] = [
--> [280](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:280)         np.array(inputs[0]),
    [281](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:281)         np.array(inputs[1]),
    [282](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:282)         np.array(inputs[2]),
    [283](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:283)         np.array(inputs[4])
    [284](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:284)     ]
    [285](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:285)     for ind in range(self.n_weave):
    [286](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/deepchem/models/torch_models/weavemodel_pytorch.py:286)         weave_layer_ind_A, weave_layer_ind_P = self.layers[ind](input1)

File ~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/_tensor.py:970, in Tensor.__array__(self, dtype)
    [968](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/_tensor.py:968)     return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
    [969](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/_tensor.py:969) if dtype is None:
--> [970](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/_tensor.py:970)     return self.numpy()
    [971](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/_tensor.py:971) else:
    [972](https://file+.vscode-resource.vscode-cdn.net/Users/soerenbrandt/src/anagenex/agx/anagenex/playground/soeren/~/.cache/pants/named_caches/pex_root/venvs/3af3ea6f7fae239cf5723d83faf1ba29848c7f25/a54cf194a38a982fc7904d648eb451a554889da0/lib/python3.10/site-packages/torch/_tensor.py:972)     return self.numpy().astype(dtype, copy=False)

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Expected behavior

The error seems to happen in the Weave.forward method when converting inputs to numpy arrays. Not sure if that is the best approach for torch in general but I would expect that the fit and predict methods both work with a device on cuda.

Environment

  • Python version: 3.10.11
  • DeepChem version: 2.8.0
  • PyTorch version (optional): 2.0.0
@krtimisra67
Copy link

The error occurs because DeepChem's WeaveModel (PyTorch version) attempts to convert a CUDA tensor to a NumPy array without first moving it to the CPU. This is a common issue when working with PyTorch models on GPUs.
Fix: Move Tensor to CPU Before Converting to NumPy

Modify the DeepChem code where it tries to convert tensors to NumPy. Specifically, look for .numpy() calls and change them to:

tensor.cpu().numpy()

Steps to Fix

Locate the error source
    Open the DeepChem source code (deepchem/models/torch_models/weave_model.py or any related file).
    Search for .numpy() and replace it with .cpu().numpy().
    Example fix:

my_tensor = some_tensor.cpu().numpy()  # Ensure tensor is on CPU before conversion

Modify WeaveModel to explicitly use CUDA

Update WeaveModel initialization:

model = WeaveModel(mode='classification', n_tasks=1, batch_size=32, learning_rate=1e-1, dropout=0.05, device='cuda')

If WeaveModel does not have a device argument, ensure the tensors are explicitly moved to CUDA:

model.to(torch.device('cuda'))

Update TensorFlow/Keras (if applicable)
Since DeepChem integrates TensorFlow/Keras and PyTorch, ensure your versions are compatible:

pip install --upgrade deepchem torch tensorflow keras

Check GPU Compatibility Run:

import torch
print(torch.cuda.is_available()) # Should return True
print(torch.cuda.device_count()) # Should be > 0

If CUDA is unavailable, check your PyTorch installation:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy