Skip to content

[ESM] Add support for sdpa. #34954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 177 additions & 17 deletions src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@

import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
Expand All @@ -31,7 +36,13 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from ...utils import (
auto_docstring,
get_torch_version,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from .configuration_esm import EsmConfig


Expand Down Expand Up @@ -89,7 +100,6 @@ def __init__(self, dim: int):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
inv_freq = inv_freq
self.register_buffer("inv_freq", inv_freq)

self._seq_len_cached = None
Expand All @@ -103,12 +113,20 @@ def _update_cos_sin_tables(self, x, seq_dimension=2):
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
t = torch.arange(x.shape[seq_dimension], device=x.device).float()

# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = torch.outer(t.float(), self.inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

self._cos_cached = emb.cos()[None, None, :, :]
self._sin_cached = emb.sin()[None, None, :, :]
cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]

self._cos_cached = cos.to(dtype=x.dtype)
self._sin_cached = sin.to(dtype=x.dtype)

return self._cos_cached, self._sin_cached

Expand Down Expand Up @@ -374,7 +392,7 @@ def forward(
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
context_layer = torch.matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
Expand All @@ -387,6 +405,113 @@ def forward(
return outputs


class EsmSdpaSelfAttention(EsmSelfAttention):
def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type)
self.attention_dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
if self.position_embedding_type not in ["absolute", "rotary"] or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"EsmSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute or non-rotary `position_embedding_type` or `output_attentions=True` or `head_mask`. "
"Falling back to the manual attention implementation, but specifying the manual implementation will "
"be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

bsz, tgt_len, _ = hidden_states.size()

query_layer = self.transpose_for_scores(self.query(hidden_states))

# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask

# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

# Scale the query for rotary embeddings
query_layer = query_layer * self.attention_head_size**-0.5

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
Comment on lines +408 to +470
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are still missing the spot : we don't need 3 different classes anymore https://github.com/huggingface/transformers/blob/tp-cb/src/transformers/models/llama/modeling_llama.py#L249-L249

# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.attention_dropout_prob if self.training else 0.0,
is_causal=is_causal,
scale=1.0, # Scale is already applied to query_layer
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs


class EsmSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -515,6 +640,7 @@ def forward(
ESM_ATTENTION_CLASSES = {
"eager": EsmSelfAttention,
"flash_attention_2": EsmFlashAttention2,
"sdpa": EsmSdpaSelfAttention,
}


Expand Down Expand Up @@ -799,6 +925,7 @@ class EsmPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
_supports_flash_attn_2 = True
_supports_sdpa = True

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
def _init_weights(self, module):
Expand Down Expand Up @@ -851,6 +978,9 @@ def __init__(self, config, add_pooling_layer=True):
in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
)

self.attn_implementation = config._attn_implementation
self.position_embedding_type = config.position_embedding_type

# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -932,9 +1062,38 @@ def forward(
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)

use_sdpa_attention_masks = (
self.attn_implementation == "sdpa"
and self.position_embedding_type in ["absolute", "rotary"]
and head_mask is None
and not output_attentions
)

# Expand the attention mask
if self.config._attn_implementation == "flash_attention_2":
extended_attention_mask = attention_mask

elif use_sdpa_attention_masks and attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
if self.config.is_decoder:
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
embedding_output,
past_key_values_length,
)
else:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
Expand All @@ -947,7 +1106,15 @@ def forward(
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)

if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None

Expand All @@ -958,13 +1125,6 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
Expand Down
11 changes: 7 additions & 4 deletions tests/models/esm/test_modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import pytest
from parameterized import parameterized

from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import (
Expand Down Expand Up @@ -352,7 +353,7 @@ def test_inference_masked_lm(self):
with torch.no_grad():
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
model.eval()
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]).to(model.device)
output = model(input_ids)[0]

vocab_size = 33
Expand All @@ -370,17 +371,19 @@ def test_inference_no_head(self):
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
model.eval()

input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device)
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = torch.tensor(
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)

@parameterized.expand([({"load_in_8bit": True},), ({"load_in_4bit": True},)])
@require_bitsandbytes
def test_inference_bitsandbytes(self):
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True)
def test_inference_bitsandbytes(self, bnb_kwargs):
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", **bnb_kwargs)
model.eval()

input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device)
# Just test if inference works
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy