|
17 | 17 | validate_subgraph_args_types,
|
18 | 18 | )
|
19 | 19 | from torch._ops import HigherOrderOperator
|
20 |
| -from torch._subclasses import FakeTensorMode |
| 20 | +from torch._subclasses import FakeTensor, FakeTensorMode |
| 21 | +from torch._subclasses.functional_tensor import FunctionalTensor |
21 | 22 | from torch.fx.experimental.proxy_tensor import (
|
22 | 23 | make_fx,
|
23 | 24 | ProxyTorchDispatchMode,
|
@@ -396,6 +397,29 @@ def flex_attention_functionalize(
|
396 | 397 | """
|
397 | 398 | from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
398 | 399 |
|
| 400 | + flat_args, _ = pytree.tree_flatten( |
| 401 | + ( |
| 402 | + query, |
| 403 | + key, |
| 404 | + value, |
| 405 | + score_mod, |
| 406 | + block_mask, |
| 407 | + scale, |
| 408 | + kernel_options, |
| 409 | + score_mod_other_buffers, |
| 410 | + mask_mod_other_buffers, |
| 411 | + ) |
| 412 | + ) |
| 413 | + # For tensor subclasses, give the subclass a chance to run first |
| 414 | + if any( |
| 415 | + isinstance(a, torch.Tensor) |
| 416 | + and type(a) is not torch.Tensor |
| 417 | + and not isinstance(a, FakeTensor) |
| 418 | + and not isinstance(a, FunctionalTensor) |
| 419 | + for a in flat_args |
| 420 | + ): |
| 421 | + return NotImplemented |
| 422 | + |
399 | 423 | query_unwrapped = ctx.unwrap_tensors(query)
|
400 | 424 | key_unwrapped = ctx.unwrap_tensors(key)
|
401 | 425 | value_unwrapped = ctx.unwrap_tensors(value)
|
@@ -473,6 +497,27 @@ def flex_attention_fake_tensor_mode(
|
473 | 497 | score_mod_other_buffers: tuple = (),
|
474 | 498 | mask_mod_other_buffers: tuple = (),
|
475 | 499 | ) -> tuple[torch.Tensor, torch.Tensor]:
|
| 500 | + flat_args, _ = pytree.tree_flatten( |
| 501 | + ( |
| 502 | + query, |
| 503 | + key, |
| 504 | + value, |
| 505 | + score_mod, |
| 506 | + block_mask, |
| 507 | + scale, |
| 508 | + kernel_options, |
| 509 | + score_mod_other_buffers, |
| 510 | + mask_mod_other_buffers, |
| 511 | + ) |
| 512 | + ) |
| 513 | + # For tensor subclasses, give the subclass a chance to run first |
| 514 | + if any( |
| 515 | + isinstance(a, torch.Tensor) |
| 516 | + and type(a) is not torch.Tensor |
| 517 | + and not isinstance(a, FakeTensor) |
| 518 | + for a in flat_args |
| 519 | + ): |
| 520 | + return NotImplemented |
476 | 521 | with mode:
|
477 | 522 | out, logsumexp = flex_attention_fake_impl(query, value)
|
478 | 523 | return out, logsumexp
|
@@ -1086,6 +1131,31 @@ def flex_attention_backward_functionalize(
|
1086 | 1131 | since we know that the forward score mod function is assured to be free of mutations
|
1087 | 1132 | to the other_buffers, we skip that mutate check and go straight to redispatching.
|
1088 | 1133 | """
|
| 1134 | + flat_args, _ = pytree.tree_flatten( |
| 1135 | + ( |
| 1136 | + query, |
| 1137 | + key, |
| 1138 | + value, |
| 1139 | + out, |
| 1140 | + logsumexp, |
| 1141 | + grad_out, |
| 1142 | + grad_logsumexp, |
| 1143 | + block_mask, |
| 1144 | + scale, |
| 1145 | + kernel_options, |
| 1146 | + score_mod_other_buffers, |
| 1147 | + mask_mod_other_buffers, |
| 1148 | + ) |
| 1149 | + ) |
| 1150 | + # For tensor subclasses, give the subclass a chance to run first |
| 1151 | + if any( |
| 1152 | + isinstance(a, torch.Tensor) |
| 1153 | + and type(a) is not torch.Tensor |
| 1154 | + and not isinstance(a, FakeTensor) |
| 1155 | + and not isinstance(a, FunctionalTensor) |
| 1156 | + for a in flat_args |
| 1157 | + ): |
| 1158 | + return NotImplemented |
1089 | 1159 | query_unwrapped = ctx.unwrap_tensors(query)
|
1090 | 1160 | key_unwrapped = ctx.unwrap_tensors(key)
|
1091 | 1161 | value_unwrapped = ctx.unwrap_tensors(value)
|
@@ -1158,6 +1228,30 @@ def flex_attention_backward_fake_tensor_mode(
|
1158 | 1228 | ) -> tuple[
|
1159 | 1229 | torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
1160 | 1230 | ]:
|
| 1231 | + flat_args, _ = pytree.tree_flatten( |
| 1232 | + ( |
| 1233 | + query, |
| 1234 | + key, |
| 1235 | + value, |
| 1236 | + out, |
| 1237 | + logsumexp, |
| 1238 | + grad_out, |
| 1239 | + grad_logsumexp, |
| 1240 | + block_mask, |
| 1241 | + scale, |
| 1242 | + kernel_options, |
| 1243 | + score_mod_other_buffers, |
| 1244 | + mask_mod_other_buffers, |
| 1245 | + ) |
| 1246 | + ) |
| 1247 | + # For tensor subclasses, give the subclass a chance to run first |
| 1248 | + if any( |
| 1249 | + isinstance(a, torch.Tensor) |
| 1250 | + and type(a) is not torch.Tensor |
| 1251 | + and not isinstance(a, FakeTensor) |
| 1252 | + for a in flat_args |
| 1253 | + ): |
| 1254 | + return NotImplemented |
1161 | 1255 | with mode:
|
1162 | 1256 | Bq, _, _, qk_head_dim = query.shape
|
1163 | 1257 | Bkv, Hkv, seq_len_kv, v_head_dim = value.shape
|
|
0 commit comments