-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[dtensor][cp] experiment: register flex_attention to a custom fn on DTensor #147515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…Tensor [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147515
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 1 Unrelated FailureAs of commit 42882bc with merge base 5a7588f ( NEW FAILURES - The following jobs have failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…tom fn on DTensor" ### Summary Attempt to dispatch flex_attention on DTensor to a custom CP flex_attention function but got the error below. This error should be identical to #146994 . ``` E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] Caught exception: E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] Traceback (most recent call last): E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/testing/_internal/common_distributed.py", line 726, in run_test E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] getattr(self, test_name)() E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/testing/_internal/common_distributed.py", line 599, in wrapper E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] fn() E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/testing/_internal/common_utils.py", line 3155, in wrapper E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] method(*args, **kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 405, in wrapper E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] raise e E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 402, in wrapper E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] func(self, *args, **kwargs) # type: ignore[misc] E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/test/distributed/tensor/test_attention.py", line 493, in test_ring_flex_attention E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] out_dt = flex_attention(q_dist, k_dist, v_dist, block_mask=block_mask) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_dynamo/eval_frame.py", line 589, in _fn E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_dynamo/output_graph.py", line 1509, in _call_user_compiler E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] raise BackendCompilerFailed( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_dynamo/output_graph.py", line 1488, in _call_user_compiler E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] compiled_fn = compiler_fn(gm, self.example_inputs()) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] compiled_gm = compiler_fn(gm, example_inputs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/__init__.py", line 2339, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return compile_fx(model_, inputs_, config_patches=self.config) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_inductor/compile_fx.py", line 2168, in compile_fx E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return aot_autograd( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_dynamo/backends/common.py", line 101, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] cg = aot_module_simplified(gm, example_inputs, **self.kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_functorch/aot_autograd.py", line 1158, in aot_module_simplified E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] compiled_fn = AOTAutogradCache.load( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 779, in load E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] compiled_fn = dispatch_and_compile() E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_functorch/aot_autograd.py", line 1143, in dispatch_and_compile E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] compiled_fn, _ = create_aot_dispatcher_function( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return _create_aot_dispatcher_function( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_functorch/aot_autograd.py", line 671, in _create_aot_dispatcher_function E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] fw_metadata = run_functionalized_fw_and_collect_metadata( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] flat_f_outs = f(*flat_f_args) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 899, in functional_call E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] out = PropagateUnbackedSymInts(mod).run( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/fx/interpreter.py", line 171, in run E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] self.env[node] = self.run_node(node) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7084, in run_node E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] result = super().run_node(n) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/fx/interpreter.py", line 236, in run_node E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return getattr(self, n.op)(n.target, args, kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/fx/interpreter.py", line 316, in call_function E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return target(*args, **kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return super().__call__( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 471, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return wrapper() E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 467, in wrapper E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return self.dispatch( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 327, in dispatch E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return kernel(*args, **kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_higher_order_ops/flex_attention.py", line 744, in flex_attention_autograd E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] out, logsumexp = FlexAttentionAutogradOp.apply( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/autograd/function.py", line 575, in apply E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return super().apply(*args, **kwargs) # type: ignore[misc] E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_higher_order_ops/flex_attention.py", line 610, in forward E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] out, logsumexp = flex_attention( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return super().__call__( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 471, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return wrapper() E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 462, in wrapper E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return torch.overrides.handle_torch_function( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/overrides.py", line 1721, in handle_torch_function E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] result = mode.__torch_function__(public_api, types, args, kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 142, in __torch_function__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return func(*args, **(kwargs or {})) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return super().__call__( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 471, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return wrapper() E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 467, in wrapper E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return self.dispatch( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 363, in dispatch E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] result = handler(mode, *args, **kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 179, in functionalize_dispatch_mode_fn E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_higher_order_ops/flex_attention.py", line 415, in flex_attention_functionalize E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] [query_unwrapped.new_zeros(())] E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_compile.py", line 51, in inner E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return disable_fn(*args, **kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_dynamo/eval_frame.py", line 764, in _fn E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return fn(*args, **kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/distributed/tensor/_api.py", line 348, in __torch_dispatch__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return DTensor._op_dispatcher.dispatch( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/distributed/tensor/_dispatch.py", line 221, in dispatch E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] local_results = op_call(*local_tensor_args, **op_info.local_kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_ops.py", line 756, in __call__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] return self._op(*args, **kwargs) E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] File "/data/users/xilunwu/oss/pytorch/torch/_subclasses/functional_tensor.py", line 201, in __torch_dispatch__ E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] raise RuntimeError( E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: E0220 00:44:53.839000 1006342 torch/testing/_internal/common_distributed.py:733] RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode() ``` [ghstack-poisoned]
…Tensor ghstack-source-id: efca286 Pull Request resolved: pytorch/pytorch#147515
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
Summary
Attempt to dispatch flex_attention on DTensor to a custom CP flex_attention function but got the error below. This error should be identical to #146994 .
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k