Skip to content

[WIP][RFC] Compilable flex_attention + Context Parallel #157015

@XilunWu

Description

@XilunWu

Background

Existing Context Parallel API supports 3 ways to replace F.scale_dot_product_attention with a context-parallel version (let's call it cp_sdpa):

  1. monkey patch. This approach replaces the SDPA function in torch.nn.functional module with cp_sdpa using setattr.
  2. TorchDispatch. This approach wraps a TorchDispatchMode in the python context created when users calling with context_parallel(...). This TD mode will dispatch any seen F.scaled_dot_product_attention to cp_sdpa. This approach has been replaced with the third one.
  3. TorchFunction. This approach wraps a TorchFunctionMode in the python context created when users calling with context_parallel(...). This TF mode does the same thing as in Approach 2 except doing it in TF mode rather than TD mode.

All these 3 approaches work fine with torch.compile in a way that the below code runs w/o any issue:

class MyModule(torch.nn.Module):
    def forward(self, x):
        return F.scaled_dot_product_attention(x)

compiled_module = torch.compile(MyModule(), fullgraph=True)

with context_parallel(...):
    out = compiled_module(inp)

Goal

I want to enable the above pattern for flex_attention so that a single-device flex_attention call

from torch.nn.attention.flex_attention import flex_attention

# simply use causal masking to demonstrate
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

B=1
H=1
Q_LEN=128
KV_LEN=128
q, k, v = (torch.rand(B, H, Q_LEN, KV_LEN) for _ in range(3))

block_mask = create_block_mask(
    causal_mask,
    B=B,
    H=H,
    Q_LEN=Q_LEN,
    KV_LEN=KV_LEN,
)
compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
out, lse = compiled_flex_attention(q, k, v, block_mask=block_mask, return_lse=True)

can be parallelized with minimal code change using existing CP api:

device_mesh = init_device_mesh(
    device_type=...,
    mesh_shape=(4,),    # init a device mesh of 4 devices for CP
    mesh_dim_names=("cp",),
)

with context_parallel(
    device_mesh,
    buffers=(q, k, v),    # shard QKV
    buffer_seq_dims=(2, 2, 2),    # over seq dim
):
    block_mask = create_block_mask(
        causal_mask,
        B=B,
        H=H,
        Q_LEN=Q_LEN,
        KV_LEN=KV_LEN,
    )

    compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
    out, lse = compiled_flex_attention(q, k, v, block_mask=block_mask, return_lse=True)

However, directly getting the above example work correctly falls into some difficulties, but let's see how I want it to work:

  1. context_parallel introduces a TorchFunctionMode that can detect calls to flex_attention HOP
  2. replace the call and call the context-parallel version cp_flex_attention instead within the __torch_function__
  3. cp_flex_attention takes several steps to calculate the correct attention score:
    • all-gather KV (this is for simplicity, shouldn't be hard to adopt other algorithms)
    • rewrite the block_mask object passed to flex_attention. Because we choose to all-gather KV, the new mask_mod is a direct translation from the original block_mask by interpreting the q_idx regarding the local q shard to that in the global view of q.

Challenge 1: BlockMask

[need discussion] Challenge 1.1: flex_attention function checks QKV shape and the block_mask argument before HOP dispatch

flex_attention function performs shape check on QKV and the block_mask argument which happens before the HOP dispatch. This requires that block_mask matches with QKV's shard:

with context_parallel(
    device_mesh,
    buffers=(q, k, v),    # shard QKV
    buffer_seq_dims=(2, 2, 2),    # over seq dim
):
    block_mask_post_sharding = create_block_mask(
        causal_mask,
        B=B,
        H=H,
        Q_LEN=Q_LEN  // 4,
        KV_LEN=KV_LEN // 4,
    )

    compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
    out, lse = compiled_flex_attention(q, k, v, block_mask=block_mask_post_sharding, return_lse=True)

[need verification] Challenge 1.2: the same mask_mod callable somehow makes a difference

However, simply passing block_mask_post_sharding is not enough because of the causal_mask argument. For some reason, using block_mask_post_sharding.mask_mod and block_mask.mask_mod to create_block_mask introduces different masking on CP's result (TODO: verify in unit test).

[unblocked] Challenge 1.3: cannot call create_block_mask within TorchFunctionMode to create cp_block_mask (i.e. the BlockMask for query shard, global key, and global value).

My attempt to call create_block_mask in TorchFunctionMode caused "Triton error" in create_fw_bw_graph().

Challenge 2: make flex_attention HOP's __call__ correctly dispatch to cp_flex_attention

[need verification, low priority] Challenge 2.1: flex_attention_hop.py_impl wasn't compatible w/ DTensor (likely is fixed in #151719)

Using py_impl to register cp_flex_attention on DTensor (a torch.Tensor subclass) led to error "RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()". See #147515

This should be fixed in #151719 but I haven't got time to verify because this path requires the use of DTensor for context parallel region which may be not very feasible for some users.

[low priority] Challenge 2.2: TorchDispatchMode isn't compatible with torch.compile and skips compilation

Unlike doing py_impl on Tensor subclass, we can py_impl our cp_flex_attention on a TorchDispatchMode so that every flex_attention HOP call within the mode will be redirected to cp_flex_attention. #151497 is a prototype showing the correctness. Unfortunately, dynamo skips non-infra TorchDispatchMode, see

elif is_in_torch_dispatch_mode(include_infra_modes=False):
skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile"

[inactive] Challenge 2.3: Monkey-patching mixes tracing and actual execution

The monkey-patch approach simply replace the HOP's __call__ method with a custom function (in this case it's cp_flex_attention) and should be avoided if possible. Besides, flex_attention HOP is called more than one time for a single flex_attention function call (at least for the first time calling it) and I don't have a simple way to figure out what is the right time to monkey-patch. Overall, this doesn't look like a promising direction.

[unblocked] Challenge 2.4: TorchFunctionMode doesn't support correct HOP dispatch

The first issue is, py_impl doesn't support TorchFunctionMode like it supports TorchDispatchMode.

The second issue is, flex_attention HOP call is not correctly dispatched to cp_flex_attention even though in TorchFunctionMode's __torch_function__ I have the following logic:

def __torch_function__(
    self,
    func: Callable,
    types: Any,
    args: tuple[Any, ...] = (),
    kwargs: Optional[dict[str, Any]] = None,
) -> Any:
    if func == torch._higher_order_ops.flex_attention:
        cp_flex_attention(*args, **kwargs)

It still doesn't dispatch to cp_flex_attention. Thanks to @mlazos his PR #155452 fixes this issue.

The third issue is, torch._higher_order_ops.flex_attention_backward will not appear in TorchFunctionMode. One workaround is to insert an Autograd function before calling cp_flex_attention:

def __torch_function__(
    self,
    func: Callable,
    types: Any,
    args: tuple[Any, ...] = (),
    kwargs: Optional[dict[str, Any]] = None,
) -> Any:
    if func == torch._higher_order_ops.flex_attention:
        autograd_KV_allgather(k, v)  # perform reduce-scatter on KV's grad in backward
        cp_flex_attention(*args, **kwargs)

Challenge 3: Make the process pt2-friendly

[unblocked] Challenge 3.1: Recompiles

When using TorchFunctionMode, the current usage of context_parallel API creates one TF mode instance for every input batch. For example, in TorchTitan every input batch is first sharded by context_parallel then goes through the computation within a TorchFunctionMode.

When working with dynamo, this pattern triggers "recompilations" and leads to error when recompilations reach its preset limit. One workaround is to create one TF mode and assign to global var. An alternative proposed by @mlazos is to loose the recompile guard on TF mode.

[resolved] Challenge 3.2: current DeviceMesh is not pt2-friendly

dynamo does not have full support of tracing DeviceMesh (e.g. get_local_rank()) and the past CP implementation heavily used those untraceable methods. However, this is resolved by #155441 by 1) replacing tensor sharding with DTensor distribute_tensor API; and 2) replacing the use of DeviceMesh with ProcessGroup.

Conclusion

The overall journey of integrating FlexAttention in Context Parallel has been bumpy and this has exposed gaps between the current Distributed components, HOP, and torch.compile. Some of them can be addressed by re-designing the Context Parallel, some may require users to make minimal change to code, and some can be enabled on the HOP and torch.compile side.

It would be great if HOP directly supports the following items:

  1. [high prio] correct HOP dispatch within TorchFunctionMode (land [Dynamo] Enable torch function dispatch on HOPs #155452 )
  2. [high prio, need discussion] Is the "Triton error" thrown from create_fw_bw_graph() when calling create_block_mask in TorchFunctionMode expected or an actual bug? If this is not expected, it's good to support this use case.
  3. [medium prio] is the shape check in flex_attention python function skippable? This check forces CP to have users create block_mask_post_sharding only to meet the shape requirement (see Challenge 1.1). Ideally the only needed BlockMask created by user is block_mask.
  4. [low prio] correct py_impl on Tensor subclass

torch.compile support:

  1. [medium prio] correct recompilation check on TorchFunctionMode instance
  2. [medium prio] better notice of skipped compilation. The skip of FlexAttention within TorchDispatchMode is quite hard to capture until directly viewing the generated artifact (see Challenge 2.2).

Distributed support:

  1. [high prio] FlexAttention + CP via a TorchFunctionMode singleton
  2. [medium prio] PT2-friendly DeviceMesh

Proposal

Proposal 1: Introduce a TorchFunctionMode singleton for replacing SDPA/FlexAttention with their context-parallel implementation

class DistributeFunction(TorchFunctionMode):
    def __torch_function__(
        self,
        func: Callable,
        types: Any,
        args: tuple[Any, ...] = (),
        kwargs: Optional[dict[str, Any]] = None,
    ) -> Any:
        if func == torch._higher_order_ops.flex_attention:
            return cp_flex_attention()

        if func == F.scaled_dot_product_attention:
            return cp_sdpa()

        ...

# module-level variable
_tf_mode: Optional[TorchFunctionMode] = None

# DistributeFunction is instantiated once and only when calling `context_parallel()` API
def context_parallel():
    ...

    global _tf_mode
    if _tf_mode is None:
        _tf_mode = DistributeFunction()

    ...

Proposal 2: For simplicity, only support all-gather based CP + FlexAttention for now, by inserting an Autograd Function CPFlexAttentionPreOp before calling flex_attention in cp_flex_attention

class CPFlexAttentionPreOp(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx: Any,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        device_mesh: DeviceMesh,
    ) -> tuple[Any, ...]:
        all-gather key, value
        return query, global_key, global_value

    @staticmethod
    def backward(
        ctx: Any,
        grad_query: torch.Tensor,
        grad_key: torch.Tensor,
        grad_value: torch.Tensor,
    ) -> tuple[Optional[torch.Tensor], ...]:
        reduce-scatter grad_key, grad_value
        return grad_query, grad_key, grad_value, None


def cp_flex_attention(...):
    query, global_key, global_value = CPFlexAttentionPreOp.apply(query, key, value, device_mesh)
    return flex_attention(query, global_key, global_value, cp_block_mask, ...)

Proposal 3: [need better solution] Require users to manually create block_mask and block_mask_post_sharding. Besides, block_mask must be stored somewhere accessible by the Context Parallel module (e.g. a module-level variable in torch.distributed.tensor.experimental._attention) and block_mask_post_sharding must be passed to flex_attention calls.

This compromise comes from not being able to create_block_mask within TorchFunctionMode (Challenge 1.3) is the most controversial part in this proposal IMO.

UI Change

WIP

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

Labels

module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    pFad - Phonifier reborn

    Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

    Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


    Alternative Proxies:

    Alternative Proxy

    pFad Proxy

    pFad v3 Proxy

    pFad v4 Proxy