-
Notifications
You must be signed in to change notification settings - Fork 24.7k
flex attention: fix dispatch order for tensor subclasses, avoid hardcoding call to faketensor impl in dynamo #151719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…oding call to faketensor impl in dynamo [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151719
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 1ba7ad0 with merge base c92f107 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…avoid hardcoding call to faketensor impl in dynamo" This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198 There are two problems: (1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run. This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention. Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though. (2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation. **Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic. This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…avoid hardcoding call to faketensor impl in dynamo" This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198 There are two problems: (1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run. This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention. Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though. (2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation. **Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic. This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…avoid hardcoding call to faketensor impl in dynamo" This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198 There are two problems: (1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run. This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention. Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though. (2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation. **Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic. This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…avoid hardcoding call to faketensor impl in dynamo" This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198 There are two problems: (1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run. This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention. Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though. (2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation. **Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic. This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…avoid hardcoding call to faketensor impl in dynamo" This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198 There are two problems: (1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run. This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention. Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though. (2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation. **Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic. This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
@@ -2892,7 +2884,7 @@ def call_function( | |||
), | |||
kwargs={}, | |||
), | |||
example_value=example_value, | |||
example_value=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By passing in None for example value wrap_fx_proxy
will implicitly do the above
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: inductor / linux-jammy-cpu-py3.9-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
if has_user_subclass( | ||
( | ||
query, | ||
key, | ||
value, | ||
score_mod, | ||
block_mask, | ||
scale, | ||
kernel_options, | ||
score_mod_other_buffers, | ||
mask_mod_other_buffers, | ||
), | ||
allowed_subclasses=(FakeTensor, FunctionalTensor), | ||
): | ||
return NotImplemented |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should go into py_functionalize_impl so that other HOPs can benefit:
Lines 173 to 188 in 1bb9b18
def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: | |
return fn(CppFunctionalizeAPI(), *args, **kwargs) | |
def functionalize_dispatch_mode_fn( | |
mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs | |
) -> _T: | |
return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) | |
def functionalize_functorch_fn( | |
interpreter, *args: _P.args, **kwargs: _P.kwargs | |
) -> _T: | |
return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) | |
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn) | |
self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn) | |
self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn) |
All HOPs need this behavior, this PR only adds it for FlexAttention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Brian mentioned that he wasn't sure if we should make this default for everyone, do we need a way for hops to opt out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could just make it an option to py_functionalize_impl if we do need a way for HOPs to opt-out. But I don't think we should allow HOPs to opt out, there's not a way for operators to opt-out of this.
if has_user_subclass( | ||
( | ||
query, | ||
key, | ||
value, | ||
score_mod, | ||
block_mask, | ||
scale, | ||
kernel_options, | ||
score_mod_other_buffers, | ||
mask_mod_other_buffers, | ||
), | ||
allowed_subclasses=(FakeTensor,), | ||
): | ||
return NotImplemented |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why this is necessary -- @register_fake will directly call FakeTensor's __torch_dispatch__
, and FakeTensor already has this logic. So we shouldn't need to duplicate it.
if has_user_subclass( | ||
( | ||
query, | ||
key, | ||
value, | ||
out, | ||
logsumexp, | ||
grad_out, | ||
grad_logsumexp, | ||
block_mask, | ||
scale, | ||
kernel_options, | ||
score_mod_other_buffers, | ||
mask_mod_other_buffers, | ||
), | ||
allowed_subclasses=(FakeTensor, FunctionalTensor), | ||
): | ||
return NotImplemented |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above -- if we put this into py_functionalize_impl, then we don't need to do this again for flex_attention_backward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg , if you have bandwidth, could you try to address the comments?
@zou3519 yeah ill open up a new one |
This is enough to get @XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198
There are two problems:
(1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run.
This normally happens automatically in
FakeTensorMode.__torch_dispatch__
andFunctionalTensorMode.__torch_dispatch__
. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding thisNotImplemented
logic directly into the functional/fake rules for flex attention.Having to do this for every HOP seems a bit painful. If we could plumb every HOP through
Fake[|Functional]TensorMode.__torch_dispatch__
then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode thisNotImplemented
logic intotorch/_ops.py
. I'm not sure if we'd need a way for the HOP to opt out of this though.(2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation.
Testing Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic.
This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov