Skip to content

flex attention: fix dispatch order for tensor subclasses, avoid hardcoding call to faketensor impl in dynamo #151719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
drisspg committed Jun 17, 2025
commit fcadeb7605579f9a259c48da8c37ea3ca4283b62
53 changes: 17 additions & 36 deletions torch/_higher_order_ops/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_has_potential_branch_input_mutation,
_maybe_reenter_make_fx,
autograd_not_implemented,
has_user_subclass,
redirect_to_mode,
reenter_make_fx,
register_fake,
Expand Down Expand Up @@ -400,7 +401,8 @@ def flex_attention_functionalize(
"""
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex

flat_args, _ = pytree.tree_flatten(

if has_user_subclass(
(
query,
key,
Expand All @@ -411,14 +413,8 @@ def flex_attention_functionalize(
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
)
# For tensor subclasses, give the subclass a chance to run first
if any(
isinstance(a, torch.Tensor)
and type(a) is not torch.Tensor
and not isinstance(a, (FakeTensor, FunctionalTensor))
for a in flat_args
),
allowed_subclasses=(FakeTensor, FunctionalTensor),
):
return NotImplemented

Expand Down Expand Up @@ -483,7 +479,8 @@ def flex_attention_fake_impl(
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
flat_args, _ = pytree.tree_flatten(

if has_user_subclass(
(
query,
key,
Expand All @@ -494,14 +491,8 @@ def flex_attention_fake_impl(
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
)
# For tensor subclasses, give the subclass a chance to run first
if any(
isinstance(a, torch.Tensor)
and type(a) is not torch.Tensor
and not isinstance(a, FakeTensor)
for a in flat_args
),
allowed_subclasses=(FakeTensor,),
):
return NotImplemented

Expand Down Expand Up @@ -1134,7 +1125,8 @@ def flex_attention_backward_functionalize(
since we know that the forward score mod function is assured to be free of mutations
to the other_buffers, we skip that mutate check and go straight to redispatching.
"""
flat_args, _ = pytree.tree_flatten(

if has_user_subclass(
(
query,
key,
Expand All @@ -1148,14 +1140,8 @@ def flex_attention_backward_functionalize(
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
)
# For tensor subclasses, give the subclass a chance to run first
if any(
isinstance(a, torch.Tensor)
and type(a) is not torch.Tensor
and not isinstance(a, (FakeTensor, FunctionalTensor))
for a in flat_args
),
allowed_subclasses=(FakeTensor, FunctionalTensor),
):
return NotImplemented
query_unwrapped = ctx.unwrap_tensors(query)
Expand Down Expand Up @@ -1229,7 +1215,8 @@ def flex_attention_backward_fake_tensor_mode(
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
]:
flat_args, _ = pytree.tree_flatten(

if has_user_subclass(
(
query,
key,
Expand All @@ -1243,14 +1230,8 @@ def flex_attention_backward_fake_tensor_mode(
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
)
# For tensor subclasses, give the subclass a chance to run first
if any(
isinstance(a, torch.Tensor)
and type(a) is not torch.Tensor
and not isinstance(a, FakeTensor)
for a in flat_args
),
allowed_subclasses=(FakeTensor,),
):
return NotImplemented
Bq, _, _, qk_head_dim = query.shape
Expand Down
22 changes: 22 additions & 0 deletions torch/_higher_order_ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,28 @@ def wrapped_fn(*flat_args):
return pytree.tree_unflatten(materialized_args, flat_spec)


def has_user_subclass(args, allowed_subclasses):
"""Check if any tensor arguments are user subclasses.

This is used to determine if tensor subclasses should get a chance to run
their own implementation first before falling back to the default implementation.

Args:
args: Arguments to check (will be flattened with pytree)
allowed_subclasses: Tuple of allowed subclass types

Returns:
True if user tensor subclasses are found, False otherwise
"""
flat_args, _ = pytree.tree_flatten(args)
return any(
isinstance(a, torch.Tensor)
and type(a) is not torch.Tensor
and not isinstance(a, allowed_subclasses)
for a in flat_args
)


def _has_gen_schema(op: HigherOrderOperator):
# There is an InvokeQuant argument we cannot gen_schema.
if op is torch.ops.higher_order.invoke_quant_packed:
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy