Skip to content

Commit 7400d4b

Browse files
committed
Explicitly move input_ids to desired device in test_modeling_esm.
1 parent d3f7e03 commit 7400d4b

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/models/esm/test_modeling_esm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import unittest
1818

1919
import pytest
20+
from parameterized import parameterized
2021

2122
from transformers import EsmConfig, is_torch_available
2223
from transformers.testing_utils import (
@@ -352,7 +353,7 @@ def test_inference_masked_lm(self):
352353
with torch.no_grad():
353354
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
354355
model.eval()
355-
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
356+
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]).to(model.device)
356357
output = model(input_ids)[0]
357358

358359
vocab_size = 33
@@ -370,17 +371,19 @@ def test_inference_no_head(self):
370371
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
371372
model.eval()
372373

373-
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
374+
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device)
374375
output = model(input_ids)[0]
375376
# compare the actual values for a slice.
376377
expected_slice = torch.tensor(
377378
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
378379
)
379380
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
380381

382+
@parameterized.expand([({"load_in_8bit": True},), ({"load_in_4bit": True},)])
381383
@require_bitsandbytes
382-
def test_inference_bitsandbytes(self):
383-
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True)
384+
def test_inference_bitsandbytes(self, bnb_kwargs):
385+
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", **bnb_kwargs)
386+
model.eval()
384387

385388
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device)
386389
# Just test if inference works

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy