Skip to content

[dtensor][cp] experiment: register flex_attention to a custom fn on DTensor #147515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[dtensor][cp] experiment: register flex_attention to a custom fn on D…
…Tensor

[ghstack-poisoned]
  • Loading branch information
XilunWu committed Feb 20, 2025
commit 46b58839a38c3e2ca144a631f71eaad300d1593d
34 changes: 34 additions & 0 deletions test/distributed/tensor/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# Owner(s): ["oncall: distributed"]
import unittest

from typing import Any, Callable

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard
from torch.distributed.tensor.debug import CommDebugMode
Expand Down Expand Up @@ -491,3 +494,34 @@ def causal_mask(b, h, q_idx, kv_idx):
v_dist = distribute_tensor(v, device_mesh, [Replicate()])
assert isinstance(q_dist, DTensor)
out_dt = flex_attention(q_dist, k_dist, v_dist, block_mask=block_mask)


@flex_attention_hop.py_impl(DTensor)
def cp_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: tuple,
scale: float,
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
print("Congrats! Flex attention is successfully dispatched!")

return flex_attention_hop(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
scale=scale,
kernel_options=kernel_options,
score_mod_other_buffers=score_mod_other_buffers,
mask_mod_other_buffers=mask_mod_other_buffers,
)


if __name__ == "__main__":
run_tests()
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy