-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🐛 Describe the bug
Hi, I have a model that uses 2d FFT operations, and I'm seeing convergent training results on Cuda and cpu, while getting divergent results on mps (loss drops for the first few steps and then explodes).
I'm not sure where the error is coming from, but I've created this minimal example below with a simple model with a fourier layer and trained on some random data. I observe different behaviors as well. Especially,
(1) with FFT and Adam, on cpu backend the loss drops but on mps backend it explodes.
(2) If I change FFT to Conv2d, or change adam to SGD, it seems the loss is dropping on both cpu and mps.
import torch
import torch.nn as nn
import torch.nn.functional as F
################ model definition ##################
class SpectralConv2d(nn.Module):
def __init__(self, in_channels, out_channels, hidden_freq, modes1, modes2):
super().__init__()
scale = (1 / in_channels / out_channels)
self.weights = nn.Parameter(scale * torch.rand(in_channels, out_channels, modes1, modes2, 2, dtype=torch.float32))
def compl_mul2d(self, input, weights):
return torch.einsum("bixy,ioxy->boxy", input, weights)
def forward(self, x):
batchsize = x.shape[0]
x_ft = torch.fft.rfftn(x, dim=[-2,-1])
weights = torch.view_as_complex(self.weights)
weights_r = F.interpolate(weights.real, size=(x.size(-2), x.size(-1)//2+1))
weights_i = F.interpolate(weights.imag, size=(x.size(-2), x.size(-1)//2+1))
weights = torch.view_as_complex(torch.stack((weights_r, weights_i), dim=-1))
out_ft = self.compl_mul2d(x_ft, weights)
x = torch.fft.irfftn(out_ft, s=(x.size(-2), x.size(-1)))
return x
#################### training with different backends ################
batch_size = 8
in_c = 2
out_c = 4
hidden_freq = 8
sizex, sizey = (128, 128)
modes1, modes2 = (16, 16)
def train(backend, seed = 42):
torch.manual_seed(seed)
if backend=='cpu':
device = torch.device("cpu")
elif backend=='mps':
device = torch.device("mps")
model = SpectralConv2d(in_c, out_c, hidden_freq, modes1, modes2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
x_train = torch.randn(batch_size, in_c, sizex, sizey)
y_train = torch.randn(batch_size, out_c, sizex, sizey)
x_train = x_train.to(device)
y_train = y_train.to(device)
for step in range(1000):
out = model(x_train)
loss = criterion(out, y_train)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if (step+1) % 100 == 0:
print(f"Step {(step+1):03d} | Loss: {loss.item():.6f}")
train('cpu')
train('mps')
output for train('cpu')
:
Step 100 | Loss: 0.995368
Step 200 | Loss: 0.992208
Step 300 | Loss: 0.991863
Step 400 | Loss: 0.991827
Step 500 | Loss: 0.991824
Step 600 | Loss: 0.991824
Step 700 | Loss: 0.991824
Step 800 | Loss: 0.991824
Step 900 | Loss: 0.991824
Step 1000 | Loss: 0.991824
output for train('mps')
:
Step 100 | Loss: 1.058992
Step 200 | Loss: 1.172400
Step 300 | Loss: 1.356889
Step 400 | Loss: 1.608124
Step 500 | Loss: 1.922639
Step 600 | Loss: 2.297220
Step 700 | Loss: 2.729716
Step 800 | Loss: 3.218872
Step 900 | Loss: 3.761904
Step 1000 | Loss: 4.357483
With smaller learning rate (e.g. 1e-5), the trends are the same.
I'm using python version 3.10.17, torch version 2.6.0 on a mac studio (M2 Ultra) with Sequoia 15.2. Could you please check if you can reproduce the error, and if you have suggestions on how to debug? Thanks a lot.
Versions
PyTorch version: 2.6.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.2 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.3)
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.17 | packaged by conda-forge | (main, Apr 10 2025, 22:23:34) [Clang 18.1.8 ] (64-bit runtime)
Python platform: macOS-15.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Ultra
Versions of relevant libraries:
[pip3] numpy==2.2.4
[pip3] optree==0.15.0
[pip3] torch==2.6.0
[pip3] torchvision==0.21.0
[pip3] torchvision-extra-decoders==0.0.2
[conda] Could not collect
cc @mruberry @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen