Skip to content

Commit 19f69d3

Browse files
bdhirshdrisspg
authored andcommitted
flex attention: fix dispatch order for tensor subclasses, avoid hardcoding call to faketensor impl in dynamo
ghstack-source-id: 18b3717 Pull Request resolved: #151719
1 parent c92f107 commit 19f69d3

File tree

4 files changed

+288
-59
lines changed

4 files changed

+288
-59
lines changed

test/inductor/test_flex_attention.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3805,6 +3805,145 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
38053805
expected_joint_graph,
38063806
)
38073807

3808+
@supported_platform
3809+
def test_tensor_subclass_dispatch_order(self, device):
3810+
"""Test that tensor subclasses get proper dispatch priority over modes.
3811+
3812+
This test verifies the fix that allows tensor subclasses' pyimpl to run before
3813+
FakeTensorMode/FunctionalTensorMode implementations, preventing issues
3814+
where subclasses that error on as_strided would fail in flex_attention.
3815+
"""
3816+
import torch.utils._pytree as pytree
3817+
from torch.utils._python_dispatch import return_and_correct_aliasing
3818+
3819+
class AsStridedErrorTensor(torch.Tensor):
3820+
@staticmethod
3821+
def __new__(cls, elem):
3822+
assert isinstance(elem, torch.Tensor)
3823+
return torch.Tensor._make_wrapper_subclass(
3824+
cls,
3825+
elem.shape,
3826+
strides=elem.stride(),
3827+
storage_offset=elem.storage_offset(),
3828+
dtype=elem.dtype,
3829+
layout=elem.layout,
3830+
device=elem.device,
3831+
requires_grad=elem.requires_grad,
3832+
)
3833+
3834+
def __init__(self, elem):
3835+
self.elem = elem
3836+
3837+
def __repr__(self):
3838+
return f"AsStridedErrorTensor({self.elem})"
3839+
3840+
def __tensor_flatten__(self):
3841+
return ["elem"], None
3842+
3843+
@staticmethod
3844+
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
3845+
assert meta is None
3846+
elem = inner_tensors["elem"]
3847+
return AsStridedErrorTensor(elem)
3848+
3849+
@classmethod
3850+
def __torch_dispatch__(cls, func, types, args, kwargs=None):
3851+
# Error if as_strided is called
3852+
if func is torch.ops.aten.as_strided.default:
3853+
raise RuntimeError("as_strided was called on AsStridedErrorTensor!")
3854+
3855+
if kwargs is None:
3856+
kwargs = {}
3857+
args_elem = pytree.tree_map_only(
3858+
AsStridedErrorTensor, lambda x: x.elem, args
3859+
)
3860+
kwargs_elem = pytree.tree_map_only(
3861+
AsStridedErrorTensor, lambda x: x.elem, kwargs
3862+
)
3863+
3864+
out = func(*args_elem, **kwargs_elem)
3865+
3866+
def wrap_output(x):
3867+
if isinstance(x, torch.Tensor):
3868+
return AsStridedErrorTensor(x)
3869+
return x
3870+
3871+
out_wrapped = pytree.tree_map(wrap_output, out)
3872+
return return_and_correct_aliasing(func, args, kwargs, out_wrapped)
3873+
3874+
from torch._higher_order_ops.flex_attention import (
3875+
flex_attention as flex_attention_hop,
3876+
)
3877+
3878+
@flex_attention_hop.py_impl(AsStridedErrorTensor)
3879+
def flex_attention_as_strided_error_tensor(
3880+
query: torch.Tensor,
3881+
key: torch.Tensor,
3882+
value: torch.Tensor,
3883+
score_mod,
3884+
block_mask,
3885+
scale,
3886+
kernel_options,
3887+
score_mod_other_buffers=(),
3888+
mask_mod_other_buffers=(),
3889+
):
3890+
inner_q, inner_k, inner_v = query.elem, key.elem, value.elem
3891+
out, lse = flex_attention_hop(
3892+
inner_q,
3893+
inner_k,
3894+
inner_v,
3895+
score_mod,
3896+
block_mask,
3897+
scale,
3898+
kernel_options,
3899+
score_mod_other_buffers,
3900+
mask_mod_other_buffers,
3901+
)
3902+
return AsStridedErrorTensor(out), AsStridedErrorTensor(lse)
3903+
3904+
# Test setup
3905+
B, H, S, D = 2, 1, 128, 16
3906+
dtype = torch.float32
3907+
3908+
# Create regular tensors
3909+
query_elem = torch.randn(B, H, S, D, device=device, dtype=dtype)
3910+
key_elem = torch.randn(B, H, S, D, device=device, dtype=dtype)
3911+
value_elem = torch.randn(B, H, S, D, device=device, dtype=dtype)
3912+
3913+
# Test 1: Verify as_strided raises error when called directly on AsStridedErrorTensor
3914+
test_tensor = AsStridedErrorTensor(query_elem)
3915+
with self.assertRaisesRegex(
3916+
RuntimeError, "as_strided was called on AsStridedErrorTensor!"
3917+
):
3918+
torch.as_strided(
3919+
test_tensor, size=(B, H, S, D), stride=test_tensor.stride()
3920+
)
3921+
3922+
# Test 2: Run flex_attention with normal tensors first
3923+
compiled_fn = torch.compile(flex_attention, backend="aot_eager", fullgraph=True)
3924+
normal_out, normal_lse = compiled_fn(
3925+
query_elem, key_elem, value_elem, return_lse=True
3926+
)
3927+
3928+
# Test 3: Wrap in our subclass
3929+
query = AsStridedErrorTensor(query_elem)
3930+
key = AsStridedErrorTensor(key_elem)
3931+
value = AsStridedErrorTensor(value_elem)
3932+
3933+
# This should NOT error with as_strided after the fix
3934+
# Before the fix, it would error because FakeTensorMode would directly
3935+
# call flex_attention_fake_impl which uses as_strided
3936+
out, lse = compiled_fn(query, key, value, return_lse=True)
3937+
# Verify we got valid output
3938+
self.assertIsInstance(out, AsStridedErrorTensor)
3939+
self.assertIsInstance(lse, AsStridedErrorTensor)
3940+
self.assertEqual(out.shape, (B, H, S, D))
3941+
self.assertEqual(lse.shape, (B, H, S))
3942+
3943+
# Test 4: Compare outputs between normal tensors and subclassed tensors
3944+
torch.testing.assert_close(out.elem, normal_out, rtol=1e-5, atol=1e-5)
3945+
torch.testing.assert_close(lse.elem, normal_lse, rtol=1e-5, atol=1e-5)
3946+
38083947
@supported_platform
38093948
@skip_on_cuda
38103949
def test_cpu_error_message_return_lse(self, device):

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,8 +2826,6 @@ def call_function(
28262826
args: "list[VariableTracker]",
28272827
kwargs: "dict[str, VariableTracker]",
28282828
) -> "VariableTracker":
2829-
from torch._higher_order_ops.flex_attention import flex_attention_fake_impl
2830-
28312829
from .builder import wrap_fx_proxy
28322830

28332831
(
@@ -2864,12 +2862,6 @@ def call_function(
28642862
# Proxying user defined functions is not supported.
28652863
inp_args, _ = proxy_args_kwargs(proxied_args, {})
28662864

2867-
query_meta = query.as_proxy().node.meta["example_value"]
2868-
value_meta = value.as_proxy().node.meta["example_value"]
2869-
with torch._guards.TracingContext.try_get().fake_mode:
2870-
out_meta, lse_meta = flex_attention_fake_impl(query_meta, value_meta)
2871-
example_value = (out_meta, lse_meta)
2872-
28732865
# Compose the ordered HOO args:
28742866
# - inp_args: [query, key, value, block_mask, scale, kernel_options]
28752867
# - subgraph node: [score_mod, mask_fn_node]
@@ -2892,7 +2884,7 @@ def call_function(
28922884
),
28932885
kwargs={},
28942886
),
2895-
example_value=example_value,
2887+
example_value=None,
28962888
)
28972889

28982890

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy