Skip to content

torch.compile fails with InternalTorchDynamoError when slicing torch.linalg.svd results #157945

@LiSsHhUuAaIi

Description

@LiSsHhUuAaIi

🐛 Describe the bug

When using torch.compile on models that slice the results of torch.linalg.svd, compilation fails across all torch.compile backends (eager/aot_eager/aot_eager_decomp_partition/inductor) with:

torch._dynamo.exc.InternalTorchDynamoError: TypeError: VariableTracker.__init__() got an unexpected keyword argument 'dynamic_attributes'

Non-compiled execution works normally.

To reproduce

import torch
import torch.nn as nn

class SVDCompressor(nn.Module):
    def __init__(self, k=10):
        super().__init__()
        self.k = k
    
    def forward(self, x):
        U, S = torch.linalg.svd(x)[:2]
        reduced = U[:, :, :self.k] @ torch.diag_embed(S[:, :self.k])
        return reduced

input = torch.randn(4, 8, 6)
model = SVDCompressor(k=5)

out1 = model(input.clone()) # works
out2 = torch.compile(model)(input.clone()) # fails

Error logs

Traceback (most recent call last):
  File "test.py", line 18, in <module>
    out2 = torch.compile(model, backend="eager")(input.clone()) # fails
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\eval_frame.py", line 372, in __call__
    return super().__call__(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\eval_frame.py", line 699, in compile_wrapper
    return fn(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\convert_frame.py", line 1469, in __call__
    return self._torchdynamo_orig_callable(
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\convert_frame.py", line 1248, in __call__
    result = self._inner_convert(
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\convert_frame.py", line 625, in __call__
    return _compile(
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\convert_frame.py", line 1144, in _compile
    raise InternalTorchDynamoError(
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\convert_frame.py", line 1092, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\convert_frame.py", line 779, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\convert_frame.py", line 818, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\bytecode_transformation.py", line 1424, in transform_code_object
    transformations(instructions, code_options)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\convert_frame.py", line 265, in _fn
    return fn(*args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\convert_frame.py", line 743, in transform
    tracer.run()
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\symbolic_convert.py", line 3484, in run
    super().run()
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\symbolic_convert.py", line 1359, in run
    while self.step():
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\symbolic_convert.py", line 1263, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\symbolic_convert.py", line 831, in wrapper
    return inner_fn(self, inst)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\symbolic_convert.py", line 422, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\builtin.py", line 1168, in call_function
    return handler(tx, args, kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\builtin.py", line 988, in builtin_dispatch
    rv = fn(tx, args, kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\builtin.py", line 866, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\builtin.py", line 1776, in call_getitem
    return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\lists.py", line 972, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\lists.py", line 714, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\lists.py", line 150, in call_method
    return self.getitem_const(tx, value)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\lists.py", line 112, in getitem_const
    return self.clone(
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\base.py", line 259, in clone
    return self.__class__(**args)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\lists.py", line 888, in __init__
    super().__init__(items, **kwargs)
  File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\_dynamo\variables\lists.py", line 77, in __init__
    super().__init__(**kwargs)
torch._dynamo.exc.InternalTorchDynamoError: TypeError: VariableTracker.__init__() got an unexpected keyword argument 'dynamic_attributes'

from user code:
   File "test.py", line 10, in forward
    U, S = torch.linalg.svd(x)[:2]

Versions

Collecting environment information...
PyTorch version: 2.8.0.dev20250605+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 4.0.2
Libc version: N/A

Python version: 3.10.10 (tags/v3.10.10:aad5f6a, Feb 7 2023, 17:20:36) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy