-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
Background
Existing Context Parallel API supports 3 ways to replace F.scale_dot_product_attention with a context-parallel version (let's call it cp_sdpa
):
- monkey patch. This approach replaces the SDPA function in
torch.nn.functional
module withcp_sdpa
usingsetattr
. - TorchDispatch. This approach wraps a
TorchDispatchMode
in the python context created when users callingwith context_parallel(...)
. This TD mode will dispatch any seenF.scaled_dot_product_attention
tocp_sdpa
. This approach has been replaced with the third one. - TorchFunction. This approach wraps a
TorchFunctionMode
in the python context created when users callingwith context_parallel(...)
. This TF mode does the same thing as in Approach 2 except doing it in TF mode rather than TD mode.
All these 3 approaches work fine with torch.compile
in a way that the below code runs w/o any issue:
class MyModule(torch.nn.Module):
def forward(self, x):
return F.scaled_dot_product_attention(x)
compiled_module = torch.compile(MyModule(), fullgraph=True)
with context_parallel(...):
out = compiled_module(inp)
Goal
I want to enable the above pattern for flex_attention
so that a single-device flex_attention call
from torch.nn.attention.flex_attention import flex_attention
# simply use causal masking to demonstrate
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
B=1
H=1
Q_LEN=128
KV_LEN=128
q, k, v = (torch.rand(B, H, Q_LEN, KV_LEN) for _ in range(3))
block_mask = create_block_mask(
causal_mask,
B=B,
H=H,
Q_LEN=Q_LEN,
KV_LEN=KV_LEN,
)
compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
out, lse = compiled_flex_attention(q, k, v, block_mask=block_mask, return_lse=True)
can be parallelized with minimal code change using existing CP api:
device_mesh = init_device_mesh(
device_type=...,
mesh_shape=(4,), # init a device mesh of 4 devices for CP
mesh_dim_names=("cp",),
)
with context_parallel(
device_mesh,
buffers=(q, k, v), # shard QKV
buffer_seq_dims=(2, 2, 2), # over seq dim
):
block_mask = create_block_mask(
causal_mask,
B=B,
H=H,
Q_LEN=Q_LEN,
KV_LEN=KV_LEN,
)
compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
out, lse = compiled_flex_attention(q, k, v, block_mask=block_mask, return_lse=True)
However, directly getting the above example work correctly falls into some difficulties, but let's see how I want it to work:
context_parallel
introduces aTorchFunctionMode
that can detect calls toflex_attention
HOP- replace the call and call the context-parallel version
cp_flex_attention
instead within the__torch_function__
cp_flex_attention
takes several steps to calculate the correct attention score:- all-gather KV (this is for simplicity, shouldn't be hard to adopt other algorithms)
- rewrite the
block_mask
object passed toflex_attention
. Because we choose to all-gather KV, the newmask_mod
is a direct translation from the originalblock_mask
by interpreting theq_idx
regarding the local q shard to that in the global view of q.
Challenge 1: BlockMask
[need discussion] Challenge 1.1: flex_attention
function checks QKV shape and the block_mask
argument before HOP dispatch
flex_attention
function performs shape check on QKV and the block_mask
argument which happens before the HOP dispatch. This requires that block_mask
matches with QKV's shard:
with context_parallel(
device_mesh,
buffers=(q, k, v), # shard QKV
buffer_seq_dims=(2, 2, 2), # over seq dim
):
block_mask_post_sharding = create_block_mask(
causal_mask,
B=B,
H=H,
Q_LEN=Q_LEN // 4,
KV_LEN=KV_LEN // 4,
)
compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
out, lse = compiled_flex_attention(q, k, v, block_mask=block_mask_post_sharding, return_lse=True)
[need verification] Challenge 1.2: the same mask_mod
callable somehow makes a difference
However, simply passing block_mask_post_sharding
is not enough because of the causal_mask
argument. For some reason, using block_mask_post_sharding.mask_mod
and block_mask.mask_mod
to create_block_mask introduces different masking on CP's result (TODO: verify in unit test).
[unblocked] Challenge 1.3: cannot call create_block_mask
within TorchFunctionMode
to create cp_block_mask
(i.e. the BlockMask
for query shard, global key, and global value).
My attempt to call create_block_mask
in TorchFunctionMode
caused "Triton error" in create_fw_bw_graph()
.
Challenge 2: make flex_attention
HOP's __call__
correctly dispatch to cp_flex_attention
[need verification, low priority] Challenge 2.1: flex_attention_hop.py_impl
wasn't compatible w/ DTensor (likely is fixed in #151719)
Using py_impl
to register cp_flex_attention
on DTensor
(a torch.Tensor subclass) led to error "RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()". See #147515
This should be fixed in #151719 but I haven't got time to verify because this path requires the use of DTensor for context parallel region which may be not very feasible for some users.
[low priority] Challenge 2.2: TorchDispatchMode isn't compatible with torch.compile
and skips compilation
Unlike doing py_impl
on Tensor subclass, we can py_impl
our cp_flex_attention
on a TorchDispatchMode so that every flex_attention
HOP call within the mode will be redirected to cp_flex_attention
. #151497 is a prototype showing the correctness. Unfortunately, dynamo skips non-infra TorchDispatchMode, see
pytorch/torch/_dynamo/convert_frame.py
Lines 1524 to 1525 in 0decd96
elif is_in_torch_dispatch_mode(include_infra_modes=False): | |
skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile" |
[inactive] Challenge 2.3: Monkey-patching mixes tracing and actual execution
The monkey-patch approach simply replace the HOP's __call__
method with a custom function (in this case it's cp_flex_attention
) and should be avoided if possible. Besides, flex_attention
HOP is called more than one time for a single flex_attention
function call (at least for the first time calling it) and I don't have a simple way to figure out what is the right time to monkey-patch. Overall, this doesn't look like a promising direction.
[unblocked] Challenge 2.4: TorchFunctionMode doesn't support correct HOP dispatch
The first issue is, py_impl
doesn't support TorchFunctionMode
like it supports TorchDispatchMode
.
The second issue is, flex_attention
HOP call is not correctly dispatched to cp_flex_attention
even though in TorchFunctionMode
's __torch_function__
I have the following logic:
def __torch_function__(
self,
func: Callable,
types: Any,
args: tuple[Any, ...] = (),
kwargs: Optional[dict[str, Any]] = None,
) -> Any:
if func == torch._higher_order_ops.flex_attention:
cp_flex_attention(*args, **kwargs)
It still doesn't dispatch to cp_flex_attention
. Thanks to @mlazos his PR #155452 fixes this issue.
The third issue is, torch._higher_order_ops.flex_attention_backward
will not appear in TorchFunctionMode
. One workaround is to insert an Autograd function before calling cp_flex_attention
:
def __torch_function__(
self,
func: Callable,
types: Any,
args: tuple[Any, ...] = (),
kwargs: Optional[dict[str, Any]] = None,
) -> Any:
if func == torch._higher_order_ops.flex_attention:
autograd_KV_allgather(k, v) # perform reduce-scatter on KV's grad in backward
cp_flex_attention(*args, **kwargs)
Challenge 3: Make the process pt2-friendly
[unblocked] Challenge 3.1: Recompiles
When using TorchFunctionMode
, the current usage of context_parallel
API creates one TF mode instance for every input batch. For example, in TorchTitan every input batch is first sharded by context_parallel
then goes through the computation within a TorchFunctionMode
.
When working with dynamo, this pattern triggers "recompilations" and leads to error when recompilations reach its preset limit. One workaround is to create one TF mode and assign to global var. An alternative proposed by @mlazos is to loose the recompile guard on TF mode.
[resolved] Challenge 3.2: current DeviceMesh
is not pt2-friendly
dynamo does not have full support of tracing DeviceMesh
(e.g. get_local_rank()
) and the past CP implementation heavily used those untraceable methods. However, this is resolved by #155441 by 1) replacing tensor sharding with DTensor distribute_tensor
API; and 2) replacing the use of DeviceMesh
with ProcessGroup
.
Conclusion
The overall journey of integrating FlexAttention in Context Parallel has been bumpy and this has exposed gaps between the current Distributed components, HOP, and torch.compile
. Some of them can be addressed by re-designing the Context Parallel, some may require users to make minimal change to code, and some can be enabled on the HOP and torch.compile side.
It would be great if HOP directly supports the following items:
- [high prio] correct HOP dispatch within
TorchFunctionMode
(land [Dynamo] Enable torch function dispatch on HOPs #155452 ) - [high prio, need discussion] Is the "Triton error" thrown from
create_fw_bw_graph()
when callingcreate_block_mask
inTorchFunctionMode
expected or an actual bug? If this is not expected, it's good to support this use case. - [medium prio] is the shape check in
flex_attention
python function skippable? This check forces CP to have users createblock_mask_post_sharding
only to meet the shape requirement (see Challenge 1.1). Ideally the only neededBlockMask
created by user isblock_mask
. - [low prio] correct
py_impl
on Tensor subclass
torch.compile
support:
- [medium prio] correct recompilation check on TorchFunctionMode instance
- [medium prio] better notice of skipped compilation. The skip of FlexAttention within
TorchDispatchMode
is quite hard to capture until directly viewing the generated artifact (see Challenge 2.2).
Distributed support:
- [high prio] FlexAttention + CP via a
TorchFunctionMode
singleton - [medium prio] PT2-friendly DeviceMesh
Proposal
Proposal 1: Introduce a TorchFunctionMode
singleton for replacing SDPA/FlexAttention with their context-parallel implementation
class DistributeFunction(TorchFunctionMode):
def __torch_function__(
self,
func: Callable,
types: Any,
args: tuple[Any, ...] = (),
kwargs: Optional[dict[str, Any]] = None,
) -> Any:
if func == torch._higher_order_ops.flex_attention:
return cp_flex_attention()
if func == F.scaled_dot_product_attention:
return cp_sdpa()
...
# module-level variable
_tf_mode: Optional[TorchFunctionMode] = None
# DistributeFunction is instantiated once and only when calling `context_parallel()` API
def context_parallel():
...
global _tf_mode
if _tf_mode is None:
_tf_mode = DistributeFunction()
...
Proposal 2: For simplicity, only support all-gather based CP + FlexAttention for now, by inserting an Autograd Function CPFlexAttentionPreOp
before calling flex_attention
in cp_flex_attention
class CPFlexAttentionPreOp(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
device_mesh: DeviceMesh,
) -> tuple[Any, ...]:
all-gather key, value
return query, global_key, global_value
@staticmethod
def backward(
ctx: Any,
grad_query: torch.Tensor,
grad_key: torch.Tensor,
grad_value: torch.Tensor,
) -> tuple[Optional[torch.Tensor], ...]:
reduce-scatter grad_key, grad_value
return grad_query, grad_key, grad_value, None
def cp_flex_attention(...):
query, global_key, global_value = CPFlexAttentionPreOp.apply(query, key, value, device_mesh)
return flex_attention(query, global_key, global_value, cp_block_mask, ...)
Proposal 3: [need better solution] Require users to manually create block_mask
and block_mask_post_sharding
. Besides, block_mask
must be stored somewhere accessible by the Context Parallel module (e.g. a module-level variable in torch.distributed.tensor.experimental._attention
) and block_mask_post_sharding
must be passed to flex_attention
calls.
This compromise comes from not being able to create_block_mask
within TorchFunctionMode
(Challenge 1.3) is the most controversial part in this proposal IMO.
UI Change
WIP
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng