-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
Summary
In Cuda 12.9 cublas released support for an expanded set of scaling strategies besides just per-tensor: https://developer.nvidia.com/blog/boosting-matrix-multiplication-speed-and-flexibility-with-nvidia-cublas-12-9/
Currently on Cuda:
SM89
_scaled_mm
dispatches to one of 2 backends on H100:
- Per-Tensor scaling -> CublasLT
- Per-Row scaling -> RowWise Cutlass kernel
- GroupWise Scaling -> Not supported | some support in AO
- BlockWise Scaling -> Not supported | some support in AO
H100
_scaled_mm
dispatches to one of 2 backends on H100:
- Per-Tensor scaling -> CublasLT
- Per-Row scaling -> RowWise Cutlass kernel
- GroupWise Scaling -> Not supported | some support in AO
- BlockWise Scaling -> Not supported | some support in AO
B200
_scaled_mm
dispatches to one of 2 backends on H100:
- Per-Tensor scaling -> CublasLT
- Per-Row scaling -> RowWise Cutlass kernel (template is not optimal)
- GroupWise Scaling -> MXFP8 BlockWise scaling is support via CublasLT
- BlockWise Scaling -> Not supported
We should add new cublas bindings to enable this more performant code path.
Blockers
We ideally would remove the cutlass templates since Cublas claims appear to be universally more performant. The main blocker is that we would lose support for SM89 hardware
We don't currently ship a prebuilt version of PyTorch for 12.9
cc @ptrblck @msaroufim @eqy @jerryzh168 @yanbing-j @vkuzo @albanD @kadeng @penguinwu @ngimel, @lw