Skip to content

[cp] dispatch flex_attention to CP impl in TorchDispatchMode #151497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: gh/XilunWu/133/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6bfb512
DTensor HOP dispatch
XilunWu Apr 17, 2025
0e1171f
Update on "DTensor HOP dispatch"
XilunWu Apr 17, 2025
efc5e67
Update on "DTensor HOP dispatch"
XilunWu Apr 17, 2025
4f8e9d4
Update on "DTensor HOP dispatch"
XilunWu Apr 18, 2025
19ac1ce
Update on "DTensor HOP dispatch"
XilunWu Apr 18, 2025
9ecc674
Update on "DTensor HOP dispatch"
XilunWu Apr 22, 2025
c71ff48
Update
XilunWu Apr 24, 2025
2836766
Update
XilunWu Apr 24, 2025
103f0ef
Update
XilunWu Apr 24, 2025
bdda34a
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu Apr 27, 2025
929f157
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu Apr 27, 2025
0ba9169
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu Apr 28, 2025
644ec19
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 9, 2025
048010a
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 9, 2025
6ad5a79
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 21, 2025
bced5f0
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 27, 2025
177f177
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 27, 2025
446fb80
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 29, 2025
ddd1ab5
Update
XilunWu May 30, 2025
43c91ff
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 30, 2025
362b6e8
Update on "[cp] dispatch flex_attention to CP impl in TorchDispatchMode"
XilunWu May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 147 additions & 8 deletions test/distributed/tensor/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental._attention import (
_AttentionContextParallel,
_CausalBehavior,
_cp_options,
_DispatchMode,
_FlexAttentionSequentialSharder,
_is_causal_behavior,
_RotateMethod,
_set_dispatch_mode,
context_parallel,
context_parallel_unshard,
set_rotate_method,
Expand All @@ -27,7 +30,10 @@
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
skip_if_lt_world_size,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
Expand Down Expand Up @@ -78,8 +84,8 @@ def test_ring_attention_sdpa(self) -> None:
"rotater": [_RotateMethod.ALL_TO_ALL, _RotateMethod.ALL_GATHER],
"test_forward_only": [True, False],
"dispatch_mode": [
_DispatchMode.MONKEY_PATCH,
_DispatchMode.TORCH_FUNCTION,
"monkey_patch",
"torch_function",
],
},
self._test_ring_attention_sdpa,
Expand All @@ -93,9 +99,11 @@ def _test_ring_attention_sdpa(
load_balance: bool,
rotater: _RotateMethod,
test_forward_only: bool,
dispatch_mode: _DispatchMode,
dispatch_mode: str,
) -> None:
torch.distributed.tensor.experimental._attention._dispatch_mode = dispatch_mode
torch.distributed.tensor.experimental._attention._set_dispatch_mode(
dispatch_mode
)

def fn_eval(fn, *args, **kwargs):
if test_forward_only:
Expand Down Expand Up @@ -165,7 +173,7 @@ def fn_eval(fn, *args, **kwargs):
# now. So we can just use context_parallel() to shard q, k, v.
# In reality, context_paralle() should be used to shard the input.
with context_parallel(
device_mesh, buffers=(cp_q, cp_k, cp_v), buffer_seq_dims=(2, 2, 2)
device_mesh, buffers=[cp_q, cp_k, cp_v], buffer_seq_dims=[2, 2, 2]
):
cp_q.requires_grad = True
cp_k.requires_grad = True
Expand Down Expand Up @@ -228,8 +236,9 @@ def fn_eval(fn, *args, **kwargs):
cp_k.requires_grad = False
cp_v.requires_grad = False

torch.distributed.tensor.experimental._attention._dispatch_mode = (
_DispatchMode.MONKEY_PATCH
# reset to the default mode
torch.distributed.tensor.experimental._attention._set_dispatch_mode(
"monkey_patch"
)

def test_is_causal_behavior(self) -> None:
Expand Down Expand Up @@ -437,5 +446,135 @@ def _test_ring_attention_custom_transformer(self, rotater: _RotateMethod) -> Non
)


class RingFlexAttentionTest(DTensorTestBase):
@property
def world_size(self) -> int:
return torch.cuda.device_count() if torch.cuda.is_available() else 4

@skip_if_lt_world_size()
@with_comms
def test_ring_flex_attention(self) -> None:
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

from torch.nn.attention.flex_attention import create_block_mask, flex_attention

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)
Q_BLOCK_SIZE_DEFAULT = 128
KV_BLOCK_SIZE_DEFAULT = Q_BLOCK_SIZE_DEFAULT

torch.cuda.manual_seed(10)
dtype = torch.float32
bs = 8
query_tokens = Q_BLOCK_SIZE_DEFAULT * self.world_size
context_tokens = KV_BLOCK_SIZE_DEFAULT * self.world_size
dim = 32
nheads = 8

q = torch.rand(
(bs, nheads, query_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
k = torch.rand(
(bs, nheads, context_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)
v = torch.rand(
(bs, nheads, context_tokens, dim),
device=self.device_type,
dtype=dtype,
requires_grad=True,
)

block_mask = create_block_mask(
causal_mask,
B=1,
H=1,
Q_LEN=query_tokens,
KV_LEN=context_tokens,
device=self.device_type,
)

expect_out, expect_lse = flex_attention(
q, k, v, block_mask=block_mask, return_lse=True
)

# test flex attention on DTensor
device_mesh = init_device_mesh(
device_type=self.device_type,
mesh_shape=(self.world_size,),
mesh_dim_names=("cp",),
)

q_local_size = list(q.shape)
q_local_size[2] //= self.world_size
k_local_size = list(k.shape)
k_local_size[2] //= self.world_size

# this is the block_mask created within the training step
# NOTE: flex_attention checks block_mask shape and input shape before
# calling into flex_attention_hop.
block_mask_post_sharding = create_block_mask(
causal_mask,
B=1,
H=1,
Q_LEN=q_local_size[2],
KV_LEN=k_local_size[2],
device=self.device_type,
)

# set CP context dispatch mode to use TorchDispatchMode for flex_attention
_set_dispatch_mode("torch_dispatch")
assert (
torch.distributed.tensor.experimental._attention._dispatch_mode
== _DispatchMode.TORCH_DISPATCH
)

# prepare input buffer
cp_q = q.detach().clone()
cp_k = k.detach().clone()
cp_v = v.detach().clone()

# create sharder
sharder = _FlexAttentionSequentialSharder(device_mesh, block_mask=block_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think exposing a way for user to control how the mask is being partitioned is fine. But I don't think we should do it in the current way, specifically I don't think we should leak the concept of "Sharder" as it would confuse with user its relationship with the Shard placement in DTensor.


with CommDebugMode() as comm_mode:
with context_parallel(
device_mesh,
buffers=[cp_q, cp_k, cp_v],
buffer_seq_dims=[2, 2, 2],
sharder=sharder,
):
cp_out, cp_lse = flex_attention(
cp_q,
cp_k,
cp_v,
block_mask=block_mask_post_sharding,
return_lse=True,
)

self.assertDictEqual(
comm_mode.get_comm_counts(),
{
c10d_functional.all_gather_into_tensor: 2
}, # currently we have k and v all-gather separate
)

# unshard the output
cp_out, cp_lse = context_parallel_unshard(
device_mesh, [cp_out, cp_lse], [2, 2], sharder
)
torch.testing.assert_close(cp_out, expect_out, atol=1e-6, rtol=1e-2)
torch.testing.assert_close(cp_lse, expect_lse, atol=1e-6, rtol=1e-2)

# reset to the default mode
_set_dispatch_mode("monkey_patch")


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torch/distributed/tensor/_ops/_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def default_strategy(op_schema: OpSchema) -> StrategyType:
aten.contiguous.default,
aten.copy_.default,
aten.detach.default,
aten.detach_.default,
aten.fill_.Scalar,
aten.view.dtype,
aten.zero_.default,
Expand Down
30 changes: 29 additions & 1 deletion torch/distributed/tensor/debug/_comm_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import re
import weakref
from collections import defaultdict
from typing import Any
from typing import Any, Callable

import torch
import torch.nn
from torch._guards import detect_fake_mode
from torch._higher_order_ops import flex_attention as flex_attention_hop
from torch.autograd.graph import register_multi_grad_hook
from torch.distributed._tools.mod_tracker import ModTracker
from torch.distributed.tensor._api import DTensor
Expand Down Expand Up @@ -733,3 +734,30 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
].append(operation_dict)

return out


# register flex_attention HOP to CommDebugMode
@flex_attention_hop.py_impl(CommDebugMode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why one need to register a hop here?

def flex_attention_comm_debug_mode(
mode: CommDebugMode,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: tuple,
scale: float,
kernel_options: dict[str, Any],
score_mod_other_buffers: tuple = (),
mask_mod_other_buffers: tuple = (),
) -> tuple[torch.Tensor, torch.Tensor]:
return flex_attention_hop(
query,
key,
value,
score_mod,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
Loading
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy