-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🐛 Describe the bug
Using torch.func.jacfwd
on a function that contains a concatenation type operator (e.g., torch.stack
, torch.cat
, torch.vstack
) triggers an assertion RuntimeError: !self.is_mps() INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp":1408, please report a bug to PyTorch. as_strided_tensorimpl does not work with MPS; call self.as_strided(...) instead
Minimal repro:
import torch
def example(x, y):
return torch.cat((x, y))
jac = torch.func.jacfwd(example)
x = torch.tensor([0.0], device="mps")
jac(x, x)
Note torch.func.jacrev
does not cause the error. Looks related to #111547.
Ran with TORCH_SHOW_CPP_STACKTRACES=1
RuntimeError: !self.is_mps() INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp":1408, please report a bug to PyTorch. as_strided_tensorimpl does not work with MPS; call self.as_strided(...) instead
Exception raised from as_strided_tensorimpl at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:1408 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>) + 52 (0x101d9bfd8 in libc10.dylib)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) + 140 (0x101d98c4c in libc10.dylib)
frame #2: c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) + 72 (0x101d98e4c in libc10.dylib)
frame #3: at::native::as_strided_tensorimpl(at::Tensor const&, c10::ArrayRef<long long>, c10::ArrayRef<long long>, std::__1::optional<long long>) + 472 (0x1387f50b8 in libtorch_cpu.dylib)
frame #4: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, std::__1::optional<c10::SymInt>), &at::(anonymous namespace)::(anonymous namespace)::wrapper_ZeroTensor__as_strided(at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, std::__1::optional<c10::SymInt>)>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, std::__1::optional<c10::SymInt>>>, at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, std::__1::optional<c10::SymInt>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, std::__1::optional<c10::SymInt>) + 104 (0x13a22a374 in libtorch_cpu.dylib)
frame #5: at::_ops::as_strided::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, std::__1::optional<c10::SymInt>) + 476 (0x138c2ad68 in libtorch_cpu.dylib)
frame #6: at::Tensor::as_strided(c10::ArrayRef<long long>, c10::ArrayRef<long long>, std::__1::optional<long long>) const + 236 (0x13801424c in libtorch_cpu.dylib)
frame #7: at::native::expand(at::Tensor const&, c10::ArrayRef<long long>, bool) + 348 (0x1387f4178 in libtorch_cpu.dylib)
frame #8: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool), &torch::ADInplaceOrView::(anonymous namespace)::expand(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool)>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool>>, at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool) + 116 (0x13c1a40a4 in libtorch_cpu.dylib)
frame #9: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool), &torch::autograd::VariableType::(anonymous namespace)::expand(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool)>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool>>, at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool) + 996 (0x13b8a2f24 in libtorch_cpu.dylib)
frame #10: c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool), &torch::autograd::VariableType::(anonymous namespace)::expand(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool)>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool>>, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 112 (0x13b8a3dd4 in libtorch_cpu.dylib)
frame #11: c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 144 (0x138032c14 in libtorch_cpu.dylib)
frame #12: at::functorch::Interpreter::sendToNextInterpreter(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*, bool) + 76 (0x1382bf918 in libtorch_cpu.dylib)
frame #13: at::functorch::dynamicLayerBack(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*, bool) + 212 (0x1382be400 in libtorch_cpu.dylib)
frame #14: at::_ops::expand::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>, bool) + 528 (0x1391328f4 in libtorch_cpu.dylib)
frame #15: at::functorch::ensure_has_bdim(at::Tensor const&, bool, c10::SymInt) + 280 (0x13817c7f8 in libtorch_cpu.dylib)
frame #16: at::functorch::(anonymous namespace)::cat_batching_rule(c10::IListRef<at::Tensor> const&, long long) + 948 (0x1382c0a98 in libtorch_cpu.dylib)
frame #17: __decay(c10::guts::infer_function_traits<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(c10::IListRef<at::Tensor> const&, long long), at::Tensor, c10::guts::typelist::typelist<c10::IListRef<at::Tensor> const&, long long>>>::type::return_type) c10::impl::call_functor_with_args_from_stack_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(c10::IListRef<at::Tensor> const&, long long), at::Tensor, c10::guts::typelist::typelist<c10::IListRef<at::Tensor> const&, long long>>, false, 0ul, 1ul, c10::IListRef<at::Tensor> const&, long long>(c10::OperatorKernel*, c10::DispatchKeySet, std::__1::vector<c10::IValue, c10::DispatchKeySet::allocator<std::__1::vector>>*, c10::DispatchKeySet::integer_sequence<unsigned long, 0ul, 1ul>, c10::guts::typelist::typelist<c10::IListRef<at::Tensor> const&, long long>*) + 152 (0x13806c7ec in libtorch_cpu.dylib)
frame #18: c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(c10::IListRef<at::Tensor> const&, long long), at::Tensor, c10::guts::typelist::typelist<c10::IListRef<at::Tensor> const&, long long>>, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 40 (0x13806c698 in libtorch_cpu.dylib)
frame #19: c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 144 (0x138032c14 in libtorch_cpu.dylib)
frame #20: at::functorch::Interpreter::process(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 76 (0x1382bf7c4 in libtorch_cpu.dylib)
frame #21: void c10::BoxedKernel::make_boxed_function<&at::functorch::dynamicLayerFrontFallback(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*)>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 376 (0x1382bde0c in libtorch_cpu.dylib)
frame #22: c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 144 (0x138032c14 in libtorch_cpu.dylib)
frame #23: at::functorch::autogradBasedTransformSendToNext(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*, at::functorch::Interpreter const&, at::functorch::TransformType, std::__1::optional<bool>, std::__1::optional<bool>, bool) + 1216 (0x13817aabc in libtorch_cpu.dylib)
frame #24: at::functorch::Interpreter::sendToNextInterpreter(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*, bool) + 104 (0x1382bf934 in libtorch_cpu.dylib)
frame #25: at::functorch::dynamicLayerBack(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*, bool) + 212 (0x1382be400 in libtorch_cpu.dylib)
frame #26: c10::impl::BoxedKernelWrapper<at::Tensor (c10::IListRef<at::Tensor> const&, long long), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long) + 80 (0x138c3904c in libtorch_cpu.dylib)
frame #27: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long), &torch::autograd::VariableType::(anonymous namespace)::cat(c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long)>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long>>, at::Tensor (c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long) + 640 (0x13b36c170 in libtorch_cpu.dylib)
frame #28: at::_ops::cat::call(c10::IListRef<at::Tensor> const&, long long) + 304 (0x138c38400 in libtorch_cpu.dylib)
frame #29: torch::autograd::generated::details::cat_jvp(c10::IListRef<at::Tensor> const&, long long) + 876 (0x13cd89a34 in libtorch_cpu.dylib)
frame #30: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long), &torch::autograd::VariableType::(anonymous namespace)::cat(c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long)>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long>>, at::Tensor (c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long) + 1040 (0x13b36c300 in libtorch_cpu.dylib)
frame #31: c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long), &torch::autograd::VariableType::(anonymous namespace)::cat(c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long)>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long>>, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 148 (0x13b36cad8 in libtorch_cpu.dylib)
frame #32: c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 144 (0x138032c14 in libtorch_cpu.dylib)
frame #33: at::functorch::autogradBasedTransformProcess(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*, long long, at::functorch::TransformType) + 528 (0x13817a014 in libtorch_cpu.dylib)
frame #34: at::functorch::Interpreter::process(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 104 (0x1382bf7e0 in libtorch_cpu.dylib)
frame #35: void c10::BoxedKernel::make_boxed_function<&at::functorch::dynamicLayerFrontFallback(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*)>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 376 (0x1382bde0c in libtorch_cpu.dylib)
frame #36: c10::impl::BoxedKernelWrapper<at::Tensor (c10::IListRef<at::Tensor> const&, long long), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, c10::IListRef<at::Tensor> const&, long long) + 80 (0x138c3904c in libtorch_cpu.dylib)
frame #37: at::_ops::cat::call(c10::IListRef<at::Tensor> const&, long long) + 416 (0x138c38470 in libtorch_cpu.dylib)
frame #38: torch::autograd::THPVariable_cat(_object*, _object*, _object*) + 568 (0x1039cde9c in libtorch_python.dylib)
<omitting python frames>
frame #49: start + 6000 (0x182f06b4c in dyld)
Versions
Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.4.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.3)
CMake version: version 3.31.5
Libc version: N/A
Python version: 3.12.8 (main, Mar 24 2025, 16:58:11) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-15.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M4 Max
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] torch==2.7.0
[conda] Could not collect
cc: @kulinseth @albanD @malfet
cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen @zou3519 @Chillee @samdow @kshitij12345