-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Open
Labels
module: dynamic shapesmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
This was discovered in #157499, which tries to turn on capture_scalar_outputs
by default. Specifically python test/dynamo/test_repros.py ReproTests.test_do_paste_mask
fails. I made a minimal repro below, we should graph break rather than hard error in this case (if we can eliminate the graph break, even better).
import torch
torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(backend="eager", fullgraph=False)
def f(xs):
x1, x2 = xs
res = torch.arange(x1, x2)
return res
xs = torch.tensor([2, 4])
res = f(xs)
Error logs
Running with `TORCH_LOGS="+dynamo, +dynamic":
I0710 11:54:28.616000 4046875 torch/_dynamo/utils.py:1697] [0/0] ChromiumEventLogger initialized with id f19ae67f-33d7-46c0-86b1-398570aa8d39
V0710 11:54:28.617000 4046875 torch/_dynamo/convert_frame.py:1097] [0/0] torchdynamo start compiling f /home/ryanguo99/scratch/comp.py:5, stack (elided 5 frames):
V0710 11:54:28.617000 4046875 torch/_dynamo/convert_frame.py:1097] [0/0] File "/home/ryanguo99/scratch/comp.py", line 13, in <module>
V0710 11:54:28.617000 4046875 torch/_dynamo/convert_frame.py:1097] [0/0] res = f(xs)
V0710 11:54:28.617000 4046875 torch/_dynamo/convert_frame.py:1097] [0/0]
I0710 11:54:29.104000 4046875 torch/_dynamo/symbolic_convert.py:3340] [0/0] Step 1: torchdynamo start tracing f /home/ryanguo99/scratch/comp.py:5
I0710 11:54:29.105000 4046875 torch/fx/experimental/symbolic_shapes.py:3776] [0/0] create_env
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source] TRACE starts_line /home/ryanguo99/scratch/comp.py:5 in f ()
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source] @torch.compile(backend="eager", fullgraph=True)
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE RESUME 0 []
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source] TRACE starts_line /home/ryanguo99/scratch/comp.py:7 in f (f)
V0710 11:54:29.107000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source] x1, x2 = xs
V0710 11:54:29.108000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_FAST xs []
V0710 11:54:29.108000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE UNPACK_SEQUENCE 2 [LazyVariableTracker()]
V0710 11:54:29.109000 4046875 torch/_dynamo/variables/builder.py:3438] [0/0] wrap_to_fake L['xs'] (2,) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None], constraint_strides=[None], specialize_on=[[]], view_base_context=None, tensor_source=LocalSource(local_name='xs', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0710 11:54:29.109000 4046875 torch/_dynamo/output_graph.py:2629] [0/0] create_graph_input L_xs_ L['xs'] FakeTensor(..., size=(2,), dtype=torch.int64) at debug_level 0 before=False
V0710 11:54:29.110000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call] TRACE FX call getitem from /home/ryanguo99/scratch/comp.py:7 in f (f)
V0710 11:54:29.110000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call] x1, x2 = xs
V0710 11:54:29.110000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call] ^^^^^^
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE STORE_FAST x1 [TensorVariable(), TensorVariable()]
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE STORE_FAST x2 [TensorVariable()]
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source] TRACE starts_line /home/ryanguo99/scratch/comp.py:8 in f (f)
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1242] [0/0] [__trace_source] res = torch.arange(x1, x2)
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V0710 11:54:29.112000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_ATTR arange [NullVariable, LazyVariableTracker()]
V0710 11:54:29.113000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_FAST x1 [NullVariable, LazyVariableTracker()]
V0710 11:54:29.113000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE LOAD_FAST x2 [NullVariable, LazyVariableTracker(), TensorVariable()]
V0710 11:54:29.114000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE PRECALL 2 [NullVariable, LazyVariableTracker(), TensorVariable(), TensorVariable()]
V0710 11:54:29.114000 4046875 torch/_dynamo/symbolic_convert.py:1267] [0/0] [__trace_bytecode] TRACE CALL 2 [NullVariable, LazyVariableTracker(), TensorVariable(), TensorVariable()]
V0710 11:54:29.119000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call] TRACE FX call arange from /home/ryanguo99/scratch/comp.py:8 in f (f)
V0710 11:54:29.119000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call] res = torch.arange(x1, x2)
V0710 11:54:29.119000 4046875 torch/_dynamo/output_graph.py:2477] [0/0] [__trace_call] ~~~~~~~~~~~~^^^^^^^^
I0710 11:54:29.123000 4046875 torch/fx/experimental/symbolic_shapes.py:4783] [0/0] create_unbacked_symint u0 [-int_oo, int_oo] res = torch.arange(x1, x2) # scratch/comp.py:8 in f (_subclasses/fake_impls.py:425 in local_scalar_dense)
I0710 11:54:29.123000 4046875 torch/fx/experimental/symbolic_shapes.py:4783] [0/0] create_unbacked_symint u1 [-int_oo, int_oo] res = torch.arange(x1, x2) # scratch/comp.py:8 in f (_subclasses/fake_impls.py:425 in local_scalar_dense)
I0710 11:54:29.168000 4046875 torch/fx/experimental/symbolic_shapes.py:7211] [0/0] runtime_assert u0 >= u1 [guard added] res = torch.arange(x1, x2) # scratch/comp.py:8 in f (_refs/__init__.py:5105 in arange), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= u1"
V0710 11:54:29.170000 4046875 torch/fx/experimental/symbolic_shapes.py:7708] [0/0] runtime_assert u0 - u1 >= 0 == True [statically known]
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0] could not evaluate Eq(u0 - u1, 0) due to data dependency, it was assumed to be False with no runtime assertions res = torch.arange(x1, x2) # scratch/comp.py:8 in f (utils/_stats.py:28 in wrapper)
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0] User Stack (most recent call last):
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0] (snipped, see stack below for prefix)
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0] File "/home/ryanguo99/scratch/comp.py", line 8, in f
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0] res = torch.arange(x1, x2)
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0]
I0710 11:54:29.172000 4046875 torch/fx/experimental/symbolic_shapes.py:7383] [0/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I0710 11:54:29.174000 4046875 torch/fx/experimental/symbolic_shapes.py:1290] [0/0] compute_unbacked_bindings [u0, u1]
I0710 11:54:29.175000 4046875 torch/_dynamo/convert_frame.py:1225] [0/0] run_gc_after_compile: running gc
Traceback (most recent call last):
File "/home/ryanguo99/scratch/comp.py", line 13, in <module>
res = f(xs)
^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 784, in compile_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 1578, in __call__
return self._torchdynamo_orig_backend(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 1336, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 651, in __call__
result = _compile(
^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 1214, in _compile
raise InternalTorchDynamoError(
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 1153, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_utils_internal.py", line 98, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 827, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 866, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
transformations(instructions, code_options)
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 276, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/convert_frame.py", line 787, in transform
tracer.run()
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 3517, in run
super().run()
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1370, in run
while self.step():
^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1274, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 851, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 2921, in CALL
self._call(inst)
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 2915, in _call
self.call_function(fn, args, kwargs)
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1198, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/torch.py", line 1462, in call_function
tensor_variable = wrap_fx_proxy(
^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2618, in wrap_fx_proxy [0/13272]
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2684, in wrap_fx_proxy_cls
return _wrap_fx_proxy(
^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2784, in _wrap_fx_proxy
return handle_traced_output(
^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2796, in handle_traced_output
var = construct_tensor_variable(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/variables/builder.py", line 2999, in construct_tensor_variable
set_example_value(proxy.node, example_value)
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/utils.py", line 2614, in set_example_value
:= torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1302, in compute_unbacked_bindings
raise PendingUnbackedSymbolNotFound(
torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u1, u0} not in returned outputs FakeTensor(..., size=(u0 - u1,), dtype=torch.int64) ((1,), 0).
Did you accidentally call new_dynamic_size() or item() more times than you needed to in your fake implementation?
For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit
from user code:
File "/home/ryanguo99/scratch/comp.py", line 8, in f
res = torch.arange(x1, x2)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
I0710 11:54:29.181000 4046875 torch/_dynamo/eval_frame.py:540] TorchDynamo attempted to trace the following frames: [
I0710 11:54:29.181000 4046875 torch/_dynamo/eval_frame.py:540] * f /home/ryanguo99/scratch/comp.py:5
I0710 11:54:29.181000 4046875 torch/_dynamo/eval_frame.py:540] ]
Versions
main fcc682b, python 3.11
cc @chauhang @penguinwu @ezyang @bobrenjc93 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames
Metadata
Metadata
Assignees
Labels
module: dynamic shapesmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module