Skip to content

torch.layer_norm() gives wrong results on MPS if applied on a slice of tensor #131750

@henrysky

Description

@henrysky

🐛 Describe the bug

torch.layer_norm() gives wrong results on MPS if applied on a slice of tensor.

For example if I do torch.layer_norm(a_tensor[123:125]) on MPS, what it actually calculate is torch.layer_norm(a_tensor[0:2]) which is incorrect. If I do torch.layer_norm(a_tensor[123:125].clone()) on MPS, it gives the correct result.

Here is a complete example to demonstrate the issue

import torch

device = "mps"
rand_tensor = torch.randn(3, 4).to(device)
weights = torch.ones(4).to(device)
bias = torch.zeros(4).to(device)

_a_mps = torch.layer_norm(rand_tensor[0:1], (4,), weights, bias, 1e-5, False)
_b_mps = torch.layer_norm(rand_tensor[1:3], (4,), weights, bias, 1e-5, False)
_b_mps_clone = torch.layer_norm(rand_tensor[1:3].clone(), (4,), weights, bias, 1e-5, False)
_c_mps = torch.layer_norm(rand_tensor, (4,), weights, bias, 1e-5, False)
# should be 0, 0, 4 (on CPU and CUDA)
# but on MPS, it will be 4, 4, 4
print(torch.eq(_a_mps, _b_mps).sum(), torch.eq(_c_mps[0:1], _b_mps).sum(), torch.eq(_c_mps[1:2], _b_mps).sum())
# print 0 matches
print(torch.eq(_b_mps, _b_mps_clone).sum())

device = "cpu"
rand_tensor = rand_tensor.to(device)
weights = weights.to(device)
bias = bias.to(device)

_a_cpu = torch.layer_norm(rand_tensor[0:1], (4,), weights, bias, 1e-5, False)
_b_cpu = torch.layer_norm(rand_tensor[1:3], (4,), weights, bias, 1e-5, False)
_b_cpu_clone = torch.layer_norm(rand_tensor[1:3].clone(), (4,), weights, bias, 1e-5, False)
_c_cpu = torch.layer_norm(rand_tensor, (4,), weights, bias, 1e-5, False)
# print 0, 0, 4
print(torch.eq(_a_cpu, _b_cpu).sum(), torch.eq(_c_cpu[0:1], _b_cpu).sum(), torch.eq(_c_cpu[1:2], _b_cpu).sum())
# print 8 matches (which is all matched)
print(torch.eq(_b_cpu, _b_cpu_clone).sum())

# _a_mps equals _a_cpu
# _b_mps NOT equals _b_cpu
# _c_mps equals _c_cpu

Versions

PyTorch version: 2.4.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.4 (main, Jul 17 2024, 22:53:15) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.12.1
[pip3] torch==2.4.0
[conda] Could not collect

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

Labels

module: correctness (silent)issue that returns an incorrect result silentlymodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    pFad - Phonifier reborn

    Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

    Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


    Alternative Proxies:

    Alternative Proxy

    pFad Proxy

    pFad v3 Proxy

    pFad v4 Proxy