diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index db6f638105ec..8c14c1a04829 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -2,10 +2,13 @@ # Owner(s): ["oncall: distributed"] import unittest +from typing import Any, Callable + import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn +from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard from torch.distributed.tensor.debug import CommDebugMode @@ -499,3 +502,34 @@ def causal_mask(b, h, q_idx, kv_idx): v_dist = distribute_tensor(v, device_mesh, [Replicate()]) assert isinstance(q_dist, DTensor) out_dt = flex_attention(q_dist, k_dist, v_dist, block_mask=block_mask) + + +@flex_attention_hop.py_impl(DTensor) +def cp_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + print("Congrats! Flex attention is successfully dispatched!") + + return flex_attention_hop( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + kernel_options=kernel_options, + score_mod_other_buffers=score_mod_other_buffers, + mask_mod_other_buffers=mask_mod_other_buffers, + ) + + +if __name__ == "__main__": + run_tests()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: