@@ -3805,6 +3805,145 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
3805
3805
expected_joint_graph ,
3806
3806
)
3807
3807
3808
+ @supported_platform
3809
+ def test_tensor_subclass_dispatch_order (self , device ):
3810
+ """Test that tensor subclasses get proper dispatch priority over modes.
3811
+
3812
+ This test verifies the fix that allows tensor subclasses' pyimpl to run before
3813
+ FakeTensorMode/FunctionalTensorMode implementations, preventing issues
3814
+ where subclasses that error on as_strided would fail in flex_attention.
3815
+ """
3816
+ import torch .utils ._pytree as pytree
3817
+ from torch .utils ._python_dispatch import return_and_correct_aliasing
3818
+
3819
+ class AsStridedErrorTensor (torch .Tensor ):
3820
+ @staticmethod
3821
+ def __new__ (cls , elem ):
3822
+ assert isinstance (elem , torch .Tensor )
3823
+ return torch .Tensor ._make_wrapper_subclass (
3824
+ cls ,
3825
+ elem .shape ,
3826
+ strides = elem .stride (),
3827
+ storage_offset = elem .storage_offset (),
3828
+ dtype = elem .dtype ,
3829
+ layout = elem .layout ,
3830
+ device = elem .device ,
3831
+ requires_grad = elem .requires_grad ,
3832
+ )
3833
+
3834
+ def __init__ (self , elem ):
3835
+ self .elem = elem
3836
+
3837
+ def __repr__ (self ):
3838
+ return f"AsStridedErrorTensor({ self .elem } )"
3839
+
3840
+ def __tensor_flatten__ (self ):
3841
+ return ["elem" ], None
3842
+
3843
+ @staticmethod
3844
+ def __tensor_unflatten__ (inner_tensors , meta , outer_size , outer_stride ):
3845
+ assert meta is None
3846
+ elem = inner_tensors ["elem" ]
3847
+ return AsStridedErrorTensor (elem )
3848
+
3849
+ @classmethod
3850
+ def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
3851
+ # Error if as_strided is called
3852
+ if func is torch .ops .aten .as_strided .default :
3853
+ raise RuntimeError ("as_strided was called on AsStridedErrorTensor!" )
3854
+
3855
+ if kwargs is None :
3856
+ kwargs = {}
3857
+ args_elem = pytree .tree_map_only (
3858
+ AsStridedErrorTensor , lambda x : x .elem , args
3859
+ )
3860
+ kwargs_elem = pytree .tree_map_only (
3861
+ AsStridedErrorTensor , lambda x : x .elem , kwargs
3862
+ )
3863
+
3864
+ out = func (* args_elem , ** kwargs_elem )
3865
+
3866
+ def wrap_output (x ):
3867
+ if isinstance (x , torch .Tensor ):
3868
+ return AsStridedErrorTensor (x )
3869
+ return x
3870
+
3871
+ out_wrapped = pytree .tree_map (wrap_output , out )
3872
+ return return_and_correct_aliasing (func , args , kwargs , out_wrapped )
3873
+
3874
+ from torch ._higher_order_ops .flex_attention import (
3875
+ flex_attention as flex_attention_hop ,
3876
+ )
3877
+
3878
+ @flex_attention_hop .py_impl (AsStridedErrorTensor )
3879
+ def flex_attention_as_strided_error_tensor (
3880
+ query : torch .Tensor ,
3881
+ key : torch .Tensor ,
3882
+ value : torch .Tensor ,
3883
+ score_mod ,
3884
+ block_mask ,
3885
+ scale ,
3886
+ kernel_options ,
3887
+ score_mod_other_buffers = (),
3888
+ mask_mod_other_buffers = (),
3889
+ ):
3890
+ inner_q , inner_k , inner_v = query .elem , key .elem , value .elem
3891
+ out , lse = flex_attention_hop (
3892
+ inner_q ,
3893
+ inner_k ,
3894
+ inner_v ,
3895
+ score_mod ,
3896
+ block_mask ,
3897
+ scale ,
3898
+ kernel_options ,
3899
+ score_mod_other_buffers ,
3900
+ mask_mod_other_buffers ,
3901
+ )
3902
+ return AsStridedErrorTensor (out ), AsStridedErrorTensor (lse )
3903
+
3904
+ # Test setup
3905
+ B , H , S , D = 2 , 1 , 128 , 16
3906
+ dtype = torch .float32
3907
+
3908
+ # Create regular tensors
3909
+ query_elem = torch .randn (B , H , S , D , device = device , dtype = dtype )
3910
+ key_elem = torch .randn (B , H , S , D , device = device , dtype = dtype )
3911
+ value_elem = torch .randn (B , H , S , D , device = device , dtype = dtype )
3912
+
3913
+ # Test 1: Verify as_strided raises error when called directly on AsStridedErrorTensor
3914
+ test_tensor = AsStridedErrorTensor (query_elem )
3915
+ with self .assertRaisesRegex (
3916
+ RuntimeError , "as_strided was called on AsStridedErrorTensor!"
3917
+ ):
3918
+ torch .as_strided (
3919
+ test_tensor , size = (B , H , S , D ), stride = test_tensor .stride ()
3920
+ )
3921
+
3922
+ # Test 2: Run flex_attention with normal tensors first
3923
+ compiled_fn = torch .compile (flex_attention , backend = "aot_eager" , fullgraph = True )
3924
+ normal_out , normal_lse = compiled_fn (
3925
+ query_elem , key_elem , value_elem , return_lse = True
3926
+ )
3927
+
3928
+ # Test 3: Wrap in our subclass
3929
+ query = AsStridedErrorTensor (query_elem )
3930
+ key = AsStridedErrorTensor (key_elem )
3931
+ value = AsStridedErrorTensor (value_elem )
3932
+
3933
+ # This should NOT error with as_strided after the fix
3934
+ # Before the fix, it would error because FakeTensorMode would directly
3935
+ # call flex_attention_fake_impl which uses as_strided
3936
+ out , lse = compiled_fn (query , key , value , return_lse = True )
3937
+ # Verify we got valid output
3938
+ self .assertIsInstance (out , AsStridedErrorTensor )
3939
+ self .assertIsInstance (lse , AsStridedErrorTensor )
3940
+ self .assertEqual (out .shape , (B , H , S , D ))
3941
+ self .assertEqual (lse .shape , (B , H , S ))
3942
+
3943
+ # Test 4: Compare outputs between normal tensors and subclassed tensors
3944
+ torch .testing .assert_close (out .elem , normal_out , rtol = 1e-5 , atol = 1e-5 )
3945
+ torch .testing .assert_close (lse .elem , normal_lse , rtol = 1e-5 , atol = 1e-5 )
3946
+
3808
3947
@supported_platform
3809
3948
@skip_on_cuda
3810
3949
def test_cpu_error_message_return_lse (self , device ):
0 commit comments