Skip to content

Open to contribution: adding torch.nn.functional.scaled_dot_product_attention support for more architectures #28005

@fxmarty

Description

@fxmarty

Feature request

In Transformers 4.36, we started adding native support of torch.nn.functional.scaled_dot_product_attention (SDPA), enabled by default in Transformers: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention

SDPA allows to dispatch to memory-efficient attention, flash attention on supported GPUs (currently NVIDIA-only), and even on Intel CPUs.

For the record, here's a benchmark on some currently supported models:

Training benchmark, run on A100-SXM4-80GB.

Model Batch size Sequence length Time per batch ("eager", s) Time per batch ("sdpa", s) Speedup Peak memory ("eager", MB) Peak memory ("sdpa", MB) Memory savings
llama2 7b 4 1024 1.065 0.90 19.4% 73878.28 45977.81 60.7%
llama2 7b 4 2048 OOM 1.87 / OOM 78394.58 SDPA does not OOM
llama2 7b 1 2048 0.64 0.48 32.0% 55557.01 29795.63 86.4%
llama2 7b 1 3072 OOM 0.75 / OOM 37916.08 SDPA does not OOM
llama2 7b 1 4096 OOM 1.03 / OOM 46028.14 SDPA does not OOM
llama2 7b 2 4096 OOM 2.05 / OOM 78428.14 SDPA does not OOM

Inference benchmark, run on A100-SXM4-80GB.

Model Batch size Prompt length Num new tokens Per token latency "eager" (ms) Per token latency "sdpa" (ms) Speedup
llama2 13b 1 1024 1 (prefill) 178.66 159.36 12.11%
llama2 13b 1 100 100 40.35 37.62 7.28%
llama2 13b 8 100 100 40.55 38.06 6.53%
Whisper v3 large 1 / 62 20.05 18.90 6.10%
Whisper v3 large 8 / 77 25.42 24.77 2.59%
Whisper v3 large 16 / 77 28.51 26.32 8.34%

Previously, we had a partial support of SDPA in Optimum BetterTransformer but we are now looking to slowly deprecate it in favor of upstream support of SDPA directly in Transformers.

Here are the architectures for which support has been requested:

The integration could take inspiration from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/decoder_models.py & https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py

Motivation

Faster training & inference, lower memory requirement

Your contribution

I may work on some at some point, but contributions are most welcome.

You should refer to #26572 to add the support of SDPA for a model, roughly following these steps:

  • Create a XxxSdpaAttention class inheriting from XxxAttention and implement the attention logic using SDPA
  • Use _prepare_4d_causal_attention_mask_for_sdpa instead of _prepare_4d_causal_attention_mask for SDPA
  • Use _prepare_4d_attention_mask_for_sdpa instead of _prepare_4d_attention_mask for SDPA
  • Add _supports_sdpa = True to XxxPreTrainedModel
  • Add "sdpa" key to XXX_ATTENTION_CLASSES in the model modeling file

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy