Skip to content

Commit 996880a

Browse files
committed
Add SDPA test for ESMFold.
1 parent 72750f6 commit 996880a

File tree

1 file changed

+198
-3
lines changed

1 file changed

+198
-3
lines changed

tests/models/esm/test_modeling_esmfold.py

Lines changed: 198 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
# limitations under the License.
1515
"""Testing suite for the PyTorch ESM model."""
1616

17+
import inspect
18+
import tempfile
1719
import unittest
1820

21+
import numpy as np
1922
from parameterized import parameterized
2023

2124
from transformers import EsmConfig, is_torch_available
2225
from transformers.testing_utils import TestCasePlus, require_torch, require_torch_sdpa, slow, torch_device
26+
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
2327

2428
from ...test_configuration_common import ConfigTester
2529
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@@ -268,11 +272,202 @@ def test_torchscript_simple(self):
268272
def test_multi_gpu_data_parallel_forward(self):
269273
pass
270274

275+
# Modified from test_modeling_common.py as ESMFold doesn't support output hidden states
271276
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
272277
@require_torch_sdpa
273-
@unittest.skip("ESMFold doesn't support output hidden states in normal way which is required in this test.")
274-
def test_eager_matches_sdpa_inference(self):
275-
pass
278+
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
279+
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
280+
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
281+
282+
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
283+
self.skipTest(
284+
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
285+
)
286+
287+
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
288+
if torch_dtype == "float16":
289+
torch_dtype = torch.float16
290+
elif torch_dtype == "bfloat16":
291+
torch_dtype = torch.bfloat16
292+
elif torch_dtype == "float32":
293+
torch_dtype = torch.float32
294+
295+
atols = {
296+
("cpu", False, torch.float32): 1e-6,
297+
("cpu", False, torch.float16): 5e-3,
298+
("cpu", False, torch.bfloat16): 1e-2,
299+
("cpu", True, torch.float32): 1e-6,
300+
("cpu", True, torch.float16): 5e-3,
301+
("cpu", True, torch.bfloat16): 1e-2,
302+
("cuda", False, torch.float32): 1e-6,
303+
("cuda", False, torch.bfloat16): 1e-2,
304+
("cuda", False, torch.float16): 5e-3,
305+
("cuda", True, torch.float32): 1e-6,
306+
("cuda", True, torch.bfloat16): 1e-2,
307+
("cuda", True, torch.float16): 5e-3,
308+
}
309+
rtols = {
310+
("cpu", False, torch.float32): 1e-4,
311+
("cpu", False, torch.float16): 5e-3,
312+
("cpu", False, torch.bfloat16): 1e-2,
313+
("cpu", True, torch.float32): 1e-4,
314+
("cpu", True, torch.float16): 5e-3,
315+
("cpu", True, torch.bfloat16): 1e-2,
316+
("cuda", False, torch.float32): 1e-4,
317+
("cuda", False, torch.bfloat16): 1e-2,
318+
("cuda", False, torch.float16): 5e-3,
319+
("cuda", True, torch.float32): 1e-4,
320+
("cuda", True, torch.bfloat16): 3e-2,
321+
("cuda", True, torch.float16): 5e-3,
322+
}
323+
324+
def get_mean_reldiff(failcase, x, ref, atol, rtol):
325+
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
326+
327+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
328+
329+
config.layer_norm_eps = 1.0
330+
331+
for model_class in self.all_model_classes:
332+
model = model_class(config)
333+
334+
with tempfile.TemporaryDirectory() as tmpdirname:
335+
model.save_pretrained(tmpdirname)
336+
337+
# Note: the half precision will only be applied to backbone model
338+
model_sdpa = model_class.from_pretrained(tmpdirname)
339+
model_sdpa = model_sdpa.eval().to(torch_device)
340+
341+
model_eager = model_class.from_pretrained(
342+
tmpdirname,
343+
attn_implementation="eager",
344+
)
345+
model_eager = model_eager.eval().to(torch_device)
346+
347+
model_sdpa.esm.to(torch_dtype)
348+
model_eager.esm.to(torch_dtype)
349+
350+
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
351+
# but it would be nicer to have an efficient way to use parameterized.expand
352+
fail_cases = []
353+
for padding_side in ["left", "right"]:
354+
for use_mask in [False, True]:
355+
# TODO: if we can also check with `batch_size=1` without being flaky?
356+
for batch_size in [7]:
357+
dummy_input = inputs_dict[model.main_input_name]
358+
359+
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
360+
dummy_input = dummy_input.to(torch_dtype)
361+
362+
dummy_input = dummy_input[:batch_size]
363+
if dummy_input.shape[0] != batch_size:
364+
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
365+
extension = torch.rand(
366+
batch_size - dummy_input.shape[0],
367+
*dummy_input.shape[1:],
368+
dtype=torch_dtype,
369+
device=torch_device,
370+
)
371+
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
372+
else:
373+
extension = torch.randint(
374+
high=5,
375+
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
376+
dtype=dummy_input.dtype,
377+
device=torch_device,
378+
)
379+
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
380+
381+
if not use_mask:
382+
dummy_attention_mask = None
383+
else:
384+
dummy_attention_mask = inputs_dict.get("attention_mask", None)
385+
if dummy_attention_mask is None:
386+
seqlen = dummy_input.shape[-1]
387+
dummy_attention_mask = (
388+
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
389+
)
390+
391+
dummy_attention_mask = dummy_attention_mask[:batch_size]
392+
if dummy_attention_mask.shape[0] != batch_size:
393+
extension = torch.ones(
394+
batch_size - dummy_attention_mask.shape[0],
395+
*dummy_attention_mask.shape[1:],
396+
dtype=dummy_attention_mask.dtype,
397+
device=torch_device,
398+
)
399+
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
400+
dummy_attention_mask = dummy_attention_mask.to(torch_device)
401+
402+
dummy_attention_mask[:] = 1
403+
if padding_side == "left":
404+
dummy_attention_mask[-1, :2] = 0
405+
dummy_attention_mask[-1, 2:] = 1
406+
elif padding_side == "right":
407+
dummy_attention_mask[-1, -2:] = 0
408+
dummy_attention_mask[-1, :-2] = 1
409+
410+
for enable_kernels in [False, True]:
411+
failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
412+
processed_inputs = {
413+
model.main_input_name: dummy_input,
414+
}
415+
416+
# Otherwise fails for e.g. WhisperEncoderModel
417+
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
418+
processed_inputs["attention_mask"] = dummy_attention_mask
419+
420+
# TODO: test gradients as well (& for FA2 as well!)
421+
with torch.no_grad():
422+
with torch.backends.cuda.sdp_kernel(
423+
enable_flash=enable_kernels,
424+
enable_math=True,
425+
enable_mem_efficient=enable_kernels,
426+
):
427+
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
428+
outputs_eager = model_eager(**prepared_inputs)
429+
outputs_sdpa = model_sdpa(**prepared_inputs)
430+
431+
logits_eager = outputs_eager.lm_logits
432+
logits_sdpa = outputs_sdpa.lm_logits
433+
434+
if torch_device in ["cpu", "cuda"]:
435+
atol = atols[torch_device, enable_kernels, torch_dtype]
436+
rtol = rtols[torch_device, enable_kernels, torch_dtype]
437+
else:
438+
atol = 1e-7
439+
rtol = 1e-4
440+
441+
# Masked tokens output slightly deviates - we don't mind that.
442+
if use_mask:
443+
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
444+
_logits_eager = torch.zeros_like(input=logits_eager)
445+
446+
_logits_sdpa[:-1] = logits_sdpa[:-1]
447+
_logits_eager[:-1] = logits_eager[:-1]
448+
449+
if padding_side == "left":
450+
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
451+
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
452+
453+
elif padding_side == "right":
454+
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
455+
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
456+
457+
logits_sdpa = _logits_sdpa
458+
logits_eager = _logits_eager
459+
460+
results = [
461+
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
462+
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
463+
]
464+
# If 80% batch elements have matched results, it's fine
465+
if np.mean(results) < 0.8:
466+
fail_cases.append(
467+
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
468+
)
469+
470+
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
276471

277472

278473
@require_torch

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy