-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
Edit: As discussed below, log matmul seems more acceptable
🚀 Feature
Implement GPU/CPU kernels for efficient regularized Wasserstein distance computation (Sinkhorn).
Motivation
The regularized Wasserstein distance between samples is a frequent metric used in various deep learning models.
One intuitive way to think about it is that it is a distance (in the underlying space)-aware alternative to KL divergence for measuring distances between probability measures.
In the time since Cuturi from 2013 and Frogner et al. from 2015, this has developed into a very useful general purpose tool. Recent notable applications using the Wasserstein distance include SwAV, but it is also of interest more generally to compare probability distributions or interpolate between them.
One key computational challenge is that stable implementation requires (more or less) a matrix multiplication using logsumexp instead of regular dot products, this is much easier to get with a custom kernel. (And logsumexp is much more stable than factoring out large scalars which has been proposed as a make-do stabilization.)
Pitch
I propose to implement a sinkhorn kernel.
Alternatives
Implement it a specialized library, but I think the applications are numerous and general enough to warrant inclusion into PyTorch proper.
Additional context
Related: #51777