Skip to content

flex attention: fix dispatch order for tensor subclasses, avoid hardcoding call to faketensor impl in dynamo #151719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Apr 18, 2025

This is enough to get @XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198

There are two problems:

(1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run.

This normally happens automatically in FakeTensorMode.__torch_dispatch__ and FunctionalTensorMode.__torch_dispatch__. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this NotImplemented logic directly into the functional/fake rules for flex attention.

Having to do this for every HOP seems a bit painful. If we could plumb every HOP through Fake[|Functional]TensorMode.__torch_dispatch__ then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this NotImplemented logic into torch/_ops.py. I'm not sure if we'd need a way for the HOP to opt out of this though.

(2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation.

Testing Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic.

This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

…oding call to faketensor impl in dynamo

[ghstack-poisoned]
@bdhirsh bdhirsh requested a review from zou3519 as a code owner April 18, 2025 23:14
Copy link

pytorch-bot bot commented Apr 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151719

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 1ba7ad0 with merge base c92f107 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request Apr 18, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: c4d5d49
Pull Request resolved: #151719
…avoid hardcoding call to faketensor impl in dynamo"

This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198

There are two problems:

(1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run.

This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention.

Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though.

(2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation.


**Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic.

This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Apr 25, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: ba6f1ee
Pull Request resolved: #151719
…avoid hardcoding call to faketensor impl in dynamo"

This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198

There are two problems:

(1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run.

This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention.

Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though.

(2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation.


**Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic.

This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Apr 25, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: cc7c17d
Pull Request resolved: #151719
…avoid hardcoding call to faketensor impl in dynamo"

This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198

There are two problems:

(1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run.

This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention.

Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though.

(2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation.


**Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic.

This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Apr 28, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: e402015
Pull Request resolved: #151719
…avoid hardcoding call to faketensor impl in dynamo"

This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198

There are two problems:

(1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run.

This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention.

Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though.

(2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation.


**Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic.

This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Apr 29, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: 0eeb838
Pull Request resolved: #151719
…avoid hardcoding call to faketensor impl in dynamo"

This is enough to get XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198

There are two problems:

(1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run.

This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention.

Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though.

(2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation.


**Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic.

This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 2, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: 17535e0
Pull Request resolved: #151719
XilunWu added a commit that referenced this pull request Jun 12, 2025
[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jun 12, 2025
ghstack-source-id: 02f4618
Pull Request resolved: #155851
@drisspg
Copy link
Contributor

drisspg commented Jun 17, 2025

@zou3519 @bdhirsh @ydwu4
Update the code, added a test and addressed all comments

[ghstack-poisoned]
drisspg pushed a commit that referenced this pull request Jun 17, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: d7f1a8e
Pull Request resolved: #151719
[ghstack-poisoned]
drisspg pushed a commit that referenced this pull request Jun 17, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: 872151d
Pull Request resolved: #151719
[ghstack-poisoned]
drisspg pushed a commit that referenced this pull request Jun 17, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: 0daefe7
Pull Request resolved: #151719
@@ -2892,7 +2884,7 @@ def call_function(
),
kwargs={},
),
example_value=example_value,
example_value=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By passing in None for example value wrap_fx_proxy will implicitly do the above

[ghstack-poisoned]
drisspg pushed a commit that referenced this pull request Jun 18, 2025
…oding call to faketensor impl in dynamo

ghstack-source-id: 18b3717
Pull Request resolved: #151719
@drisspg
Copy link
Contributor

drisspg commented Jun 18, 2025

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 18, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: inductor / linux-jammy-cpu-py3.9-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Comment on lines +404 to +418
if has_user_subclass(
(
query,
key,
value,
score_mod,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
),
allowed_subclasses=(FakeTensor, FunctionalTensor),
):
return NotImplemented
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go into py_functionalize_impl so that other HOPs can benefit:

pytorch/torch/_ops.py

Lines 173 to 188 in 1bb9b18

def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
return fn(CppFunctionalizeAPI(), *args, **kwargs)
def functionalize_dispatch_mode_fn(
mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
def functionalize_functorch_fn(
interpreter, *args: _P.args, **kwargs: _P.kwargs
) -> _T:
return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn)
self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn)
.

All HOPs need this behavior, this PR only adds it for FlexAttention

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Brian mentioned that he wasn't sure if we should make this default for everyone, do we need a way for hops to opt out?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just make it an option to py_functionalize_impl if we do need a way for HOPs to opt-out. But I don't think we should allow HOPs to opt out, there's not a way for operators to opt-out of this.

Comment on lines +481 to +495
if has_user_subclass(
(
query,
key,
value,
score_mod,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
),
allowed_subclasses=(FakeTensor,),
):
return NotImplemented
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why this is necessary -- @register_fake will directly call FakeTensor's __torch_dispatch__, and FakeTensor already has this logic. So we shouldn't need to duplicate it.

Comment on lines +1127 to +1144
if has_user_subclass(
(
query,
key,
value,
out,
logsumexp,
grad_out,
grad_logsumexp,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
),
allowed_subclasses=(FakeTensor, FunctionalTensor),
):
return NotImplemented
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above -- if we put this into py_functionalize_impl, then we don't need to do this again for flex_attention_backward

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg , if you have bandwidth, could you try to address the comments?

@drisspg
Copy link
Contributor

drisspg commented Jun 18, 2025

@zou3519 yeah ill open up a new one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy