Skip to content

Commit 577af2e

Browse files
committed
flex attention: fix dispatch order for tensor subclasses, avoid hardcoding call to faketensor impl in dynamo
ghstack-source-id: 17535e0 Pull Request resolved: #151719
1 parent 4dfcdbe commit 577af2e

File tree

2 files changed

+122
-8
lines changed

2 files changed

+122
-8
lines changed

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2636,7 +2636,8 @@ def call_function(
26362636
args: "list[VariableTracker]",
26372637
kwargs: "dict[str, VariableTracker]",
26382638
) -> "VariableTracker":
2639-
from torch._higher_order_ops.flex_attention import flex_attention_fake_impl
2639+
from torch._higher_order_ops.flex_attention import flex_attention
2640+
from . import TensorVariable
26402641

26412642
from .builder import wrap_fx_proxy
26422643

@@ -2660,6 +2661,31 @@ def call_function(
26602661
tx, query, mask_fn, "mask_fn"
26612662
)
26622663

2664+
def unwrap_proxy_to_faketensor(x):
2665+
if isinstance(x, TupleVariable):
2666+
return pytree.tree_map(unwrap_proxy_to_faketensor, x.items)
2667+
if isinstance(x, (TensorVariable, SymNodeVariable)):
2668+
x_proxy = x.as_proxy()
2669+
return x_proxy.node.meta['example_value']
2670+
else:
2671+
return x.as_python_constant()
2672+
2673+
# use all of the args for faketensor prop
2674+
vt_full_args = [
2675+
query,
2676+
key,
2677+
value,
2678+
score_mod,
2679+
block_mask,
2680+
scale,
2681+
kernel_options,
2682+
]
2683+
all_fake_args = pytree.tree_map(unwrap_proxy_to_faketensor, vt_full_args)
2684+
2685+
with torch._guards.TracingContext.try_get().fake_mode:
2686+
out_meta, lse_meta = flex_attention(*all_fake_args)
2687+
example_value = (out_meta, lse_meta)
2688+
26632689
proxied_args = [
26642690
query,
26652691
key,
@@ -2674,12 +2700,6 @@ def call_function(
26742700
# Proxying user defined functions is not supported.
26752701
inp_args, _ = proxy_args_kwargs(proxied_args, {})
26762702

2677-
query_meta = query.as_proxy().node.meta["example_value"]
2678-
value_meta = value.as_proxy().node.meta["example_value"]
2679-
with torch._guards.TracingContext.try_get().fake_mode:
2680-
out_meta, lse_meta = flex_attention_fake_impl(query_meta, value_meta)
2681-
example_value = (out_meta, lse_meta)
2682-
26832703
# Compose the ordered HOO args:
26842704
# - inp_args: [query, key, value, block_mask, scale, kernel_options]
26852705
# - subgraph node: [score_mod, mask_fn_node]

torch/_higher_order_ops/flex_attention.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
validate_subgraph_args_types,
1818
)
1919
from torch._ops import HigherOrderOperator
20-
from torch._subclasses import FakeTensorMode
20+
from torch._subclasses import FakeTensor, FakeTensorMode
21+
from torch._subclasses.functional_tensor import FunctionalTensor
2122
from torch.fx.experimental.proxy_tensor import (
2223
make_fx,
2324
ProxyTorchDispatchMode,
@@ -396,6 +397,29 @@ def flex_attention_functionalize(
396397
"""
397398
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
398399

400+
flat_args, _ = pytree.tree_flatten(
401+
(
402+
query,
403+
key,
404+
value,
405+
score_mod,
406+
block_mask,
407+
scale,
408+
kernel_options,
409+
score_mod_other_buffers,
410+
mask_mod_other_buffers,
411+
)
412+
)
413+
# For tensor subclasses, give the subclass a chance to run first
414+
if any(
415+
isinstance(a, torch.Tensor)
416+
and type(a) is not torch.Tensor
417+
and not isinstance(a, FakeTensor)
418+
and not isinstance(a, FunctionalTensor)
419+
for a in flat_args
420+
):
421+
return NotImplemented
422+
399423
query_unwrapped = ctx.unwrap_tensors(query)
400424
key_unwrapped = ctx.unwrap_tensors(key)
401425
value_unwrapped = ctx.unwrap_tensors(value)
@@ -473,6 +497,27 @@ def flex_attention_fake_tensor_mode(
473497
score_mod_other_buffers: tuple = (),
474498
mask_mod_other_buffers: tuple = (),
475499
) -> tuple[torch.Tensor, torch.Tensor]:
500+
flat_args, _ = pytree.tree_flatten(
501+
(
502+
query,
503+
key,
504+
value,
505+
score_mod,
506+
block_mask,
507+
scale,
508+
kernel_options,
509+
score_mod_other_buffers,
510+
mask_mod_other_buffers,
511+
)
512+
)
513+
# For tensor subclasses, give the subclass a chance to run first
514+
if any(
515+
isinstance(a, torch.Tensor)
516+
and type(a) is not torch.Tensor
517+
and not isinstance(a, FakeTensor)
518+
for a in flat_args
519+
):
520+
return NotImplemented
476521
with mode:
477522
out, logsumexp = flex_attention_fake_impl(query, value)
478523
return out, logsumexp
@@ -1086,6 +1131,31 @@ def flex_attention_backward_functionalize(
10861131
since we know that the forward score mod function is assured to be free of mutations
10871132
to the other_buffers, we skip that mutate check and go straight to redispatching.
10881133
"""
1134+
flat_args, _ = pytree.tree_flatten(
1135+
(
1136+
query,
1137+
key,
1138+
value,
1139+
out,
1140+
logsumexp,
1141+
grad_out,
1142+
grad_logsumexp,
1143+
block_mask,
1144+
scale,
1145+
kernel_options,
1146+
score_mod_other_buffers,
1147+
mask_mod_other_buffers,
1148+
)
1149+
)
1150+
# For tensor subclasses, give the subclass a chance to run first
1151+
if any(
1152+
isinstance(a, torch.Tensor)
1153+
and type(a) is not torch.Tensor
1154+
and not isinstance(a, FakeTensor)
1155+
and not isinstance(a, FunctionalTensor)
1156+
for a in flat_args
1157+
):
1158+
return NotImplemented
10891159
query_unwrapped = ctx.unwrap_tensors(query)
10901160
key_unwrapped = ctx.unwrap_tensors(key)
10911161
value_unwrapped = ctx.unwrap_tensors(value)
@@ -1158,6 +1228,30 @@ def flex_attention_backward_fake_tensor_mode(
11581228
) -> tuple[
11591229
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
11601230
]:
1231+
flat_args, _ = pytree.tree_flatten(
1232+
(
1233+
query,
1234+
key,
1235+
value,
1236+
out,
1237+
logsumexp,
1238+
grad_out,
1239+
grad_logsumexp,
1240+
block_mask,
1241+
scale,
1242+
kernel_options,
1243+
score_mod_other_buffers,
1244+
mask_mod_other_buffers,
1245+
)
1246+
)
1247+
# For tensor subclasses, give the subclass a chance to run first
1248+
if any(
1249+
isinstance(a, torch.Tensor)
1250+
and type(a) is not torch.Tensor
1251+
and not isinstance(a, FakeTensor)
1252+
for a in flat_args
1253+
):
1254+
return NotImplemented
11611255
with mode:
11621256
Bq, _, _, qk_head_dim = query.shape
11631257
Bkv, Hkv, seq_len_kv, v_head_dim = value.shape

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy