Skip to content

TypeError: message must be a callable when calling grouped_mm with incompatible batch size for offsets #157922

@IvanYashchuk

Description

@IvanYashchuk

🐛 Describe the bug

In [1]: import torch

In [2]: from torch._meta_registrations import grouped_mm

In [3]: a = torch.randn(10, 1024, dtype=torch.bfloat16, device="cuda")

In [4]: b = torch.randn(3, 1024, 4096, dtype=torch.bfloat16, device="cuda")

In [5]: expert_offsets = torch.zeros(4, dtype=torch.int32, device="cuda")

In [6]: grouped_mm(a, b, expert_offsets)

Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 grouped_mm(a, b, expert_offsets)

File ~/dev/pytorch/main/torch/_prims_common/wrappers.py:309, in out_wrapper.<locals>._out_wrapper.<locals>._fn(*args, **kwargs)
    307     result = fn(*args, is_out=(out is not None), **kwargs)  # type: ignore[arg-type]
    308 else:
--> 309     result = fn(*args, **kwargs)
    310 if result is NotImplemented:
    311     return NotImplemented

File ~/dev/pytorch/main/torch/_meta_registrations.py:7688, in grouped_mm(mat_a, mat_b, offs, bias, out_dtype)
   7679 @register_meta(aten._grouped_mm)
   7680 @out_wrapper()
   7681 def grouped_mm(
   (...)
   7686     out_dtype: Optional[torch.dtype] = None,
   7687 ) -> Tensor:
-> 7688     return _meta_grouped_mm_common(
   7689         mat_a,
   7690         mat_b,
   7691         scale_a=None,
   7692         scale_b=None,
   7693         offs=offs,
   7694         bias=bias,
   7695         scale_result=None,
   7696         out_dtype=out_dtype,
   7697     )

File ~/dev/pytorch/main/torch/_meta_registrations.py:7676, in _meta_grouped_mm_common(mat_a, mat_b, scale_a, scale_b, offs, bias, scale_result, out_dtype, use_fast_accum)
   7666 torch._check(
   7667     bias is None,
   7668     lambda: "Bias tensor provided, but it is not supported yet.",
   7669 )
   7671 torch._check(
   7672     out_dtype is None or out_dtype == torch.bfloat16,
   7673     lambda: "If output dtype provided, it must be torch.bfloat16.",
   7674 )
-> 7676 return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype)

File ~/dev/pytorch/main/torch/_meta_registrations.py:7484, in _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype)
   7482         out_size = [offs.size(0), mat1.size(0), mat2.size(1)]
   7483     else:
-> 7484         torch._check(
   7485             offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
   7486         )
   7487         out_size = [mat1.size(0), mat2.size(-1)]
   7488 else:

File ~/dev/pytorch/main/torch/__init__.py:1696, in _check(cond, message)
   1681 def _check(cond, message=None):  # noqa: F811
   1682     r"""Throws error containing an optional message if the specified condition
   1683     is False.
   1684 
   (...)
   1694             message. Default: ``None``
   1695     """
-> 1696     _check_with(RuntimeError, cond, message)

File ~/dev/pytorch/main/torch/__init__.py:1674, in _check_with(error_type, cond, message)
   1672 else:
   1673     if not callable(message):
-> 1674         raise TypeError("message must be a callable")
   1676     message_evaluated = str(message())
   1678 raise error_type(message_evaluated)

TypeError: message must be a callable

Versions

In [x]: torch.__version__
Out[x]: '2.9.0a0+git7381c77'

Metadata

Metadata

Assignees

No one assigned

    Labels

    triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy