|
14 | 14 | # limitations under the License.
|
15 | 15 | """Testing suite for the PyTorch ESM model."""
|
16 | 16 |
|
| 17 | +import inspect |
| 18 | +import tempfile |
17 | 19 | import unittest
|
18 | 20 |
|
| 21 | +import numpy as np |
19 | 22 | from parameterized import parameterized
|
20 | 23 |
|
21 | 24 | from transformers import EsmConfig, is_torch_available
|
22 | 25 | from transformers.testing_utils import TestCasePlus, require_torch, require_torch_sdpa, slow, torch_device
|
| 26 | +from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device |
23 | 27 |
|
24 | 28 | from ...test_configuration_common import ConfigTester
|
25 | 29 | from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
@@ -268,11 +272,202 @@ def test_torchscript_simple(self):
|
268 | 272 | def test_multi_gpu_data_parallel_forward(self):
|
269 | 273 | pass
|
270 | 274 |
|
| 275 | + # Modified from test_modeling_common.py as ESMFold doesn't support output hidden states |
271 | 276 | @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
272 | 277 | @require_torch_sdpa
|
273 |
| - @unittest.skip("ESMFold doesn't support output hidden states in normal way which is required in this test.") |
274 |
| - def test_eager_matches_sdpa_inference(self): |
275 |
| - pass |
| 278 | + def test_eager_matches_sdpa_inference(self, torch_dtype: str): |
| 279 | + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): |
| 280 | + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") |
| 281 | + |
| 282 | + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): |
| 283 | + self.skipTest( |
| 284 | + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" |
| 285 | + ) |
| 286 | + |
| 287 | + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. |
| 288 | + if torch_dtype == "float16": |
| 289 | + torch_dtype = torch.float16 |
| 290 | + elif torch_dtype == "bfloat16": |
| 291 | + torch_dtype = torch.bfloat16 |
| 292 | + elif torch_dtype == "float32": |
| 293 | + torch_dtype = torch.float32 |
| 294 | + |
| 295 | + atols = { |
| 296 | + ("cpu", False, torch.float32): 1e-6, |
| 297 | + ("cpu", False, torch.float16): 5e-3, |
| 298 | + ("cpu", False, torch.bfloat16): 1e-2, |
| 299 | + ("cpu", True, torch.float32): 1e-6, |
| 300 | + ("cpu", True, torch.float16): 5e-3, |
| 301 | + ("cpu", True, torch.bfloat16): 1e-2, |
| 302 | + ("cuda", False, torch.float32): 1e-6, |
| 303 | + ("cuda", False, torch.bfloat16): 1e-2, |
| 304 | + ("cuda", False, torch.float16): 5e-3, |
| 305 | + ("cuda", True, torch.float32): 1e-6, |
| 306 | + ("cuda", True, torch.bfloat16): 1e-2, |
| 307 | + ("cuda", True, torch.float16): 5e-3, |
| 308 | + } |
| 309 | + rtols = { |
| 310 | + ("cpu", False, torch.float32): 1e-4, |
| 311 | + ("cpu", False, torch.float16): 5e-3, |
| 312 | + ("cpu", False, torch.bfloat16): 1e-2, |
| 313 | + ("cpu", True, torch.float32): 1e-4, |
| 314 | + ("cpu", True, torch.float16): 5e-3, |
| 315 | + ("cpu", True, torch.bfloat16): 1e-2, |
| 316 | + ("cuda", False, torch.float32): 1e-4, |
| 317 | + ("cuda", False, torch.bfloat16): 1e-2, |
| 318 | + ("cuda", False, torch.float16): 5e-3, |
| 319 | + ("cuda", True, torch.float32): 1e-4, |
| 320 | + ("cuda", True, torch.bfloat16): 3e-2, |
| 321 | + ("cuda", True, torch.float16): 5e-3, |
| 322 | + } |
| 323 | + |
| 324 | + def get_mean_reldiff(failcase, x, ref, atol, rtol): |
| 325 | + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" |
| 326 | + |
| 327 | + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| 328 | + |
| 329 | + config.layer_norm_eps = 1.0 |
| 330 | + |
| 331 | + for model_class in self.all_model_classes: |
| 332 | + model = model_class(config) |
| 333 | + |
| 334 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 335 | + model.save_pretrained(tmpdirname) |
| 336 | + |
| 337 | + # Note: the half precision will only be applied to backbone model |
| 338 | + model_sdpa = model_class.from_pretrained(tmpdirname) |
| 339 | + model_sdpa = model_sdpa.eval().to(torch_device) |
| 340 | + |
| 341 | + model_eager = model_class.from_pretrained( |
| 342 | + tmpdirname, |
| 343 | + attn_implementation="eager", |
| 344 | + ) |
| 345 | + model_eager = model_eager.eval().to(torch_device) |
| 346 | + |
| 347 | + model_sdpa.esm.to(torch_dtype) |
| 348 | + model_eager.esm.to(torch_dtype) |
| 349 | + |
| 350 | + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model, |
| 351 | + # but it would be nicer to have an efficient way to use parameterized.expand |
| 352 | + fail_cases = [] |
| 353 | + for padding_side in ["left", "right"]: |
| 354 | + for use_mask in [False, True]: |
| 355 | + # TODO: if we can also check with `batch_size=1` without being flaky? |
| 356 | + for batch_size in [7]: |
| 357 | + dummy_input = inputs_dict[model.main_input_name] |
| 358 | + |
| 359 | + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: |
| 360 | + dummy_input = dummy_input.to(torch_dtype) |
| 361 | + |
| 362 | + dummy_input = dummy_input[:batch_size] |
| 363 | + if dummy_input.shape[0] != batch_size: |
| 364 | + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: |
| 365 | + extension = torch.rand( |
| 366 | + batch_size - dummy_input.shape[0], |
| 367 | + *dummy_input.shape[1:], |
| 368 | + dtype=torch_dtype, |
| 369 | + device=torch_device, |
| 370 | + ) |
| 371 | + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) |
| 372 | + else: |
| 373 | + extension = torch.randint( |
| 374 | + high=5, |
| 375 | + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), |
| 376 | + dtype=dummy_input.dtype, |
| 377 | + device=torch_device, |
| 378 | + ) |
| 379 | + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) |
| 380 | + |
| 381 | + if not use_mask: |
| 382 | + dummy_attention_mask = None |
| 383 | + else: |
| 384 | + dummy_attention_mask = inputs_dict.get("attention_mask", None) |
| 385 | + if dummy_attention_mask is None: |
| 386 | + seqlen = dummy_input.shape[-1] |
| 387 | + dummy_attention_mask = ( |
| 388 | + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) |
| 389 | + ) |
| 390 | + |
| 391 | + dummy_attention_mask = dummy_attention_mask[:batch_size] |
| 392 | + if dummy_attention_mask.shape[0] != batch_size: |
| 393 | + extension = torch.ones( |
| 394 | + batch_size - dummy_attention_mask.shape[0], |
| 395 | + *dummy_attention_mask.shape[1:], |
| 396 | + dtype=dummy_attention_mask.dtype, |
| 397 | + device=torch_device, |
| 398 | + ) |
| 399 | + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) |
| 400 | + dummy_attention_mask = dummy_attention_mask.to(torch_device) |
| 401 | + |
| 402 | + dummy_attention_mask[:] = 1 |
| 403 | + if padding_side == "left": |
| 404 | + dummy_attention_mask[-1, :2] = 0 |
| 405 | + dummy_attention_mask[-1, 2:] = 1 |
| 406 | + elif padding_side == "right": |
| 407 | + dummy_attention_mask[-1, -2:] = 0 |
| 408 | + dummy_attention_mask[-1, :-2] = 1 |
| 409 | + |
| 410 | + for enable_kernels in [False, True]: |
| 411 | + failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}" |
| 412 | + processed_inputs = { |
| 413 | + model.main_input_name: dummy_input, |
| 414 | + } |
| 415 | + |
| 416 | + # Otherwise fails for e.g. WhisperEncoderModel |
| 417 | + if "attention_mask" in inspect.signature(model_eager.forward).parameters: |
| 418 | + processed_inputs["attention_mask"] = dummy_attention_mask |
| 419 | + |
| 420 | + # TODO: test gradients as well (& for FA2 as well!) |
| 421 | + with torch.no_grad(): |
| 422 | + with torch.backends.cuda.sdp_kernel( |
| 423 | + enable_flash=enable_kernels, |
| 424 | + enable_math=True, |
| 425 | + enable_mem_efficient=enable_kernels, |
| 426 | + ): |
| 427 | + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) |
| 428 | + outputs_eager = model_eager(**prepared_inputs) |
| 429 | + outputs_sdpa = model_sdpa(**prepared_inputs) |
| 430 | + |
| 431 | + logits_eager = outputs_eager.lm_logits |
| 432 | + logits_sdpa = outputs_sdpa.lm_logits |
| 433 | + |
| 434 | + if torch_device in ["cpu", "cuda"]: |
| 435 | + atol = atols[torch_device, enable_kernels, torch_dtype] |
| 436 | + rtol = rtols[torch_device, enable_kernels, torch_dtype] |
| 437 | + else: |
| 438 | + atol = 1e-7 |
| 439 | + rtol = 1e-4 |
| 440 | + |
| 441 | + # Masked tokens output slightly deviates - we don't mind that. |
| 442 | + if use_mask: |
| 443 | + _logits_sdpa = torch.zeros_like(input=logits_sdpa) |
| 444 | + _logits_eager = torch.zeros_like(input=logits_eager) |
| 445 | + |
| 446 | + _logits_sdpa[:-1] = logits_sdpa[:-1] |
| 447 | + _logits_eager[:-1] = logits_eager[:-1] |
| 448 | + |
| 449 | + if padding_side == "left": |
| 450 | + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] |
| 451 | + _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] |
| 452 | + |
| 453 | + elif padding_side == "right": |
| 454 | + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] |
| 455 | + _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] |
| 456 | + |
| 457 | + logits_sdpa = _logits_sdpa |
| 458 | + logits_eager = _logits_eager |
| 459 | + |
| 460 | + results = [ |
| 461 | + torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) |
| 462 | + for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) |
| 463 | + ] |
| 464 | + # If 80% batch elements have matched results, it's fine |
| 465 | + if np.mean(results) < 0.8: |
| 466 | + fail_cases.append( |
| 467 | + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) |
| 468 | + ) |
| 469 | + |
| 470 | + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) |
276 | 471 |
|
277 | 472 |
|
278 | 473 | @require_torch
|
|
0 commit comments