Skip to content

torch.compile on torch.arange hard errors with PendingUnbackedSymbolNotFound #158058

@StrongerXi

Description

@StrongerXi

🐛 Describe the bug

This was discovered in #157499, which tries to turn on capture_scalar_outputs by default. Specifically python test/dynamo/test_repros.py ReproTests.test_do_paste_mask fails. I made a minimal repro below, we should graph break rather than hard error in this case (if we can eliminate the graph break, even better).

import torch


torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(backend="eager", fullgraph=False)
def f(xs):
    x1, x2 = xs
    res = torch.arange(x1, x2)
    return res


xs = torch.tensor([2, 4])
res = f(xs)

Error logs

Running with `TORCH_LOGS="+dynamo, +dynamic":

I0710 11:54:28.616000 4046875 torch/_dynamo/utils.py:1697] [0/0] ChromiumEventLogger initialized with id f19ae67f-33d7-46c0-86b1-398570aa8d39
V0710 11:54:28.617000 4046875 torch/_dynamo/convert_frame.py:1097] [0/0] torchdynamo start compiling f /home/ryanguo99/scratch/comp.py:5, stack (elided 5 frames):
V0710 11:54:28.617000 4046875 torch/_dynamo/convert_frame.py:1097] [0/0]   File "/home/ryanguo99/scratch/comp.py", line 13, in <module>
V0710 11:54:28.617000 4046875 torch/_dynamo/convert_frame.py:1097] [0/0]     res = f(xs)
V0710 11:54:28.617000 4046875 torch/_dynamo/convert_frame.py:1097] [0/0]
I0710 11:54:29.104000 4046875 torch/_dynamo/symbolic_convert.py:3340] [0/0] Step 1: torchdynamo start tracing f /home/ryanguo99/scratch/comp.py:5
I0710 11:54:29.105000 4046875 torch/fx/experimental/symbolic_shapes.py:3776] [0/0] create_env
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source] TRACE starts_line /home/ryanguo99/scratch/comp.py:5 in f ()
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source]     @torch.compile(backend="eager", fullgraph=True)
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE RESUME 0 []
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source] TRACE starts_line /home/ryanguo99/scratch/comp.py:7 in f (f)
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source]         x1, x2 = xs
V0710 11:54:29.108000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_FAST xs []
V0710 11:54:29.108000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE UNPACK_SEQUENCE 2 [LazyVariableTracker()]
V0710 11:54:29.109000 4046875 torch/_dynamo/variables/builder.py:3438] [0/0] wrap_to_fake L['xs'] (2,) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None], constraint_strides=[None], specialize_on=[[]], view_base_context=None, tensor_source=LocalSource(local_name='xs', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0710 11:54:29.109000 4046875 torch/_dynamo/output_graph.py:2629] [0/0] create_graph_input L_xs_ L['xs'] FakeTensor(..., size=(2,), dtype=torch.int64) at debug_level 0 before=False
V0710 11:54:29.110000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call] TRACE FX call getitem from /home/ryanguo99/scratch/comp.py:7 in f (f)
V0710 11:54:29.110000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call]     x1, x2 = xs
V0710 11:54:29.110000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call]     ^^^^^^
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE STORE_FAST x1 [TensorVariable(), TensorVariable()]
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE STORE_FAST x2 [TensorVariable()]
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source] TRACE starts_line /home/ryanguo99/scratch/comp.py:8 in f (f)
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source]         res = torch.arange(x1, x2)
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_ATTR arange [NullVariable, LazyVariableTracker()]
V0710 11:54:29.113000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_FAST x1 [NullVariable, LazyVariableTracker()]
V0710 11:54:29.113000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_FAST x2 [NullVariable, LazyVariableTracker(), TensorVariable()]
V0710 11:54:29.114000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE PRECALL 2 [NullVariable, LazyVariableTracker(), TensorVariable(), TensorVariable()]
V0710 11:54:29.114000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE CALL 2 [NullVariable, LazyVariableTracker(), TensorVariable(), TensorVariable()]
V0710 11:54:29.119000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call] TRACE FX call arange from /home/ryanguo99/scratch/comp.py:8 in f (f)
V0710 11:54:29.119000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call]     res = torch.arange(x1, x2)
V0710 11:54:29.119000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call]           ~~~~~~~~~~~~^^^^^^^^
I0710 11:54:29.123000 4046875 torch/fx/experimental/symbolic_shapes.py:4783] [0/0] create_unbacked_symint u0 [-int_oo, int_oo] res = torch.arange(x1, x2)  # scratch/comp.py:8 in f (_subclasses/fake_impls.py:425 in local_scalar_dense)
I0710 11:54:29.123000 4046875 torch/fx/experimental/symbolic_shapes.py:4783] [0/0] create_unbacked_symint u1 [-int_oo, int_oo] res = torch.arange(x1, x2)  # scratch/comp.py:8 in f (_subclasses/fake_impls.py:425 in local_scalar_dense)
I0710 11:54:29.168000 4046875 torch/fx/experimental/symbolic_shapes.py:7211] [0/0] runtime_assert u0 >= u1 [guard added] res = torch.arange(x1, x2)  # scratch/comp.py:8 in f (_refs/__init__.py:5105 in arange), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= u1"
V0710 11:54:29.170000 4046875 torch/fx/experimental/symbolic_shapes.py:7708] [0/0] runtime_assert u0 - u1 >= 0 == True [statically known]
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0] could not evaluate Eq(u0 - u1, 0) due to data dependency, it was assumed to be False with no runtime assertions res = torch.arange(x1, x2)  # scratch/comp.py:8 in f (utils/_stats.py:28 in wrapper)
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0] User Stack (most recent call last):
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0]   (snipped, see stack below for prefix)
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0]   File "/home/ryanguo99/scratch/comp.py", line 8, in f
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0]     res = torch.arange(x1, x2)
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0]
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I0710 11:54:29.174000 4046875 torch/fx/experimental/symbolic_shapes.py:1290] [0/0] compute_unbacked_bindings [u0, u1]
I0710 11:54:29.175000 4046875 torch/_dynamo/convert_frame.py:1225] [0/0] run_gc_after_compile: running gc
Traceback (most recent call last):
  File "/home/ryanguo99/scratch/comp.py", line 13, in <module>
    res = f(xs)
          ^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 784, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 1578, in __call__
    return self._torchdynamo_orig_backend(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 1336, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 651, in __call__
    result = _compile(
             ^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 1214, in _compile
    raise InternalTorchDynamoError(
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 1153, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_utils_internal.py", line 98, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 827, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 866, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
    transformations(instructions, code_options)
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 276, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 787, in transform
    tracer.run()
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 3517, in run
    super().run()
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1370, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1274, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 851, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 2921, in CALL
    self._call(inst)
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 2915, in _call
    self.call_function(fn, args, kwargs)
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1198, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/torch.py", line 1462, in call_function
    tensor_variable = wrap_fx_proxy(
                      ^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2618, in wrap_fx_proxy                                                                                                                               [0/13272]
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2684, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
           ^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2784, in _wrap_fx_proxy
    return handle_traced_output(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2796, in handle_traced_output
    var = construct_tensor_variable(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2999, in construct_tensor_variable
    set_example_value(proxy.node, example_value)
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/utils.py", line 2614, in set_example_value
    := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1302, in compute_unbacked_bindings
    raise PendingUnbackedSymbolNotFound(
torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u1, u0} not in returned outputs FakeTensor(..., size=(u0 - u1,), dtype=torch.int64) ((1,), 0).
Did you accidentally call new_dynamic_size() or item() more times than you needed to in your fake implementation?
For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit

from user code:
   File "/home/ryanguo99/scratch/comp.py", line 8, in f
    res = torch.arange(x1, x2)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

I0710 11:54:29.181000 4046875 torch/_dynamo/eval_frame.py:540] TorchDynamo attempted to trace the following frames: [
I0710 11:54:29.181000 4046875 torch/_dynamo/eval_frame.py:540]   * f /home/ryanguo99/scratch/comp.py:5
I0710 11:54:29.181000 4046875 torch/_dynamo/eval_frame.py:540] ]

Versions

main fcc682b, python 3.11

cc @chauhang @penguinwu @ezyang @bobrenjc93 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy