-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🐛 Describe the bug
Reproduction Script
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, rnn_dims=512):
super().__init__()
self.rnn = nn.LSTM(input_size=rnn_dims, hidden_size=rnn_dims, batch_first=True)
def forward(self, x):
h0 = torch.zeros(1, x.size(0), self.rnn.hidden_size, device=x.device, dtype=x.dtype)
c0 = torch.zeros(1, x.size(0), self.rnn.hidden_size, device=x.device, dtype=x.dtype)
out, h1 = self.rnn(x, (h0, c0))
return h1 # tuple (h_n, c_n)
if __name__ == "__main__":
torch.manual_seed(0)
model = MyModel().eval().to("cuda")
input_tensor = torch.rand(1, 1, 512, dtype=torch.float32).to("cuda")
try:
with torch.no_grad():
eager_output = model(input_tensor)
print("Eager output:", eager_output[0].shape, eager_output[1].shape)
except Exception as e:
print(e)
try:
compiled_model = torch.compile(model, fullgraph=True, backend="inductor")
with torch.no_grad():
compiled_output = compiled_model(input_tensor)
print("Compiled output:", compiled_output[0].shape, compiled_output[1].shape)
except Exception as e:
print(e)
Error logs
Output on Stable (2.7.1) and Latest main
(2.9.0a0+gitfcc682b)
Output (Stable)
Eager output: torch.Size([1, 1, 512]) torch.Size([1, 1, 512])
TorchDynamo purposely graph breaks on RNN, GRU, LSTMs
from user code:
File "for_test.py", line 10, in forward
h0 = torch.zeros(1, x.size(0), self.rnn.hidden_size, device=x.device, dtype=x.dtype)
Output (main branch)
Eager output: torch.Size([1, 1, 512]) torch.Size([1, 1, 512])
Attempted to wrap RNN, GRU, or LSTM
Explanation: Dynamo does not support RNN, GRU, or LSTM.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: LSTM(512, 512, batch_first=True)
from user code:
File "test_coverage/for_test.py", line 10, in forward
h0 = torch.zeros(1, x.size(0), self.rnn.hidden_size, device=x.device, dtype=x.dtype)
What’s Confusing?
-
This model runs just fine under eager execution.
-
It’s a very basic LSTM-based model, and doesn’t use anything exotic or fused.
-
Yet torch.compile breaks, even on the latest main (torch==2.9.0a0+gitfcc682b).
-
I understand some ops are not supported yet — but shouldn't a warning or fallback happen, instead of an outright failure? Or is there some plan to support nn.LSTM and friends via fake tensor rules in the future?
Versions
Both torch==2.9.0a0+gitfcc682b
and torch==2.7.1a0+gitfcc682b
have the same problem.
PyTorch version: 2.7.1a0+gite2d141d
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
Is CUDA available: True
Thanks for the great work on torch.compile
— I appreciate any clarification on this!
cc @mikaylagawarecki @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames