-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
RFC: PyTorch Low-Precision GEMMs Public API
Summary
This RFC proposes public APIs for low-precision matrix multiplications in PyTorch, addressing the current gap where users must rely on private/undocumented APIs like _scaled_mm
and _scaled_grouped_mm
. With the rapid adoption of FP8 and other low-precision formats (DeepSeek-V3, Llama 4, NVIDIA H100/B200, AMD MI300+), PyTorch needs robust, officially supported APIs for scaled matrix operations.
The proposal presents two approaches for supporting various scaling strategies (per-tensor, per-row, group-wise, block-wise):
Approach 1: New Dedicated Functions
torch.scaled_mm
- Basic scaled matrix multiply (when bias=None in scaled_addmm)torch.scaled_addmm
- Scaled matrix multiply with bias additiontorch.scaled_bmm
- Batched scaled matrix multiply (when bias=None in scaled_baddbmm)torch.scaled_baddbmm
- Batched scaled matrix multiply with bias additiontorch.scaled_grouped_mm
- Grouped GEMM for MoE workloads (when bias=None)torch.scaled_grouped_addmm
- Grouped GEMM with bias for MoE workloads
Approach 2: Extend Existing Functions
- Add optional scaling parameters to
torch.addmm
,torch.baddbmm
,torch.bmm
, etc.
Both approaches have trade-offs in terms of API clarity, type safety, and user experience.
Authors
Motivation
Industry Adoption and Demand
Low-precision matrix multiplications have moved from research to production:
- DeepSeek-V3: Uses FP8 with group-wise scaling (128 elements/group for activations, 128×128 blocks for weights) [Technical Report]
- Llama 4: Pre-trained using FP8 precision, achieving 390 TFLOPs/GPU utilization [Meta Blog]
- NVIDIA H100/B200: Native hardware support for FP8 E4M3/E5M2 with various scaling strategies
Performance and Memory Benefits

Source: https://semianalysis.com/2025/06/13/amd-advancing-ai-mi350x-and-mi400-ualoe72-mi500-ual256/
Platform | FP16 TFLOPS | FP8 TFLOPS | FP4/MX4 TFLOPS | Improvement vs FP16 | References |
---|---|---|---|---|---|
NVIDIA H100 | 989.5 | 1,978.5 | N/A | 2.0x | H100 Datasheet |
NVIDIA B200 | 2,250 | 4,500 | 9,000 | 2x (FP8), 4x (FP4) | MLPerf v4.1, Blackwell Arch |
AMD MI300X | 1,307 | 2,615 | N/A | 2x (FP8) | MI300X Specs, ROCm Blog |
AMD MI350X | 2,300 | 4,600 | 9,200 | 4x (FP4, projected) | AMD Roadmap, SemiAnalysis |
Note: All TFLOPS values are dense so I divided by 2 for most official docs. AMD MI350X values are projections based on official announcements.
Current State
Existing Private APIs
PyTorch currently has two private APIs for low-precision GEMMs:
# Primary scaled matrix multiplication
torch._scaled_mm(
mat1: Tensor, # FP8_E4M3/E5M2 or FP4_E2M1
mat2: Tensor, # FP8_E4M3/E5M2 or FP4_E2M1
scale_a: Tensor, # FP32 or E8M0/E4M3 for MX/NVFP4 group scaling
scale_b: Tensor, # FP32 or E8M0/E4M3 for MX/NVFP4 group scaling
bias: Optional[Tensor] = None, # FP16/BF16 bias
scale_result: Optional[Tensor] = None, # For FP8 output quantization
out_dtype: Optional[dtype] = None, # Output precision control
use_fast_accum: bool = False # Hardware accumulation mode
) -> Tensor
# FP8 grouped GEMM for MoE
torch._scaled_grouped_mm(
mat_a: Tensor, # FP8_E4M3 only, 2D or 3D
mat_b: Tensor, # FP8_E4M3 only, 2D or 3D
scale_a: Tensor, # FP32 scales
scale_b: Tensor, # FP32 scales
offs: Optional[Tensor], # INT32 offsets for jagged dims
bias: Optional[Tensor] = None, # Not supported yet
scale_result: Optional[Tensor] = None, # Not supported yet
out_dtype: Optional[dtype] = None,
use_fast_accum: bool = False
) -> Tensor
Current Private APIs and Scaling Support
Operation | API | Scaling Types Supported | Hardware | Limitations |
---|---|---|---|---|
Basic GEMM | _scaled_mm |
Per-tensor, Per-row, Group-wise (MX/NVFP4) | CUDA 8.9+, ROCm MI300+ | No batch support, no general block-wise |
Grouped GEMM (FP8) | _scaled_grouped_mm |
Per-row only | CUDA 9.0+ only | FP8_E4M3 only, no bias/scale_result |
Scaling Type Details
Scaling Type | Support | Implementation | Example Shape | Supported Dtypes |
---|---|---|---|---|
Per-tensor | _scaled_mm only |
Single scale factor | scale_a: [1, 1] , scale_b: [1, 1] |
FP8_E4M3, FP8_E5M2 |
Per-row | _scaled_mm , _scaled_grouped_mm |
Row/column-wise scaling | scale_a: [M, 1] , scale_b: [1, N] |
FP8_E4M3, FP8_E5M2 |
Group-wise (MX) | _scaled_mm only |
E8M0 scales, 32-element groups | scale_a: [M, K//32] , scale_b: [K//32, N] |
FP8_E4M3, FP8_E5M2 |
Group-wise (NVFP4) | _scaled_mm only |
E4M3 scales, 16-element groups | scale_a: [M, K//16] , scale_b: [K//16, N] |
FP4_E2M1 only |
Block-wise (General) | Not supported | Missing - major gap | DeepSeek-V3: scale_a: [M//128, K//128] , scale_b: [K//128, N//128] |
N/A |
Background
Low-Precision Matmuls on Modern Hardware
Modern AI accelerators (NVIDIA H100/B200, AMD MI300X) have introduced native support for low-precision matrix multiplications. Unlike traditional INT8 quantization, these formats use symmetric quantization with scales only (no zero points). For a number of releases 2.1.0+, the only way to actually capitalize on FP8 matmul via PyTorch core is through the private/undocumented _scaled_mm
. We should change this. FP8 and lower precision formats are here to stay, and after the success that DeepSeek and others have had utilizing these formats for training and inference, I would like to propose an API for enabling this.
Shell Dtypes in PyTorch Core
PyTorch has introduced "shell dtypes" - lightweight dtype definitions that enable tensor creation and basic operations without full kernel support. These dtypes were added to core to provide a foundation for low-precision compute, even though most operations aren't directly implemented:
Shell Dtype Properties
Dtype | Exponent Bits | Mantissa Bits | Max Value | Min Normal |
---|---|---|---|---|
float8_e4m3fn |
4 | 3 | 448 | 2^-6 |
float8_e5m2 |
5 | 2 | 57,344 | 2^-14 |
float8_e8m0fnu |
8 | 0 | 2^127 | 2^-127 |
float4_e2m1fn |
2 | 1 | 6 | 0.5 |
These dtypes act as containers - you can create tensors, cast to/from them, and move them between devices, but most mathematical operations will raise NotImplementedError
:
# This works
x = torch.randn(10, 10, dtype=torch.float32)
x_fp8 = x.to(torch.float8_e4m3fn)
# This raises NotImplementedError - no direct arithmetic
# y = x_fp8 + x_fp8 # Error!
# Must use specialized ops
y = torch._scaled_mm(x_fp8, x_fp8, scale_a, scale_b)
Actual compute happens through specialized operations that understand the scaling so that dequantization and quantization is handled correctly intra kernel.
Proposed PyTorch APIs
Design Principles
- Explicit over Implicit: Scaling strategy is explicit in the API
- Hardware Agnostic: Works across different accelerators
- Future Proof: Extensible to new formats (NVFP4, etc.)
- Performance First: Aligned to hardware
- Familiar Patterns: Follows existing PyTorch API conventions
This doc presents two different approaches for exposing low-precision GEMM operations:
Approach 1: Dedicated scaled_
Function Family
Create new public APIs specifically for scaled operations:
def torch.scaled_addmm(
input: Optional[Tensor], # Can be None for mm
mat1: Tensor,
mat2: Tensor,
scale_mat1: Tensor,
scale_mat2: Tensor,
*,
beta: float = 1.0,
alpha: float = 1.0,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[dtype] = None,
# use_fast_accum: bool = False # TBD
) -> Tensor:
"""
Scaled matrix multiply-add: beta * input + alpha * (scale_mat1 * mat1) @ (scale_mat2 * mat2)
Scales are applied during dequantization before accumulation in FP32.
When out_dtype is 1 byte or less, scale_result is required for output quantization.
When input is None, beta is ignored (acts as scaled_mm).
"""
def torch.scaled_baddbmm(
input: Optional[Tensor], # Can be None for bmm
batch1: Tensor,
batch2: Tensor,
scale_batch1: Tensor,
scale_batch2: Tensor,
*,
beta: float = 1.0,
alpha: float = 1.0,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[dtype] = None,
# use_fast_accum: bool = False
) -> Tensor:
"""
Scaled batch add matrix-matrix product: beta * input + alpha * (scale_batch1 * batch1) @ (scale_batch2 * batch2)
Scales are applied during dequantization before accumulation in FP32.
When input is None, acts as scaled_bmm.
"""
def torch.scaled_grouped_addmm(
input: Optional[Tensor] = None,
mat1: Tensor,
mat2: Tensor,
scale_mat1: Tensor,
scale_mat2: Tensor,
*,
beta: float = 1.0,
alpha: float = 1.0,
offs: Optional[Tensor] = None, # Existing API only as 1 set of offsets
scale_result: Optional[Tensor] = None,
out_dtype: Optional[dtype] = None,
# use_fast_accum: bool = False
) -> Tensor:
"""
Grouped GEMM for MoE and similar patterns: beta * input + alpha * (scale_mat1 * mat1) @ (scale_mat2 * mat2)
Scales are applied during dequantization before accumulation in FP32.
Supports jagged dimensions via offset tensors for variable expert sizes.
When input is None, acts as scaled_grouped_mm.
"""
Approach 2: Overload Existing Functions
Extend current APIs to handle scaling:
def torch.addmm(
input: Tensor,
mat1: Tensor,
mat2: Tensor,
*,
beta: float = 1.0,
alpha: float = 1.0,
# NEW KWARGS
scale_input: Optional[Tensor] = None,
scale_mat1: Optional[Tensor] = None,
scale_mat2: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[dtype] = None,
use_fast_accum: bool = False
) -> Tensor:
"""
When mat1/mat2 are FP<=8, corresponding scales become required.
Raises ValueError if FP<=8 tensors provided without scales.
"""
def torch.baddbmm(
input: Tensor,
batch1: Tensor,
batch2: Tensor,
*,
beta: float = 1.0,
alpha: float = 1.0,
# NEW KWARGS
scale_input: Optional[Tensor] = None,
scale_batch1: Optional[Tensor] = None,
scale_batch2: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[dtype] = None,
use_fast_accum: bool = False
) -> Tensor:
"""
When batch1/batch2 are FP<=8, corresponding scales become required.
"""
# New Public function - no existing equivalent
def torch.grouped_addmm(
input: Optional[Tensor],
mat1: Tensor,
mat2: Tensor,
*,
beta: float = 1.0,
alpha: float = 1.0,
offs: Optional[Tensor] = None,
# NEW KWARGS
scale_input: Optional[Tensor] = None,
scale_mat1: Optional[Tensor] = None,
scale_mat2: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[dtype] = None,
use_fast_accum: bool = False
) -> Tensor:
"""
Grouped matrix multiply-add for MoE patterns.
When mat1/mat2 are FP<=8, corresponding scales become required.
"""
Implementation Details
Scale Tensor Shapes and Strategies
The scaling strategy is determined by scale tensor shapes:
# Per-tensor scaling (all batches share same scale)
scale1.shape == [1, 1, 1]
scale2.shape == [1, 1, 1]
# Per-batch scaling (each batch gets one scale)
scale1.shape == [B, 1, 1]
scale2.shape == [B, 1, 1]
# Per-row scaling (can be shared or per-batch)
scale1.shape == [1, M, 1] # Shared across batches
scale1.shape == [B, M, 1] # Per-batch row scaling
scale2.shape == [1, 1, N] # Shared across batches
scale2.shape == [B, 1, N] # Per-batch column scaling
# Group-wise scaling (can be shared or per-batch)
scale1.shape == [1, M, K//group_size] # Shared across batches
scale1.shape == [B, M, K//group_size] # Per-batch groups
scale2.shape == [1, K//group_size, N] # Shared across batches
scale2.shape == [B, K//group_size, N] # Per-batch groups
# Block-wise scaling (can be shared or per-batch)
scale1.shape == [1, M//block_m, K//block_k] # Shared across batches
scale1.shape == [B, M//block_m, K//block_k] # Per-batch blocks
scale2.shape == [1, K//block_k, N//block_n] # Shared across batches
scale2.shape == [B, K//block_k, N//block_n] # Per-batch blocks
Output Scaling
Quick tangent
In a typical low-precision GEMM operation, there are 1-2 quantization/dequantization steps:
# Step 1: Inputs are in low precision (FP8), need dequantization for accumulation
# Conceptually: high_precision_a = fp8_a * scale_a (dequant)
# high_precision_b = fp8_b * scale_b (dequant)
# Accumulation happens in FP32: acc = high_precision_a @ high_precision_b
# Step 2: If output is FP8, need quantization from high precision accumulator
# Conceptually: fp8_output = acc * inverse_scale (quant)
This scale factor aligns the high-precision distribution with the FP8 range:
# Dynamic scaling (computed at runtime)
abs_max = tensor.abs().max()
fp8_max = torch.finfo(torch.float8_e4m3).max # e.g., 448 for E4M3
scale = fp8_max / abs_max # > 1.0, "stretches" values to use full FP8 range
# < 1.0, "compresses" values to fit in FP8
# Static calibration (pre-computed)
# Can use percentiles, moving averages, or other statistics
scale = fp8_max / calibration_stats.quantile(0.999)
When converting from high precision (accumulator) to low precision output, a scale is required:
def scaled_addmm(..., scale_result: Optional[Tensor] = None, out_dtype: Optional[dtype] = None):
# Accumulation always happens in high precision (FP32)
# Input scales (scale_a, scale_b) handle dequantization
acc = (a * scale_a) @ (b * scale_b) # Conceptually
# Output scaling invariant
if out_dtype in [torch.float8_e5m2, torch.float8_e4m3fn, torch.float4_e2m1fn_x2]:
if scale_result is None:
raise ValueError(
f"scale_result required for {out_dtype} output. "
"Need to quantize from FP32 accumulator to low precision."
)
return quantize_to_fp8(acc, scale_result, out_dtype)
else:
# No quantization needed, return high precision
return acc
Therefore, in order to support low precision output, we need users to be able to input a quant scale factor for the output.
Example Usage
Basic FP8 GEMM
# Simple matmul with per-tensor scaling
a_hp = torch.randn(M, K, dtype=torch.bfloat16) # Original high precision
b_hp = torch.randn(K, N, dtype=torch.bfloat16)
# Compute scales for quantization
fp8_max = torch.finfo(torch.float8_e4m3fn).max
scale_a = fp8_max / a_hp.abs().max()
scale_b = fp8_max / b_hp.abs().max()
# Quantize to FP8
a_fp8 = (a_hp * scale_a).to(torch.float8_e4m3fn)
b_fp8 = (b_hp * scale_b).to(torch.float8_e4m3fn)
# Pass reciprocal scales for dequantization in scaled_addmm
output = torch.scaled_addmm(None, a_fp8, b_fp8,
torch.reciprocal(scale_a),
torch.reciprocal(scale_b),
out_dtype=torch.bfloat16)
Migration Path
From _scaled_mm to Public API
# Old
out = torch._scaled_mm(mat1, mat2, scale_a, scale_b)
# New
out = torch.scaled_addmm(None, mat1, mat2, scale_mat1, scale_mat2)
# Or
out = torch.addmm(None, mat1, mat2, scale_mat1=scale_mat1, scale_mat2=scale_mat2)
Sticky Points → Looking for feedback
use_fast_accum
The use_fast_accum
offers 5-10% performance gains over the slow_accum
variant. Due to the reduced precision of this accumulation mode, training has tended to use it differently between forward and inference or not use it entirely. I don't know if we should bring this API to the public APIs, maybe we can have a global setting that enables this feature.
One option could be to follow PyTorch's existing patterns for hardware-specific optimizations:
# Similar to torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_fp8_fast_accum = False # Default
This would keep function signatures clean while still allowing users to opt in.
Explicit scale enum versus shape inference
The above proposed API infers the type of scaling from the shape of the input scales (and ultimately we use this info to choose between a set of kernels).
An alternative implementation is to go for a belt and suspenders approach and also add an Enum that users can manually specify the scaling they are trying to do. I don't love this. It feels unintuitive and can lead to circumstances where users can get into inconsistent input states. As well, what if people want to do groupwise for scale1 and blockwise for scale2? Note that I think every scaling format is actually block-wise.
There is one outstanding counter point that makes this annoying. For MX inputs on B200 we need to swizzle
the scales to a specific layout to work with the TCgen5 mma atom.
The specifics of the layout can be found here: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
and a Pytorch swizzle implementation is here. Because we need to re-arrange the scales we can no longer directly infer which scales maps to which data element in the tensor.
We can perhaps just make an exception here and establish some shape convention here:
# Regular MX scales: [M, K//32]
# Pre-swizzled and padded MX scales: [32*ceil_div(M,128) * 16*ceil_div((K//32),4)]
Note that on AMD the scales do not need to be swizzled so the shape inference still works on that platform.
Autograd support
We don't intend to specify the end to end workflows and instead expect these primitives to be utilized within specific flows. This means that these new GEMMs will not define backwards formulas. Instead we expect that people will use them in conjunction with tensor subclasses and/or autograd functions to enable training.
Examples:
Torchao: https://github.com/pytorch/ao/tree/main/torchao/float8
Lingua: https://github.com/facebookresearch/lingua/blob/437d680e521873bb5971067148a69587790da853/lingua/float8.py#L83
Tensor Data Swizzling
During the decode phase, GEMMs are often memory bound rather than compute bound. In these cases, pre-swizzling low precision data
tensors can provide significant speedups.
The core issue is that PyTorch stores weights in row-major format, but tensor cores require a specific register layout. For standard
data types (FP16/FP32), NVIDIA provides the ldmatrix
instruction for efficient shuffling. However, for 4-bit types, no such
instruction exists, forcing kernels to use multiple inefficient 8-bit shared memory loads.
Pre-swizzling solves this by reordering data ahead of time, enabling:
- Single 32-bit loads instead of four 8-bit loads
- Up to 128-bit loads by interleaving multiple tiles
- Better shared memory bandwidth utilization
Examples: Machete kernel
The current recommendation is to not support pre-swizzled data tensors in the public API. This seems overly conservative given the
performance benefits, but keeps the API simpler by avoiding another layout dimension.
Layout Support
Most formats require that the inputs be in TN format (RowMajor mat1, ColMajor mat2). This requirement has been lifted for FP8 + per-tensor inputs on B200 but still remains for H100 and the MX dtypes. This can be quite restrictive and maybe counter-intuitive for PyTorch users. Just wanted to call it out.
Future arguments
What happens if we need to add new arguments in the future? NVFP4 is a good example of this challenge. NVFP4 uses a two-level scaling strategy - first level uses E4M3 FP8 for each 16-value micro-block, second level applies an FP32 scale at the tensor level. This creates a scaling hierarchy that doesn't fit cleanly into our current scale_mat1
/scale_mat2
setup.
For NVFP4 inputs and higher precision outputs, we can probably reuse scale_result
for the second-level scaling. But if we want to do NVFP4 → NVFP4, we'd need additional scale_d_in
and scale_d_out
parameters.
# This would work for NVFP4 → FP16
output = torch.scaled_addmm(None, mat1_nvfp4, mat2_nvfp4,
scale_mat1, scale_mat2,
scale_result=per_tensor_scale, # Reuse for second-level
out_dtype=torch.bfloat16)
# But NVFP4 → NVFP4 would need new args
output = torch.scaled_addmm(None, mat1_nvfp4, mat2_nvfp4,
scale_mat1, scale_mat2,
scale_d_in=scale_d_in, # NEW
scale_d_out=scale_d_out, # NEW
out_dtype=torch.float4_nvfp4)
This is where Approach 1 (dedicated functions) might be more future-proof. Adding new parameters to torch.scaled_addmm
is cleaner than polluting core PyTorch functions like torch.addmm
with an ever-growing list of format-specific scaling parameters.
References
Technical Reports and Documentation
- DeepSeek-V3 Technical Report
- Llama 4 Multimodal Intelligence
- NVIDIA H100 FP8 Formats
- OCP MX Specification
- NVIDIA NVFP4 Introduction
- NVIDIA cuBLAS Block Scaling Layout
Hardware Specifications and Benchmarks
- NVIDIA H100 Datasheet
- NVIDIA MLPerf v4.1 Results
- NVIDIA Blackwell Architecture
- AMD MI300X Specifications
- AMD ROCm Performance Analysis
- AMD Roadmap Announcement
- SemiAnalysis AMD MI350X Analysis
Implementation Examples and Code
- PyTorch Shell Dtypes Issue
- _scaled_mm Gist Example
- PyTorch AO MX Formats Utils
- TorchAO Float8 Implementation
- Lingua Float8 Implementation
- Machete Mixed-Input GEMM Kernel
PyTorch Evolution
cc @ptrblck @msaroufim @eqy @jerryzh168 @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd
Metadata
Metadata
Assignees
Labels
Type
Projects
Status