Skip to content

[RFC]: PyTorch Low-Precision GEMMs Public API #157950

@drisspg

Description

@drisspg

RFC: PyTorch Low-Precision GEMMs Public API

Summary

This RFC proposes public APIs for low-precision matrix multiplications in PyTorch, addressing the current gap where users must rely on private/undocumented APIs like _scaled_mm and _scaled_grouped_mm. With the rapid adoption of FP8 and other low-precision formats (DeepSeek-V3, Llama 4, NVIDIA H100/B200, AMD MI300+), PyTorch needs robust, officially supported APIs for scaled matrix operations.

The proposal presents two approaches for supporting various scaling strategies (per-tensor, per-row, group-wise, block-wise):

Approach 1: New Dedicated Functions

  • torch.scaled_mm - Basic scaled matrix multiply (when bias=None in scaled_addmm)
  • torch.scaled_addmm - Scaled matrix multiply with bias addition
  • torch.scaled_bmm - Batched scaled matrix multiply (when bias=None in scaled_baddbmm)
  • torch.scaled_baddbmm - Batched scaled matrix multiply with bias addition
  • torch.scaled_grouped_mm - Grouped GEMM for MoE workloads (when bias=None)
  • torch.scaled_grouped_addmm - Grouped GEMM with bias for MoE workloads

Approach 2: Extend Existing Functions

  • Add optional scaling parameters to torch.addmm, torch.baddbmm, torch.bmm, etc.

Both approaches have trade-offs in terms of API clarity, type safety, and user experience.

Authors

Motivation

Industry Adoption and Demand

Low-precision matrix multiplications have moved from research to production:

  • DeepSeek-V3: Uses FP8 with group-wise scaling (128 elements/group for activations, 128×128 blocks for weights) [Technical Report]
  • Llama 4: Pre-trained using FP8 precision, achieving 390 TFLOPs/GPU utilization [Meta Blog]
  • NVIDIA H100/B200: Native hardware support for FP8 E4M3/E5M2 with various scaling strategies

Performance and Memory Benefits

Image

Source: https://semianalysis.com/2025/06/13/amd-advancing-ai-mi350x-and-mi400-ualoe72-mi500-ual256/

Platform FP16 TFLOPS FP8 TFLOPS FP4/MX4 TFLOPS Improvement vs FP16 References
NVIDIA H100 989.5 1,978.5 N/A 2.0x H100 Datasheet
NVIDIA B200 2,250 4,500 9,000 2x (FP8), 4x (FP4) MLPerf v4.1, Blackwell Arch
AMD MI300X 1,307 2,615 N/A 2x (FP8) MI300X Specs, ROCm Blog
AMD MI350X 2,300 4,600 9,200 4x (FP4, projected) AMD Roadmap, SemiAnalysis

Note: All TFLOPS values are dense so I divided by 2 for most official docs. AMD MI350X values are projections based on official announcements.

Current State

Existing Private APIs

PyTorch currently has two private APIs for low-precision GEMMs:

# Primary scaled matrix multiplication
torch._scaled_mm(
    mat1: Tensor,                    # FP8_E4M3/E5M2 or FP4_E2M1 
    mat2: Tensor,                    # FP8_E4M3/E5M2 or FP4_E2M1
    scale_a: Tensor,                 # FP32 or E8M0/E4M3 for MX/NVFP4 group scaling
    scale_b: Tensor,                 # FP32 or E8M0/E4M3 for MX/NVFP4 group scaling
    bias: Optional[Tensor] = None,   # FP16/BF16 bias
    scale_result: Optional[Tensor] = None,  # For FP8 output quantization
    out_dtype: Optional[dtype] = None,      # Output precision control
    use_fast_accum: bool = False            # Hardware accumulation mode
) -> Tensor

# FP8 grouped GEMM for MoE
torch._scaled_grouped_mm(
    mat_a: Tensor,          # FP8_E4M3 only, 2D or 3D
    mat_b: Tensor,          # FP8_E4M3 only, 2D or 3D
    scale_a: Tensor,        # FP32 scales
    scale_b: Tensor,        # FP32 scales
    offs: Optional[Tensor], # INT32 offsets for jagged dims
    bias: Optional[Tensor] = None,        # Not supported yet
    scale_result: Optional[Tensor] = None, # Not supported yet
    out_dtype: Optional[dtype] = None,
    use_fast_accum: bool = False
) -> Tensor

Current Private APIs and Scaling Support

Operation API Scaling Types Supported Hardware Limitations
Basic GEMM _scaled_mm Per-tensor, Per-row, Group-wise (MX/NVFP4) CUDA 8.9+, ROCm MI300+ No batch support, no general block-wise
Grouped GEMM (FP8) _scaled_grouped_mm Per-row only CUDA 9.0+ only FP8_E4M3 only, no bias/scale_result

Scaling Type Details

Scaling Type Support Implementation Example Shape Supported Dtypes
Per-tensor _scaled_mm only Single scale factor scale_a: [1, 1], scale_b: [1, 1] FP8_E4M3, FP8_E5M2
Per-row _scaled_mm, _scaled_grouped_mm Row/column-wise scaling scale_a: [M, 1], scale_b: [1, N] FP8_E4M3, FP8_E5M2
Group-wise (MX) _scaled_mm only E8M0 scales, 32-element groups scale_a: [M, K//32], scale_b: [K//32, N] FP8_E4M3, FP8_E5M2
Group-wise (NVFP4) _scaled_mm only E4M3 scales, 16-element groups scale_a: [M, K//16], scale_b: [K//16, N] FP4_E2M1 only
Block-wise (General) Not supported Missing - major gap DeepSeek-V3: scale_a: [M//128, K//128], scale_b: [K//128, N//128] N/A

Background

Low-Precision Matmuls on Modern Hardware

Modern AI accelerators (NVIDIA H100/B200, AMD MI300X) have introduced native support for low-precision matrix multiplications. Unlike traditional INT8 quantization, these formats use symmetric quantization with scales only (no zero points). For a number of releases 2.1.0+, the only way to actually capitalize on FP8 matmul via PyTorch core is through the private/undocumented _scaled_mm. We should change this. FP8 and lower precision formats are here to stay, and after the success that DeepSeek and others have had utilizing these formats for training and inference, I would like to propose an API for enabling this.

Shell Dtypes in PyTorch Core

PyTorch has introduced "shell dtypes" - lightweight dtype definitions that enable tensor creation and basic operations without full kernel support. These dtypes were added to core to provide a foundation for low-precision compute, even though most operations aren't directly implemented:

Shell Dtype Properties

Dtype Exponent Bits Mantissa Bits Max Value Min Normal
float8_e4m3fn 4 3 448 2^-6
float8_e5m2 5 2 57,344 2^-14
float8_e8m0fnu 8 0 2^127 2^-127
float4_e2m1fn 2 1 6 0.5

These dtypes act as containers - you can create tensors, cast to/from them, and move them between devices, but most mathematical operations will raise NotImplementedError:

# This works
x = torch.randn(10, 10, dtype=torch.float32)
x_fp8 = x.to(torch.float8_e4m3fn)

# This raises NotImplementedError - no direct arithmetic
# y = x_fp8 + x_fp8  # Error!

# Must use specialized ops
y = torch._scaled_mm(x_fp8, x_fp8, scale_a, scale_b)

Actual compute happens through specialized operations that understand the scaling so that dequantization and quantization is handled correctly intra kernel.

Proposed PyTorch APIs

Design Principles

  1. Explicit over Implicit: Scaling strategy is explicit in the API
  2. Hardware Agnostic: Works across different accelerators
  3. Future Proof: Extensible to new formats (NVFP4, etc.)
  4. Performance First: Aligned to hardware
  5. Familiar Patterns: Follows existing PyTorch API conventions

This doc presents two different approaches for exposing low-precision GEMM operations:

Approach 1: Dedicated scaled_ Function Family

Create new public APIs specifically for scaled operations:

def torch.scaled_addmm(
    input: Optional[Tensor],  # Can be None for mm
    mat1: Tensor,
    mat2: Tensor,
    scale_mat1: Tensor,
    scale_mat2: Tensor,
    *,
    beta: float = 1.0,
    alpha: float = 1.0,
    scale_result: Optional[Tensor] = None,
    out_dtype: Optional[dtype] = None,
    # use_fast_accum: bool = False # TBD
) -> Tensor:
    """
    Scaled matrix multiply-add: beta * input + alpha * (scale_mat1 * mat1) @ (scale_mat2 * mat2)
    
    Scales are applied during dequantization before accumulation in FP32.
    When out_dtype is 1 byte or less, scale_result is required for output quantization.
    When input is None, beta is ignored (acts as scaled_mm).
    """

def torch.scaled_baddbmm(
    input: Optional[Tensor],  # Can be None for bmm
    batch1: Tensor,
    batch2: Tensor,
    scale_batch1: Tensor,
    scale_batch2: Tensor,
    *,
    beta: float = 1.0,
    alpha: float = 1.0,
    scale_result: Optional[Tensor] = None,
    out_dtype: Optional[dtype] = None,
    # use_fast_accum: bool = False
) -> Tensor:
    """
    Scaled batch add matrix-matrix product: beta * input + alpha * (scale_batch1 * batch1) @ (scale_batch2 * batch2)
    
    Scales are applied during dequantization before accumulation in FP32.
    When input is None, acts as scaled_bmm.
    """

def torch.scaled_grouped_addmm(
    input: Optional[Tensor] = None,
    mat1: Tensor,
    mat2: Tensor,
    scale_mat1: Tensor,
    scale_mat2: Tensor,
    *,
    beta: float = 1.0,
    alpha: float = 1.0,
    offs: Optional[Tensor] = None,  # Existing API only as 1 set of offsets
    scale_result: Optional[Tensor] = None,
    out_dtype: Optional[dtype] = None,
    # use_fast_accum: bool = False
) -> Tensor:
    """
    Grouped GEMM for MoE and similar patterns: beta * input + alpha * (scale_mat1 * mat1) @ (scale_mat2 * mat2)
    
    Scales are applied during dequantization before accumulation in FP32.
    Supports jagged dimensions via offset tensors for variable expert sizes.
    When input is None, acts as scaled_grouped_mm.
    """

Approach 2: Overload Existing Functions

Extend current APIs to handle scaling:

def torch.addmm(
    input: Tensor,
    mat1: Tensor,
    mat2: Tensor,
    *,
    beta: float = 1.0,
    alpha: float = 1.0,
    # NEW KWARGS 
    scale_input: Optional[Tensor] = None,
    scale_mat1: Optional[Tensor] = None,
    scale_mat2: Optional[Tensor] = None,
    scale_result: Optional[Tensor] = None,
    out_dtype: Optional[dtype] = None,
    use_fast_accum: bool = False
) -> Tensor:
    """
    When mat1/mat2 are FP<=8, corresponding scales become required.
    Raises ValueError if FP<=8 tensors provided without scales.
    """

def torch.baddbmm(
    input: Tensor,
    batch1: Tensor,
    batch2: Tensor,
    *,
    beta: float = 1.0,
    alpha: float = 1.0,
    # NEW KWARGS
    scale_input: Optional[Tensor] = None,
    scale_batch1: Optional[Tensor] = None,
    scale_batch2: Optional[Tensor] = None,
    scale_result: Optional[Tensor] = None,
    out_dtype: Optional[dtype] = None,
    use_fast_accum: bool = False
) -> Tensor:
    """
    When batch1/batch2 are FP<=8, corresponding scales become required.
    """

# New Public function - no existing equivalent
def torch.grouped_addmm(
    input: Optional[Tensor],
    mat1: Tensor,
    mat2: Tensor,
    *,
    beta: float = 1.0,
    alpha: float = 1.0,
    offs: Optional[Tensor] = None,
    # NEW KWARGS
    scale_input: Optional[Tensor] = None,
    scale_mat1: Optional[Tensor] = None,
    scale_mat2: Optional[Tensor] = None,
    scale_result: Optional[Tensor] = None,
    out_dtype: Optional[dtype] = None,
    use_fast_accum: bool = False
) -> Tensor:
    """
    Grouped matrix multiply-add for MoE patterns.
    
    When mat1/mat2 are FP<=8, corresponding scales become required.
    """

Implementation Details

Scale Tensor Shapes and Strategies

The scaling strategy is determined by scale tensor shapes:

# Per-tensor scaling (all batches share same scale)
scale1.shape == [1, 1, 1]
scale2.shape == [1, 1, 1]

# Per-batch scaling (each batch gets one scale)
scale1.shape == [B, 1, 1]
scale2.shape == [B, 1, 1]

# Per-row scaling (can be shared or per-batch)
scale1.shape == [1, M, 1]  # Shared across batches
scale1.shape == [B, M, 1]  # Per-batch row scaling
scale2.shape == [1, 1, N]  # Shared across batches  
scale2.shape == [B, 1, N]  # Per-batch column scaling

# Group-wise scaling (can be shared or per-batch)
scale1.shape == [1, M, K//group_size]  # Shared across batches
scale1.shape == [B, M, K//group_size]  # Per-batch groups
scale2.shape == [1, K//group_size, N]  # Shared across batches
scale2.shape == [B, K//group_size, N]  # Per-batch groups

# Block-wise scaling (can be shared or per-batch)
scale1.shape == [1, M//block_m, K//block_k]  # Shared across batches
scale1.shape == [B, M//block_m, K//block_k]  # Per-batch blocks
scale2.shape == [1, K//block_k, N//block_n]  # Shared across batches
scale2.shape == [B, K//block_k, N//block_n]  # Per-batch blocks

Output Scaling

Quick tangent

In a typical low-precision GEMM operation, there are 1-2 quantization/dequantization steps:

# Step 1: Inputs are in low precision (FP8), need dequantization for accumulation
# Conceptually: high_precision_a = fp8_a * scale_a (dequant)
#               high_precision_b = fp8_b * scale_b (dequant)
# Accumulation happens in FP32: acc = high_precision_a @ high_precision_b

# Step 2: If output is FP8, need quantization from high precision accumulator
# Conceptually: fp8_output = acc * inverse_scale (quant)

This scale factor aligns the high-precision distribution with the FP8 range:

# Dynamic scaling (computed at runtime)
abs_max = tensor.abs().max()
fp8_max = torch.finfo(torch.float8_e4m3).max  # e.g., 448 for E4M3

scale = fp8_max / abs_max  # > 1.0, "stretches" values to use full FP8 range
                           # < 1.0, "compresses" values to fit in FP8

# Static calibration (pre-computed)
# Can use percentiles, moving averages, or other statistics
scale = fp8_max / calibration_stats.quantile(0.999) 

When converting from high precision (accumulator) to low precision output, a scale is required:

def scaled_addmm(..., scale_result: Optional[Tensor] = None, out_dtype: Optional[dtype] = None):
    # Accumulation always happens in high precision (FP32)
    # Input scales (scale_a, scale_b) handle dequantization
    acc = (a * scale_a) @ (b * scale_b)  # Conceptually
    
    # Output scaling invariant
    if out_dtype in [torch.float8_e5m2, torch.float8_e4m3fn, torch.float4_e2m1fn_x2]:
        if scale_result is None:
            raise ValueError(
                f"scale_result required for {out_dtype} output. "
                "Need to quantize from FP32 accumulator to low precision."
            )
        return quantize_to_fp8(acc, scale_result, out_dtype)
    else:
        # No quantization needed, return high precision
        return acc

Therefore, in order to support low precision output, we need users to be able to input a quant scale factor for the output.

Example Usage

Basic FP8 GEMM

# Simple matmul with per-tensor scaling
a_hp = torch.randn(M, K, dtype=torch.bfloat16)  # Original high precision
b_hp = torch.randn(K, N, dtype=torch.bfloat16)

# Compute scales for quantization
fp8_max = torch.finfo(torch.float8_e4m3fn).max 
scale_a = fp8_max / a_hp.abs().max()
scale_b = fp8_max / b_hp.abs().max()

# Quantize to FP8
a_fp8 = (a_hp * scale_a).to(torch.float8_e4m3fn)
b_fp8 = (b_hp * scale_b).to(torch.float8_e4m3fn)

# Pass reciprocal scales for dequantization in scaled_addmm
output = torch.scaled_addmm(None, a_fp8, b_fp8, 
                           torch.reciprocal(scale_a), 
                           torch.reciprocal(scale_b), 
                           out_dtype=torch.bfloat16)

Migration Path

From _scaled_mm to Public API

# Old
out = torch._scaled_mm(mat1, mat2, scale_a, scale_b)

# New 
out = torch.scaled_addmm(None, mat1, mat2, scale_mat1, scale_mat2)

# Or
out = torch.addmm(None, mat1, mat2, scale_mat1=scale_mat1, scale_mat2=scale_mat2)

Sticky Points → Looking for feedback

use_fast_accum

The use_fast_accum offers 5-10% performance gains over the slow_accum variant. Due to the reduced precision of this accumulation mode, training has tended to use it differently between forward and inference or not use it entirely. I don't know if we should bring this API to the public APIs, maybe we can have a global setting that enables this feature.

One option could be to follow PyTorch's existing patterns for hardware-specific optimizations:

# Similar to torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_fp8_fast_accum = False  # Default

This would keep function signatures clean while still allowing users to opt in.

Explicit scale enum versus shape inference

The above proposed API infers the type of scaling from the shape of the input scales (and ultimately we use this info to choose between a set of kernels).
An alternative implementation is to go for a belt and suspenders approach and also add an Enum that users can manually specify the scaling they are trying to do. I don't love this. It feels unintuitive and can lead to circumstances where users can get into inconsistent input states. As well, what if people want to do groupwise for scale1 and blockwise for scale2? Note that I think every scaling format is actually block-wise.

There is one outstanding counter point that makes this annoying. For MX inputs on B200 we need to swizzle the scales to a specific layout to work with the TCgen5 mma atom.
The specifics of the layout can be found here: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout and a Pytorch swizzle implementation is here. Because we need to re-arrange the scales we can no longer directly infer which scales maps to which data element in the tensor.

We can perhaps just make an exception here and establish some shape convention here:

# Regular MX scales: [M, K//32]
# Pre-swizzled and padded MX scales: [32*ceil_div(M,128) * 16*ceil_div((K//32),4)]

Note that on AMD the scales do not need to be swizzled so the shape inference still works on that platform.

Autograd support

We don't intend to specify the end to end workflows and instead expect these primitives to be utilized within specific flows. This means that these new GEMMs will not define backwards formulas. Instead we expect that people will use them in conjunction with tensor subclasses and/or autograd functions to enable training.

Examples:
Torchao: https://github.com/pytorch/ao/tree/main/torchao/float8
Lingua: https://github.com/facebookresearch/lingua/blob/437d680e521873bb5971067148a69587790da853/lingua/float8.py#L83

Tensor Data Swizzling

During the decode phase, GEMMs are often memory bound rather than compute bound. In these cases, pre-swizzling low precision data
tensors can provide significant speedups.

The core issue is that PyTorch stores weights in row-major format, but tensor cores require a specific register layout. For standard
data types (FP16/FP32), NVIDIA provides the ldmatrix instruction for efficient shuffling. However, for 4-bit types, no such
instruction exists, forcing kernels to use multiple inefficient 8-bit shared memory loads.

Pre-swizzling solves this by reordering data ahead of time, enabling:

  • Single 32-bit loads instead of four 8-bit loads
  • Up to 128-bit loads by interleaving multiple tiles
  • Better shared memory bandwidth utilization

Examples: Machete kernel

The current recommendation is to not support pre-swizzled data tensors in the public API. This seems overly conservative given the
performance benefits, but keeps the API simpler by avoiding another layout dimension.

Layout Support

Most formats require that the inputs be in TN format (RowMajor mat1, ColMajor mat2). This requirement has been lifted for FP8 + per-tensor inputs on B200 but still remains for H100 and the MX dtypes. This can be quite restrictive and maybe counter-intuitive for PyTorch users. Just wanted to call it out.

Future arguments

What happens if we need to add new arguments in the future? NVFP4 is a good example of this challenge. NVFP4 uses a two-level scaling strategy - first level uses E4M3 FP8 for each 16-value micro-block, second level applies an FP32 scale at the tensor level. This creates a scaling hierarchy that doesn't fit cleanly into our current scale_mat1/scale_mat2 setup.

For NVFP4 inputs and higher precision outputs, we can probably reuse scale_result for the second-level scaling. But if we want to do NVFP4 → NVFP4, we'd need additional scale_d_in and scale_d_out parameters.

# This would work for NVFP4 → FP16
output = torch.scaled_addmm(None, mat1_nvfp4, mat2_nvfp4, 
                           scale_mat1, scale_mat2,
                           scale_result=per_tensor_scale,  # Reuse for second-level
                           out_dtype=torch.bfloat16)

# But NVFP4 → NVFP4 would need new args
output = torch.scaled_addmm(None, mat1_nvfp4, mat2_nvfp4, 
                           scale_mat1, scale_mat2,
                           scale_d_in=scale_d_in,    # NEW 
                           scale_d_out=scale_d_out,  # NEW
                           out_dtype=torch.float4_nvfp4)

This is where Approach 1 (dedicated functions) might be more future-proof. Adding new parameters to torch.scaled_addmm is cleaner than polluting core PyTorch functions like torch.addmm with an ever-growing list of format-specific scaling parameters.

References

Technical Reports and Documentation

Hardware Specifications and Benchmarks

Implementation Examples and Code

PyTorch Evolution

cc @ptrblck @msaroufim @eqy @jerryzh168 @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesmodule: rocmAMD GPU support for Pytorchtopic: performancetopic categorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy