-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Description
Feature request
In Transformers 4.36
, we started adding native support of torch.nn.functional.scaled_dot_product_attention (SDPA), enabled by default in Transformers: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention
SDPA allows to dispatch to memory-efficient attention, flash attention on supported GPUs (currently NVIDIA-only), and even on Intel CPUs.
For the record, here's a benchmark on some currently supported models:
Training benchmark, run on A100-SXM4-80GB.
Model | Batch size | Sequence length | Time per batch ("eager" , s) |
Time per batch ("sdpa" , s) |
Speedup | Peak memory ("eager" , MB) |
Peak memory ("sdpa" , MB) |
Memory savings |
---|---|---|---|---|---|---|---|---|
llama2 7b | 4 | 1024 | 1.065 | 0.90 | 19.4% | 73878.28 | 45977.81 | 60.7% |
llama2 7b | 4 | 2048 | OOM | 1.87 | / | OOM | 78394.58 | SDPA does not OOM |
llama2 7b | 1 | 2048 | 0.64 | 0.48 | 32.0% | 55557.01 | 29795.63 | 86.4% |
llama2 7b | 1 | 3072 | OOM | 0.75 | / | OOM | 37916.08 | SDPA does not OOM |
llama2 7b | 1 | 4096 | OOM | 1.03 | / | OOM | 46028.14 | SDPA does not OOM |
llama2 7b | 2 | 4096 | OOM | 2.05 | / | OOM | 78428.14 | SDPA does not OOM |
Inference benchmark, run on A100-SXM4-80GB.
Model | Batch size | Prompt length | Num new tokens | Per token latency "eager" (ms) |
Per token latency "sdpa" (ms) |
Speedup |
---|---|---|---|---|---|---|
llama2 13b | 1 | 1024 | 1 (prefill) | 178.66 | 159.36 | 12.11% |
llama2 13b | 1 | 100 | 100 | 40.35 | 37.62 | 7.28% |
llama2 13b | 8 | 100 | 100 | 40.55 | 38.06 | 6.53% |
Whisper v3 large | 1 | / | 62 | 20.05 | 18.90 | 6.10% |
Whisper v3 large | 8 | / | 77 | 25.42 | 24.77 | 2.59% |
Whisper v3 large | 16 | / | 77 | 28.51 | 26.32 | 8.34% |
Previously, we had a partial support of SDPA in Optimum BetterTransformer but we are now looking to slowly deprecate it in favor of upstream support of SDPA directly in Transformers.
Here are the architectures for which support has been requested:
- Codegen (BetterTransformer not supporting CodeGen2 optimum#1050)
- LLAVA (Can optimum.bettertransformer supports LLAVA model? optimum#1592)
- Marian ( BetterTransforer not Support Marian optimum#1142)
- Mistral (Add support for mistral type Model to use Mistral and Zephyr optimum#1553)
- LongT5 (longT5 BetterTransformer implementation optimum#1506)
- ViT (Add support for mistral type Model to use Mistral and Zephyr optimum#1553)
The integration could take inspiration from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/decoder_models.py & https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py
Motivation
Faster training & inference, lower memory requirement
Your contribution
I may work on some at some point, but contributions are most welcome.
You should refer to #26572 to add the support of SDPA for a model, roughly following these steps:
- Create a
XxxSdpaAttention
class inheriting fromXxxAttention
and implement the attention logic using SDPA - Use
_prepare_4d_causal_attention_mask_for_sdpa
instead of_prepare_4d_causal_attention_mask
for SDPA - Use
_prepare_4d_attention_mask_for_sdpa
instead of_prepare_4d_attention_mask
for SDPA - Add
_supports_sdpa = True
toXxxPreTrainedModel
- Add
"sdpa"
key toXXX_ATTENTION_CLASSES
in the model modeling file