Source code for geomfum.wrap.pot
"""Python Optimal Trasport wrapper."""
import geomstats.backend as gs
import ot
import geomfum.backend as xgs
from geomfum.convert import BaseNeighborFinder
[docs]
class PotSinkhornNeighborFinder(BaseNeighborFinder):
"""Neighbor finder based on Optimal Transport maps computed with Sinkhorn regularization.
Parameters
----------
n_neighbors : int, default=1
Number of neighbors to find.
lambd : float, default=1e-1
Regularization parameter for Sinkhorn algorithm.
method : str, default="sinkhorn"
Method to use for Sinkhorn algorithm.
max_iter : int, default=100
Maximum number of iterations for Sinkhorn algorithm.
References
----------
.. [Cuturi2013] Marco Cuturi. "Sinkhorn Distances: Lightspeed Computation of Optimal Transport."
Advances in Neural Information Processing Systems (NIPS), 2013.
http://marcocuturi.net/SI.html
"""
def __init__(self, n_neighbors=1, lambd=1e-1, method="sinkhorn", max_iter=100):
super().__init__(n_neighbors=n_neighbors)
self.lambd = lambd
self.max_iter = max_iter
self.method = method
self.X_ = None
[docs]
def fit(self, X):
"""Store the reference points.
Parameters
----------
X : array-like, shape=[n_points_x, n_features]
Reference points.
"""
self.X_ = X
return self
[docs]
def kneighbors(self, X, return_distance=True):
"""Find k nearest neighbors using Sinkhorn regularization.
Parameters
----------
X : array-like, shape=[n_points_y, n_features]
Query points.
return_distance : bool
Whether to return the distances.
Returns
-------
distances : array-like, shape=[n_points_y, n_neighbors]
Distances to the nearest neighbors.
indices : array-like, shape=[n_points_y, n_neighbors]
Indices of the nearest neighbors.
"""
M = gs.exp(-self.lambd * ot.dist(X, self.X_))
n, m = M.shape
a = gs.ones(n) / n
b = gs.ones(m) / m
# TODO: implement as sinkhorn solver?
Gs = ot.sinkhorn(a, b, M, self.lambd, self.method, self.max_iter)
indices = xgs.argsort(Gs, axis=1)[:, : self.n_neighbors]
if not return_distance:
return indices
distances = gs.array([M[i, indices[i]] for i in range(X.shape[0])])
return distances, indices