diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index db6f638105ec..8c14c1a04829 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -2,10 +2,13 @@ # Owner(s): ["oncall: distributed"] import unittest +from typing import Any, Callable + import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn +from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard from torch.distributed.tensor.debug import CommDebugMode @@ -499,3 +502,34 @@ def causal_mask(b, h, q_idx, kv_idx): v_dist = distribute_tensor(v, device_mesh, [Replicate()]) assert isinstance(q_dist, DTensor) out_dt = flex_attention(q_dist, k_dist, v_dist, block_mask=block_mask) + + +@flex_attention_hop.py_impl(DTensor) +def cp_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + print("Congrats! Flex attention is successfully dispatched!") + + return flex_attention_hop( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + kernel_options=kernel_options, + score_mod_other_buffers=score_mod_other_buffers, + mask_mod_other_buffers=mask_mod_other_buffers, + ) + + +if __name__ == "__main__": + run_tests() pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy