Skip to content

mps and cpu backends produce different training results with FFT and Adam #151740

@ChenkaiMao97

Description

@ChenkaiMao97

🐛 Describe the bug

Hi, I have a model that uses 2d FFT operations, and I'm seeing convergent training results on Cuda and cpu, while getting divergent results on mps (loss drops for the first few steps and then explodes).

I'm not sure where the error is coming from, but I've created this minimal example below with a simple model with a fourier layer and trained on some random data. I observe different behaviors as well. Especially,
(1) with FFT and Adam, on cpu backend the loss drops but on mps backend it explodes.
(2) If I change FFT to Conv2d, or change adam to SGD, it seems the loss is dropping on both cpu and mps.

import torch
import torch.nn as nn
import torch.nn.functional as F

################ model definition ##################

class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_freq, modes1, modes2):
        super().__init__()
        scale = (1 / in_channels / out_channels)
        self.weights = nn.Parameter(scale * torch.rand(in_channels, out_channels, modes1, modes2, 2, dtype=torch.float32))

    def compl_mul2d(self, input, weights):
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        x_ft = torch.fft.rfftn(x, dim=[-2,-1])

        weights = torch.view_as_complex(self.weights)
        weights_r = F.interpolate(weights.real, size=(x.size(-2), x.size(-1)//2+1))
        weights_i = F.interpolate(weights.imag, size=(x.size(-2), x.size(-1)//2+1))
        weights = torch.view_as_complex(torch.stack((weights_r, weights_i), dim=-1))
        out_ft = self.compl_mul2d(x_ft, weights)

        x = torch.fft.irfftn(out_ft, s=(x.size(-2), x.size(-1)))
        return x

#################### training with different backends ################

batch_size = 8
in_c = 2
out_c = 4
hidden_freq = 8
sizex, sizey = (128, 128)
modes1, modes2 = (16, 16)

def train(backend, seed = 42):
    torch.manual_seed(seed)
    if backend=='cpu':
        device = torch.device("cpu")
    elif backend=='mps':
        device = torch.device("mps")
    model = SpectralConv2d(in_c, out_c, hidden_freq, modes1, modes2).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    x_train = torch.randn(batch_size, in_c, sizex, sizey)
    y_train = torch.randn(batch_size, out_c, sizex, sizey)

    x_train = x_train.to(device)
    y_train = y_train.to(device)

    for step in range(1000):
        out = model(x_train)
        loss = criterion(out, y_train)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()

        if (step+1) % 100 == 0:
            print(f"Step {(step+1):03d} | Loss: {loss.item():.6f}")

train('cpu')
train('mps')

output for train('cpu'):

Step 100 | Loss: 0.995368
Step 200 | Loss: 0.992208
Step 300 | Loss: 0.991863
Step 400 | Loss: 0.991827
Step 500 | Loss: 0.991824
Step 600 | Loss: 0.991824
Step 700 | Loss: 0.991824
Step 800 | Loss: 0.991824
Step 900 | Loss: 0.991824
Step 1000 | Loss: 0.991824

output for train('mps'):

Step 100 | Loss: 1.058992
Step 200 | Loss: 1.172400
Step 300 | Loss: 1.356889
Step 400 | Loss: 1.608124
Step 500 | Loss: 1.922639
Step 600 | Loss: 2.297220
Step 700 | Loss: 2.729716
Step 800 | Loss: 3.218872
Step 900 | Loss: 3.761904
Step 1000 | Loss: 4.357483

With smaller learning rate (e.g. 1e-5), the trends are the same.

I'm using python version 3.10.17, torch version 2.6.0 on a mac studio (M2 Ultra) with Sequoia 15.2. Could you please check if you can reproduce the error, and if you have suggestions on how to debug? Thanks a lot.

Versions

PyTorch version: 2.6.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.2 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.3)
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.17 | packaged by conda-forge | (main, Apr 10 2025, 22:23:34) [Clang 18.1.8 ] (64-bit runtime)
Python platform: macOS-15.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Ultra

Versions of relevant libraries:
[pip3] numpy==2.2.4
[pip3] optree==0.15.0
[pip3] torch==2.6.0
[pip3] torchvision==0.21.0
[pip3] torchvision-extra-decoders==0.0.2
[conda] Could not collect

cc @mruberry @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: correctness (silent)issue that returns an incorrect result silentlymodule: fftmodule: mpsRelated to Apple Metal Performance Shaders frameworkneeds reproductionSomeone else needs to try reproducing the issue given the instructions. No action needed from usertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy