17
17
import unittest
18
18
19
19
import pytest
20
+ from parameterized import parameterized
20
21
21
22
from transformers import EsmConfig , is_torch_available
22
23
from transformers .testing_utils import (
@@ -352,7 +353,7 @@ def test_inference_masked_lm(self):
352
353
with torch .no_grad ():
353
354
model = EsmForMaskedLM .from_pretrained ("facebook/esm2_t6_8M_UR50D" )
354
355
model .eval ()
355
- input_ids = torch .tensor ([[0 , 1 , 2 , 3 , 4 , 5 ]])
356
+ input_ids = torch .tensor ([[0 , 1 , 2 , 3 , 4 , 5 ]]). to ( model . device )
356
357
output = model (input_ids )[0 ]
357
358
358
359
vocab_size = 33
@@ -370,17 +371,19 @@ def test_inference_no_head(self):
370
371
model = EsmModel .from_pretrained ("facebook/esm2_t6_8M_UR50D" )
371
372
model .eval ()
372
373
373
- input_ids = torch .tensor ([[0 , 6 , 4 , 13 , 5 , 4 , 16 , 12 , 11 , 7 , 2 ]])
374
+ input_ids = torch .tensor ([[0 , 6 , 4 , 13 , 5 , 4 , 16 , 12 , 11 , 7 , 2 ]]). to ( model . device )
374
375
output = model (input_ids )[0 ]
375
376
# compare the actual values for a slice.
376
377
expected_slice = torch .tensor (
377
378
[[[0.1444 , 0.5413 , 0.3248 ], [0.3034 , 0.0053 , 0.3108 ], [0.3228 , - 0.2499 , 0.3415 ]]]
378
379
)
379
380
torch .testing .assert_close (output [:, :3 , :3 ], expected_slice , rtol = 1e-4 , atol = 1e-4 )
380
381
382
+ @parameterized .expand ([({"load_in_8bit" : True },), ({"load_in_4bit" : True },)])
381
383
@require_bitsandbytes
382
- def test_inference_bitsandbytes (self ):
383
- model = EsmForMaskedLM .from_pretrained ("facebook/esm2_t36_3B_UR50D" , load_in_8bit = True )
384
+ def test_inference_bitsandbytes (self , bnb_kwargs ):
385
+ model = EsmForMaskedLM .from_pretrained ("facebook/esm2_t36_3B_UR50D" , ** bnb_kwargs )
386
+ model .eval ()
384
387
385
388
input_ids = torch .tensor ([[0 , 6 , 4 , 13 , 5 , 4 , 16 , 12 , 11 , 7 , 2 ]]).to (model .device )
386
389
# Just test if inference works
0 commit comments