-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Closed
Labels
triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
In [1]: import torch
In [2]: from torch._meta_registrations import grouped_mm
In [3]: a = torch.randn(10, 1024, dtype=torch.bfloat16, device="cuda")
In [4]: b = torch.randn(3, 1024, 4096, dtype=torch.bfloat16, device="cuda")
In [5]: expert_offsets = torch.zeros(4, dtype=torch.int32, device="cuda")
In [6]: grouped_mm(a, b, expert_offsets)
Error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 grouped_mm(a, b, expert_offsets)
File ~/dev/pytorch/main/torch/_prims_common/wrappers.py:309, in out_wrapper.<locals>._out_wrapper.<locals>._fn(*args, **kwargs)
307 result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type]
308 else:
--> 309 result = fn(*args, **kwargs)
310 if result is NotImplemented:
311 return NotImplemented
File ~/dev/pytorch/main/torch/_meta_registrations.py:7688, in grouped_mm(mat_a, mat_b, offs, bias, out_dtype)
7679 @register_meta(aten._grouped_mm)
7680 @out_wrapper()
7681 def grouped_mm(
(...)
7686 out_dtype: Optional[torch.dtype] = None,
7687 ) -> Tensor:
-> 7688 return _meta_grouped_mm_common(
7689 mat_a,
7690 mat_b,
7691 scale_a=None,
7692 scale_b=None,
7693 offs=offs,
7694 bias=bias,
7695 scale_result=None,
7696 out_dtype=out_dtype,
7697 )
File ~/dev/pytorch/main/torch/_meta_registrations.py:7676, in _meta_grouped_mm_common(mat_a, mat_b, scale_a, scale_b, offs, bias, scale_result, out_dtype, use_fast_accum)
7666 torch._check(
7667 bias is None,
7668 lambda: "Bias tensor provided, but it is not supported yet.",
7669 )
7671 torch._check(
7672 out_dtype is None or out_dtype == torch.bfloat16,
7673 lambda: "If output dtype provided, it must be torch.bfloat16.",
7674 )
-> 7676 return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype)
File ~/dev/pytorch/main/torch/_meta_registrations.py:7484, in _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype)
7482 out_size = [offs.size(0), mat1.size(0), mat2.size(1)]
7483 else:
-> 7484 torch._check(
7485 offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
7486 )
7487 out_size = [mat1.size(0), mat2.size(-1)]
7488 else:
File ~/dev/pytorch/main/torch/__init__.py:1696, in _check(cond, message)
1681 def _check(cond, message=None): # noqa: F811
1682 r"""Throws error containing an optional message if the specified condition
1683 is False.
1684
(...)
1694 message. Default: ``None``
1695 """
-> 1696 _check_with(RuntimeError, cond, message)
File ~/dev/pytorch/main/torch/__init__.py:1674, in _check_with(error_type, cond, message)
1672 else:
1673 if not callable(message):
-> 1674 raise TypeError("message must be a callable")
1676 message_evaluated = str(message())
1678 raise error_type(message_evaluated)
TypeError: message must be a callable
Versions
In [x]: torch.__version__
Out[x]: '2.9.0a0+git7381c77'
Metadata
Metadata
Assignees
Labels
triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module