Skip to content

[inductor] [silence] inconsistent swap wih eager when compiling torch.rot90-torch.randn_like #147847

@shaoyuyoung

Description

@shaoyuyoung

🐛 Describe the bug

description: this bug is triggered only when torch.rot90 and torch.randn_like are used together. In my case, u can see the second element (-2.1788) and the third element (-0.2934) are swapped by inductor (compared with eager).
device backend: both triton and CPP
note: I have used config.fallback_random = True and torch.manual_seed(0)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)
torch.manual_seed(0)


class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        torch.manual_seed(0)
        x = torch.rot90(x, k=1, dims=[2, 3])
        print(x)
        x = torch.randn_like(x)
        print(x)
        return x


model = Model()


x = torch.randn(1, 1, 2, 2)

inputs = [x]


def run_test(model, inputs, backend):
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    torch.manual_seed(0)
    output = model(*inputs)
    return output


output = run_test(model, inputs, 'eager')
c_output = run_test(model, inputs, 'inductor')

print(torch.allclose(output, c_output, 1e-3, 1e-3, equal_nan=True))
print(torch.max(torch.abs(output - c_output)))

Error logs

tensor([[[[-0.2934,  0.5684],
          [ 1.5410, -2.1788]]]])
tensor([[[[ 1.5410, -2.1788],
          [-0.2934,  0.5684]]]])
tensor([[[[-0.2934,  0.5684],
          [ 1.5410, -2.1788]]]])
tensor([[[[ 1.5410, -0.2934],
          [-2.1788,  0.5684]]]])
False
tensor(1.8854)

Versions

nightly 20250225

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bdhirsh @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @amjames @aakhundov

Metadata

Metadata

Assignees

Labels

high prioritymodule: aotdispatchumbrella label for AOTAutograd issuesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    pFad - Phonifier reborn

    Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

    Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


    Alternative Proxies:

    Alternative Proxy

    pFad Proxy

    pFad v3 Proxy

    pFad v4 Proxy