"""Conversion between pointwise and functional maps."""
import abc
import geomstats.backend as gs
import scipy
import torch
from sklearn.neighbors import NearestNeighbors
import geomfum.backend as xgs
import geomfum.wrap as _wrap # noqa (for register)
from geomfum._registry import (
SinkhornNeighborFinderRegistry,
WhichRegistryMixins,
)
from geomfum.neural_adjoint_map import NeuralAdjointMap
[docs]
class BaseP2pFromFmConverter(abc.ABC):
"""Pointwise map from functional map."""
[docs]
class P2pFromFmConverter(BaseP2pFromFmConverter):
"""Pointwise map from functional map.
Parameters
----------
neighbor_finder : NeighborFinder
Nearest neighbor finder.
adjoint : bool
Whether to use adjoint method.
bijective : bool
Whether to use bijective method. Check [VM2023]_.
References
----------
.. [OCSBG2012] Maks Ovsjanikov, Mirela Ben-Chen, Justin Solomon,
Adrian Butscher, and Leonidas Guibas.
“Functional Maps: A Flexible Representation of Maps between
Shapes.” ACM Transactions on Graphics 31, no. 4 (2012): 30:1-30:11.
https://doi.org/10.1145/2185520.2185526.
.. [VM2023] Giulio Viganò Simone Melzi. “Adjoint Bijective ZoomOut:
Efficient Upsampling for Learned Linearly-Invariant Embedding.”
The Eurographics Association, 2023. https://doi.org/10.2312/stag.20231293.
"""
def __init__(self, neighbor_finder=None, adjoint=False, bijective=False):
if neighbor_finder is None:
neighbor_finder = NearestNeighbors(
n_neighbors=1, leaf_size=40, algorithm="kd_tree", n_jobs=1
)
if neighbor_finder.n_neighbors > 1:
raise ValueError("Expects `n_neighors = 1`.")
self.neighbor_finder = neighbor_finder
self.adjoint = adjoint
self.bijective = bijective
def __call__(self, fmap_matrix, basis_a, basis_b):
"""Convert functional map.
Parameters
----------
fmap_matrix : array-like, shape=[spectrum_size_b, spectrum_size_a]
Functional map matrix.
basis_a : Basis,
Basis of the source shape.
basis_b : Basis,
Basis of the target shape.
Returns
-------
p2p : array-like, shape=[{n_vertices_b, n_vertices_a}]
Pointwise map. ``bijective`` controls shape.
"""
k2, k1 = fmap_matrix.shape
if self.adjoint:
emb1 = basis_a.full_vecs[:, :k1]
emb2 = basis_b.full_vecs[:, :k2] @ fmap_matrix
else:
emb1 = basis_a.full_vecs[:, :k1] @ fmap_matrix.T
emb2 = basis_b.full_vecs[:, :k2]
if self.bijective:
emb1, emb2 = emb2, emb1
# TODO: update neighbor finder instead
self.neighbor_finder.fit(xgs.to_device(emb1, "cpu"))
p2p_21 = self.neighbor_finder.kneighbors(
xgs.to_device(emb2, "cpu"), return_distance=False
)
return gs.from_numpy(p2p_21[:, 0])
[docs]
class BaseNeighborFinder(abc.ABC):
"""Base class for a Neighbor finder.
A simplified blueprint of ``sklearn.NearestNeighbors`` implementation.
Parameters
----------
n_neighbors : int
Number of neighbors.
"""
def __init__(self, n_neighbors=1):
self.n_neighbors = n_neighbors
[docs]
@abc.abstractmethod
def fit(self, X, y=None):
"""Store the reference points.
Parameters
----------
X : array-like, shape=[n_points_x, n_features]
Reference points.
y : Ignored
"""
[docs]
@abc.abstractmethod
def kneighbors(self, X, return_distance=True):
"""Find k nearest neighbors using Sinkhorn regularization.
Parameters
----------
X : array-like, shape=[n_points_y, n_features]
Query points.
return_distance : bool
Whether to return the distances.
Returns
-------
distances : array-like, shape=[n_points_y, n_neighbors]
Distances to the nearest neighbors, only present if
``return_distance is True``.
indices : array-like, shape=[n_points_y, n_neighbors]
Indices of the nearest neighbors.
"""
[docs]
class SinkhornNeighborFinder(WhichRegistryMixins):
"""Sinkhorn neighbor finder.
Finds neighbors based on the solution of optimal transport (OT) maps
computed with Sinkhorn regularization.
References
----------
.. [Cuturi2013] Marco Cuturi. “Sinkhorn Distances: Lightspeed Computation
of Optimal Transport.”
Advances in Neural Information Processing Systems (NIPS), 2013.
http://marcocuturi.net/SI.html
"""
_Registry = SinkhornNeighborFinderRegistry
[docs]
class SinkhornP2pFromFmConverter(P2pFromFmConverter):
"""Pointwise map from functional map using Sinkhorn filters.
Parameters
----------
neighbor_finder : SinkhornKNeighborsFinder
Nearest neighbor finder.
adjoint : bool
Whether to use adjoint method.
bijective : bool
Whether to use bijective method. Check [VM2023]_.
References
----------
.. [PRMWO2021] Gautam Pai, Jing Ren, Simone Melzi, Peter Wonka, and Maks Ovsjanikov.
"Fast Sinkhorn Filters: Using Matrix Scaling for Non-Rigid Shape Correspondence
with Functional Maps." Proceedings of the IEEE/CVF Conference on Computer Vision
and Pattern Recognition (CVPR), 2021, pp. 11956-11965.
https://hal.science/hal-03184936/document
"""
def __init__(
self,
neighbor_finder=None,
adjoint=False,
bijective=False,
):
if neighbor_finder is None:
neighbor_finder = SinkhornNeighborFinder.from_registry(which="pot")
super().__init__(
neighbor_finder=neighbor_finder,
adjoint=adjoint,
bijective=bijective,
)
[docs]
class BaseFmFromP2pConverter(abc.ABC):
"""Functional map from pointwise map."""
[docs]
class FmFromP2pConverter(BaseFmFromP2pConverter):
"""Functional map from pointwise map.
Parameters
----------
pseudo_inverse : bool
Whether to solve using pseudo-inverse.
"""
# TODO: add subsampling
def __init__(self, pseudo_inverse=False):
self.pseudo_inverse = pseudo_inverse
def __call__(self, p2p, basis_a, basis_b):
"""Convert point to point map.
Parameters
----------
p2p : array-like, shape=[n_vertices_b]
Poinwise map.
basis_a : Basis,
Basis of the source shape.
basis_b : Basis,
Basis of the target shape.
Returns
-------
fmap_matrix : array-like, shape=[spectrum_size_b, spectrum_size_a]
Functional map matrix.
"""
evects1_pb = basis_a.vecs[p2p, :]
if self.pseudo_inverse:
return basis_b.vecs.T @ (basis_b._shape.laplacian.mass_matrix @ evects1_pb)
return gs.from_numpy(scipy.linalg.lstsq(basis_b.vecs, evects1_pb)[0])
[docs]
class FmFromP2pBijectiveConverter(BaseFmFromP2pConverter):
"""Bijective functional map from pointwise map method.
References
----------
.. [VM2024] Giulio Viganò Simone Melzi. Bijective upsampling and learned embedding for point clouds correspondences.
Computers and Graphics, 2024. https://doi.org/10.1016/j.cag.2024.103985.
"""
def __call__(self, p2p, basis_a, basis_b):
"""Convert point to point map.
Parameters
----------
p2p : array-like, shape=[n_vertices_a]
Pointwise map.
basis_a : Basis,
Basis of the source shape.
basis_b : Basis,
Basis of the target shape.
Returns
-------
fmap_matrix : array-like, shape=[spectrum_size_b, spectrum_size_a]
Functional map matrix.
"""
evects2_pb = basis_b.vecs[p2p, :]
return gs.from_numpy(scipy.linalg.lstsq(evects2_pb, basis_a.vecs)[0])
[docs]
class NamFromP2pConverter(BaseFmFromP2pConverter):
"""Neural Adjoint Map from pointwise map using Neural Adjoint Maps (NAMs)."""
def __init__(self, iter_max=200, patience=10, min_delta=1e-4, device="cpu"):
"""Initialize the converter.
Parameters
----------
iter_max : int, optional
Maximum number of iterations for training the Neural Adjoint Map.
patience : int, optional
Number of iterations with no improvement after which training will be stopped.
min_delta : float, optional
Minimum change in the loss to qualify as an improvement.
device : str, optional
Device to use for the Neural Adjoint Map (e.g., 'cpu' or 'cuda').
"""
self.iter_max = iter_max
self.device = device
self.min_delta = min_delta
self.patience = patience
def __call__(self, p2p, basis_a, basis_b, optimizer=None):
"""Convert point to point map.
Parameters
----------
p2p : array-like, shape=[n_vertices_b]
Pointwise map.
basis_a : Basis,
Basis of the source shape.
basis_b : Basis,
Basis of the target shape.
optimizer : torch.optim.Optimizer, optional
Optimizer for training the Neural Adjoint Map.
Returns
-------
nam: NeuralAdjointMap , shape=[spectrum_size_b, spectrum_size_a]
Neural Adjoint Map model.
"""
evects2_pb = xgs.to_torch(basis_b.vecs[p2p, :]).to(self.device).double()
evects1 = xgs.to_torch(basis_a.vecs).to(self.device).double()
nam = NeuralAdjointMap(
input_dim=basis_a.spectrum_size,
output_dim=basis_b.spectrum_size,
device=self.device,
).double()
if optimizer is None:
optimizer = torch.optim.Adam(nam.parameters(), lr=0.01, weight_decay=1e-5)
best_loss = float("inf")
wait = 0
for _ in range(self.iter_max):
optimizer.zero_grad()
pred = nam(evects2_pb)
loss = torch.nn.functional.mse_loss(pred, evects1)
loss.backward()
optimizer.step()
if loss.item() < best_loss - self.min_delta:
best_loss = loss.item()
wait = 0
else:
wait += 1
if wait >= self.patience:
break
return nam
[docs]
class P2pFromNamConverter(BaseP2pFromFmConverter):
"""Pointwise map from Neural Adjoint Map (NAM).
Parameters
----------
neighbor_finder : NeighborFinder
Nearest neighbor finder.
"""
def __init__(self, neighbor_finder=None):
if neighbor_finder is None:
neighbor_finder = NearestNeighbors(
n_neighbors=1, leaf_size=40, algorithm="kd_tree", n_jobs=1
)
if neighbor_finder.n_neighbors > 1:
raise ValueError("Expects `n_neighors = 1`.")
self.neighbor_finder = neighbor_finder
def __call__(self, nam, basis_a, basis_b):
"""Convert neural adjoint map.
Parameters
----------
nam : NeuralAdjointMap, shape=[spectrum_size_b, spectrum_size_a]
Nam model.
basis_a : Basis,
Basis of the source shape.
basis_b : Basis,
Basis of the target shape.
Returns
-------
p2p : array-like, shape=[{n_vertices_b, n_vertices_a}]
Pointwise map.
"""
k2, k1 = nam.shape
emb1 = xgs.to_torch(basis_a.full_vecs[:, :k1]).to(nam.device).double()
emb2 = nam(xgs.to_torch(basis_b.full_vecs[:, :k2]).to(nam.device).double())
# TODO: update neighbor finder instead
self.neighbor_finder.fit(emb2.detach().cpu())
p2p_21 = self.neighbor_finder.kneighbors(
emb1.detach().cpu(), return_distance=False
)
return gs.from_numpy(p2p_21[:, 0])