-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🐛 Describe the bug
Torch is reporting that "at least one dimension spans across two contiguous subspaces," despite also reporting that all of my tensors are contiguous on mps.
Traceback (most recent call last):
File "/Users/prestonraab/GitHub/cs474FinalProject/vocos-main/vocos/irritating.py", line 20, in <module>
loss.backward()
File "/Users/prestonraab/miniforge3/envs/torchaudio/lib/python3.9/site-packages/torch/_tensor.py", line 624, in backward
torch.autograd.backward(
File "/Users/prestonraab/miniforge3/envs/torchaudio/lib/python3.9/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/Users/prestonraab/miniforge3/envs/torchaudio/lib/python3.9/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
The intent is to normalize the tensor across the channel dimension. This is a snippet of code from Vocos.
Several components of my model are necessary to reproduce the error, but there is likely a smaller reproducible example that I haven't discovered yet.
- 'mps'
- The convolution
- The first transpose
- The layer norm
- output_channels != features
The error occurs during loss.backward(), which makes it tricky to debug.
import torch
import torch.nn as nn
device = torch.device('mps')
INPUT_CHANNELS = 100
OUTPUT_CHANNELS = 50
FEATURES = 25
conv = nn.Conv1d(INPUT_CHANNELS, OUTPUT_CHANNELS, kernel_size=3, padding=1).to(device)
norm = nn.LayerNorm(OUTPUT_CHANNELS).to(device)
a = torch.rand((1, INPUT_CHANNELS, FEATURES)).to(device)
b = conv(a)
c = b.transpose(1, 2).contiguous()
d = norm(c)
e = d.transpose(1, 2).contiguous()
loss = torch.sum(e)
print([var.is_contiguous() for var in (a, b, c, d, e, loss)])
loss.backward()
Versions
PyTorch version: 2.6.0.dev20241112
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.0.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.4)
CMake version: version 3.30.5
Libc version: N/A
Python version: 3.9.20 | packaged by conda-forge | (main, Sep 30 2024, 17:48:00) [Clang 17.0.6 ] (64-bit runtime)
Python platform: macOS-15.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.6.0.dev20241112
[pip3] torchaudio==2.5.0.dev20241116
[pip3] torchcrepe==0.0.23
[pip3] torchmetrics==1.5.2
[pip3] torchvision==0.20.0.dev20241116
[conda] libopenvino-pytorch-frontend 2024.4.0 h5833ebf_2 conda-forge
[conda] nomkl 1.0 h5ca1d4c_0 conda-forge
[conda] numpy 1.26.4 py39h7aa2656_0 conda-forge
[conda] pytorch 2.6.0.dev20241112 py3.9_0 pytorch-nightly
[conda] pytorch-lightning 2.4.0 pyhd8ed1ab_0 conda-forge
[conda] torchaudio 2.5.0.dev20241116 py39_cpu pytorch-nightly
[conda] torchcrepe 0.0.23 pypi_0 pypi
[conda] torchmetrics 1.5.2 pyhe5570ce_0 conda-forge
[conda] torchvision 0.20.0.dev20241116 py39_cpu pytorch-nightly
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @jamesr66a @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen