From 46b58839a38c3e2ca144a631f71eaad300d1593d Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Thu, 20 Feb 2025 01:54:47 -0800 Subject: [PATCH] [dtensor][cp] experiment: register flex_attention to a custom fn on DTensor [ghstack-poisoned] --- test/distributed/tensor/test_attention.py | 34 +++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index 457e74cedc25c..899aca4ea4cdc 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -2,10 +2,13 @@ # Owner(s): ["oncall: distributed"] import unittest +from typing import Any, Callable + import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn +from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard from torch.distributed.tensor.debug import CommDebugMode @@ -491,3 +494,34 @@ def causal_mask(b, h, q_idx, kv_idx): v_dist = distribute_tensor(v, device_mesh, [Replicate()]) assert isinstance(q_dist, DTensor) out_dt = flex_attention(q_dist, k_dist, v_dist, block_mask=block_mask) + + +@flex_attention_hop.py_impl(DTensor) +def cp_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + print("Congrats! Flex attention is successfully dispatched!") + + return flex_attention_hop( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + kernel_options=kernel_options, + score_mod_other_buffers=score_mod_other_buffers, + mask_mod_other_buffers=mask_mod_other_buffers, + ) + + +if __name__ == "__main__": + run_tests()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: