Skip to content

InductorError: CppCompileError: C++ compile error on a function with a single item call #158060

@StrongerXi

Description

@StrongerXi

🐛 Describe the bug

This was discovered in #157499, which tries to turn on capture_scalar_outputs by default. Specifically PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=0 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_item_cuda_float64 fails. I made a minimal repro below.

import torch


torch._dynamo.config.capture_scalar_outputs = True
@torch.compile(backend="inductor")
def f(x):
    return x.item()


x = torch.tensor(42, device='cuda', dtype=torch.float32)  # passes if `dtype=torch.int32`
res = f(x)

Error logs

Traceback (most recent call last):
  File "/home/ryanguo99/scratch/comp.py", line 11, in <module>
    res = f(x)
          ^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 797, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/compile_fx.py", line 952, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/compile_fx.py", line 936, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/compile_fx.py", line 1622, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/compile_fx.py", line 1485, in codegen_and_compile
    compiled_module = graph.compile_to_module()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/graph.py", line 2289, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/graph.py", line 2299, in _compile_to_module
    mod = self._compile_to_module_lines(wrapper_code)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/graph.py", line 2367, in _compile_to_module_lines
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/codecache.py", line 3237, in load_by_key_path
    mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/runtime/compile_tasks.py", line 31, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_ryanguo99/a4/ca4ycqgdw5jjzhco5gtkixdcywxykmibbkvchgzr7cdklcchmaiw.py", line 48, in <module>
    async_compile.wait(globals())
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/async_compile.py", line 547, in wait
    self._wait_futures(scope)
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/async_compile.py", line 567, in _wait_futures
    kernel = result.result()
             ^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/codecache.py", line 3971, in result
    return self.result_fn()
           ^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/codecache.py", line 2714, in future
    result = get_result()
             ^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/codecache.py", line 2510, in load_fn
    future.result()
  File "/home/ryanguo99/.conda/envs/comfyui/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/.conda/envs/comfyui/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/home/ryanguo99/.conda/envs/comfyui/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/codecache.py", line 2540, in _worker_compile_cpp
    builder.build()
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cpp_builder.py", line 1712, in build
    run_compile_cmd(build_cmd, cwd=_build_tmp_dir)
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cpp_builder.py", line 402, in run_compile_cmd
    _run_compile_cmd(cmd_line, cwd)
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/cpp_builder.py", line 397, in _run_compile_cmd
    raise exc.CppCompileError(cmd, output) from e
torch._inductor.exc.InductorError: CppCompileError: C++ compile error

Command:
g++ /tmp/torchinductor_ryanguo99/h7/ch7dz3bknb6c6x6ktdzmbfykruashnh4wujzmz4az2lidms6rs4l.main.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D STANDALONE_TORCH_HEADER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_AVX512 -shared -fPIC -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fexcess-precision=fast -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -fno-tree-loop-vectorize -march=native -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -pedantic -fopenmp -include /tmp/torchinductor_ryanguo99/precompiled_headers/celw6u5lto4enac6gwgc7wzodghtjir5ewe5bk76ucgat424h3oa.h -I/home/ryanguo99/.conda/envs/comfyui/include/python3.11 -I/home/ryanguo99/repos/pytorch/torch/include -I/home/ryanguo99/repos/pytorch/torch/include/torch/csrc/api/include -mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma -o /tmp/torchinductor_ryanguo99/h7/ch7dz3bknb6c6x6ktdzmbfykruashnh4wujzmz4az2lidms6rs4l.main.so -ltorch -ltorch_cpu -ltorch_python -lgomp -L/home/ryanguo99/.conda/envs/comfyui/lib -L/home/ryanguo99/repos/pytorch/torch/lib

Output:
/tmp/torchinductor_ryanguo99/h7/ch7dz3bknb6c6x6ktdzmbfykruashnh4wujzmz4az2lidms6rs4l.main.cpp: In function ‘void kernel(double*)’:
/tmp/torchinductor_ryanguo99/h7/ch7dz3bknb6c6x6ktdzmbfykruashnh4wujzmz4az2lidms6rs4l.main.cpp:8:29: error: ‘zuf0’ was not declared in this scope
    8 |                 auto tmp0 = zuf0;
      |                             ^~~~
In file included from /home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512.h:14,
                 from /home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec.h:4,
                 from /home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/functional_base.h:6,
                 from /home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/functional.h:3,
                 from /home/ryanguo99/repos/pytorch/torch/include/torch/csrc/inductor/cpp_prefix.h:44,
                 from /tmp/torchinductor_ryanguo99/precompiled_headers/celw6u5lto4enac6gwgc7wzodghtjir5ewe5bk76ucgat424h3oa.h:1:
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h: In instantiation of ‘at::vec::CPU_CAPABILITY::Vectorized<T> at::vec::CPU_CAPABILITY::shift_512_8(const at::vec::CPU_CAPABILITY::Vectorized<T>&, const at::vec::CPU_CAPABILITY::Vectorized<T>&) [with bool left_shift = true; T = signed char; typename std::enable_if<(is_same_v<T, signed char> || is_same_v<T, unsigned char>), int>::type <anonymous> = 0]’:
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:2067:27:   required from here
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1859:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1859 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1861:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1861 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1863:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1863 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1865:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1865 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1867:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1867 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1869:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1869 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1871:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1871 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1873:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1873 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1875:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1875 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1877:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1877 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1879:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1879 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1881:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1881 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1883:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1883 |       0x80,
      |       ^~~~
/home/ryanguo99/repos/pytorch/torch/include/ATen/cpu/vec/vec512/vec512_int.h:1885:7: warning: overflow in conversion from ‘int’ to ‘char’ changes value from ‘128’ to ‘'\37777777600'’ [-Woverflow]
 1885 |       0x80,
      |       ^~~~
......

Versions

main fcc682b, python 3.11

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy