Skip to content

Flex Attention is incompatible with selective AC #147879

@fegin

Description

@fegin

🐛 Describe the bug

When using FlexAttention with selective activation checkpointing, we got an error as below

  traceback : Traceback (most recent call last):
    File "/data/users/chienchin/mywork/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 354, in wrapper
      return f(*args, **kwargs)
    File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/train.py", line 306, in main
      pred = model(input_ids)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1857, in _call_impl
      return inner()
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1805, in inner
      result = forward_call(*args, **kwargs)
    File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/torchtitan/models/llama/model.py", line 478, in forward
      h = layer(h, self.freqs_cis)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1857, in _call_impl
      return inner()
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1805, in inner
      result = forward_call(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
      return self.checkpoint_fn(  # type: ignore[misc]
    File "/data/users/chienchin/mywork/pytorch/torch/_compile.py", line 51, in inner
      return disable_fn(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/eval_frame.py", line 764, in _fn
      return fn(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/utils/checkpoint.py", line 495, in checkpoint
      ret = function(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/torchtitan/models/llama/model.py", line 359, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
    File "/data/users/chienchin/fbsource/fbcode/pytorch/torchtitan/torchtitan/models/llama/model.py", line 230, in forward
      output = flex_attention(xq, xk, xv, block_mask=self.block_mask)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/attention/flex_attention.py", line 1357, in flex_attention
      out, lse = torch.compile(
    File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/eval_frame.py", line 585, in _fn
      return fn(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/nn/attention/flex_attention.py", line 1345, in _flex_attention_hop_wrapper
      return flex_attention_hop(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__
      return super().__call__(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 471, in __call__
      return wrapper()
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 467, in wrapper
      return self.dispatch(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 455, in dispatch
      return kernel(*args, **kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 744, in flex_attention_autograd
      out, logsumexp = FlexAttentionAutogradOp.apply(
    File "/data/users/chienchin/mywork/pytorch/torch/autograd/function.py", line 575, in apply
      return super().apply(*args, **kwargs)  # type: ignore[misc]
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 610, in forward
      out, logsumexp = flex_attention(
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__
      return super().__call__(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 471, in __call__
      return wrapper()
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 462, in wrapper
      return torch.overrides.handle_torch_function(
    File "/data/users/chienchin/mywork/pytorch/torch/overrides.py", line 1721, in handle_torch_function
      result = mode.__torch_function__(public_api, types, args, kwargs)
    File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 142, in __torch_function__
      return func(*args, **(kwargs or {}))
    File "/data/users/chienchin/mywork/pytorch/torch/_higher_order_ops/flex_attention.py", line 92, in __call__
      return super().__call__(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 471, in __call__
      return wrapper()
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 467, in wrapper
      return self.dispatch(
    File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 365, in dispatch
      raise NotImplementedError(
  NotImplementedError: There was no rule registered for HOP flex_attention and mode <torch.utils.checkpoint._CachingTorchDispatchMode object at 0x7f3e5cc0fac0>. We recommend filing an issue.

This issue can be reproduced with pytorch/torchtitan#887 and CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --model.use_flex_attn

Note that full activation checkpointing doesn't cause this issue.

Versions

nightly

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy