Source code for geomfum.forward_functional_map

"""Optimization of the functional map with a forward pass."""

import abc

import geomstats.backend as gs
import torch
import torch.nn as nn

import geomfum.backend as xgs


[docs] class ForwardFunctionalMap(abc.ABC, nn.Module): """Class for the forward pass of the functional map. Parameters ---------- lmbda : float Weight of the mask (default: 1e3). resolvent_gamma: float Resolvant of the regularized functional map (default: 1). bijective: bool Whether we compute the map in both the directions (default: True). """ def __init__(self, lmbda=1e3, resolvent_gamma=1, bijective=True): super(ForwardFunctionalMap, self).__init__() self.lmbda = lmbda self.resolvent_gamma = resolvent_gamma self.bijective = bijective def _compute_functional_map(self, sdescr_a, sdescr_b, mask): """Compute the functional map between two shapes. Parameters ---------- sdescr_a : array-like, shape=[..., spectrum_size_a] Spectral descriptors on first basis. sdescr_b : array-like, shape=[..., spectrum_size_b] Spectral descriptors on second basis. mask: array-like, shape=[..., spectrum_size_a, spectrum_size_b] Mask for the functional map. Returns ------- fmap12 : array-like, shape=[..., spectrum_size_a, spectrum_size_b] Functional map from shape a to shape b. """ At_A = sdescr_a.T @ sdescr_a Bt_A = sdescr_b.T @ sdescr_a fmap = [] for i in range(mask.shape[0]): if self.lmbda == 0: map_row = gs.linalg.inv(At_A) @ Bt_A[i, :].reshape(-1, 1) else: MASK_i = xgs.diag(mask[i, :].flatten()) map_row = gs.linalg.inv(At_A + self.lmbda * MASK_i) @ Bt_A[ i, : ].reshape(-1, 1) fmap.append(map_row.T) fmap = gs.concatenate(fmap, 0) return fmap def __call__(self, mesh_a, mesh_b, descr_a, descr_b): """Compute the functional map between two shapes. Parameters ---------- mesh_a : TriangleMesh Mesh object representing the first shape. mesh_b : TriangleMesh Mesh object representing the second shape. descr_a : array-like, shape=[D, ...] Spectral descriptors on the first shape. descr_b : array-like, shape=[D, ...] Spectral descriptors on the second shape. Returns ------- fmap_12 : array-like, shape[spectrum_size_a, spectrum_size_b] Functional map from shape a to shape b. fmap_21: array-like, shape=[spectrum_size_b, spectrum_size_a] or None Functional map from shape b to shape a if bijective, otherwise None. """ evals_a = mesh_a.basis.vals sdescr_a = mesh_a.basis.project(descr_a) evals_b = mesh_b.basis.vals sdescr_b = mesh_b.basis.project(descr_b) mask = self._compute_mask(evals_a, evals_b, self.resolvent_gamma) fmap_12 = self._compute_functional_map(sdescr_a, sdescr_b, mask) fmap_21 = None if self.bijective: mask = self._compute_mask(evals_b, evals_a, self.resolvent_gamma) fmap_21 = self._compute_functional_map(sdescr_b, sdescr_a, mask) return fmap_12, fmap_21 def _compute_mask(self, evals_a, evals_b, resolvant_gamma): """Compute the mask for the functional map. Parameters ---------- evals_a : array-like, shape=[..., spectrum_size_a] Eigenvalues of the first shape. evals_b : array-like, shape=[..., spectrum_size_b] Eigenvalues of the second shape. resolvant_gamma : float Resolvent of the regularized functional map. Returns ------- mask : array-like, shape=[..., spectrum_size_a, spectrum_size_b] Mask for the functional map. """ evals_a = gs.array(evals_a) evals_b = gs.array(evals_b) scaling_factor = max(max(evals_a), max(evals_b)) evals_a, evals_b = evals_a / scaling_factor, evals_b / scaling_factor evals_gamma_a = gs.power(evals_a, resolvant_gamma)[None, :] evals_gamma_b = gs.power(evals_b, resolvant_gamma)[:, None] M_re = evals_gamma_b / (xgs.square(evals_gamma_b) + 1) - evals_gamma_a / ( xgs.square(evals_gamma_a) + 1 ) M_im = 1 / (xgs.square(evals_gamma_b) + 1) - 1 / (xgs.square(evals_gamma_a) + 1) return xgs.square(M_re) + xgs.square(M_im)