-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🚀 The feature, motivation and pitch
NVLink SHARP is an engine in NVSwitch that can perform collectives (e.g. all-reduce).
This feature reduces GPU SM consumption by as much as 6x (24 to 4), while boosting performance by 2x (its mechanism is like one-shot all-reduce, hence the 2x theoretical speedup).
To leverage this feature, please see this doc:
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html#nvlink-sharp-buffer-registration
It requires the input / output buffers be allocated through a NCCL API -- ncclMemAlloc
. The mem alloc is now enabled by a stack of PRs allowing CUDACachingAllocator to use different mem alloc backends. See:
original RFC: #124807 and
PR impl: #133603.
Target use
A first target of the feature can be DDP (in cases where we manage the gradient bucket internally).
Second target would be TP. (For example, "async-tp" -- but we'd need to know whether "async-tp" does all-reduce or not). Otherwise, if "general" TP is in Inductor's hand, we can ask Inductor to allocate specific memory for the result of matmul.
Cc: @syed-ahmed
Alternatives
No response
Additional context
No response
cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ptrblck @msaroufim