Skip to content

torch.compile on .sum() and .item() calls errors from tensorify_python_scalars #158083

@StrongerXi

Description

@StrongerXi

🐛 Describe the bug

This was discovered in #157499, which tries to turn on capture_scalar_outputs by default. Specifically TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 python test/dynamo/test_utils.py TestUtils.test_graph_break_counting fails. I made a minimal repro below.

Note that it passes with backend="eager".

import torch


torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(backend="aot_eager")
def f(x):
    y = x.sum()
    return x + y.item()


x = torch.ones(3, 3)
res = f(x)

Error logs

Traceback (most recent call last):
  File "/home/ryanguo99/scratch/comp.py", line 11, in <module>
    res = f(x)
          ^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 797, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/output_graph.py", line 1886, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/output_graph.py", line 1861, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/__init__.py", line 2437, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/backends/debugging.py", line 182, in aot_eager
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/backends/common.py", line 109, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 1212, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 1184, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 575, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 245, in aot_dispatch_base
    tensorify_python_scalars(fw_module, fake_mode.shape_env, fake_mode)
  File "/home/ryanguo99/repos/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 262, in tensorify_python_scalars
    proxy = _sympy_interp(zf.node.expr)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp
    expr_to_sym_proxy[expr]
    ~~~~~~~~~~~~~~~~~^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
KeyError: zuf0

Versions

main fcc682b, python 3.11

cc @chauhang @penguinwu @ezyang @bobrenjc93

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    pFad - Phonifier reborn

    Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

    Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


    Alternative Proxies:

    Alternative Proxy

    pFad Proxy

    pFad v3 Proxy

    pFad v4 Proxy