"""Spectral descriptors."""
import geomstats.backend as gs
import geomfum.backend as xgs
import geomfum.linalg as la
from geomfum._registry import (
HeatKernelSignatureRegistry,
WaveKernelSignatureRegistry,
WhichRegistryMixins,
)
from ._base import SpectralDescriptor
[docs]
def hks_default_domain(shape, n_domain):
"""Compute HKS default domain.
Parameters
----------
shape : Shape.
Shape with basis.
n_domain : int
Number of time points.
Returns
-------
domain : array-like, shape=[n_domain]
Time points.
"""
nonzero_vals = shape.basis.nonzero_vals
device = getattr(nonzero_vals, "device", None)
return xgs.to_device(
xgs.geomspace(
4 * gs.log(10) / nonzero_vals[-1],
4 * gs.log(10) / nonzero_vals[0],
n_domain,
),
device,
)
[docs]
class WksDefaultDomain:
"""Compute WKS domain.
Parameters
----------
shape : Shape.
Shape with basis.
n_domain : int
Number of energy points to use.
n_overlap : int
Controls Gaussian overlap. Ignored if ``sigma`` is not None.
n_trans : int
Number of standard deviations to translate energy bound by.
"""
def __init__(self, n_domain, sigma=None, n_overlap=7, n_trans=2):
self.n_domain = n_domain
self.sigma = sigma
self.n_overlap = n_overlap
self.n_trans = n_trans
def __call__(self, shape):
"""Compute WKS domain.
Parameters
----------
shape : Shape.
Shape with basis.
Returns
-------
domain : array-like, shape=[n_domain]
sigma : float
Standard deviation.
"""
nonzero_vals = shape.basis.nonzero_vals
device = getattr(nonzero_vals, "device", None)
e_min, e_max = gs.log(nonzero_vals[0]), gs.log(nonzero_vals[-1])
sigma = (
self.n_overlap * (e_max - e_min) / self.n_domain
if self.sigma is None
else self.sigma
)
e_min += self.n_trans * sigma
e_max -= self.n_trans * sigma
energy = xgs.to_device(gs.linspace(e_min, e_max, self.n_domain), device)
return energy, sigma
[docs]
class HeatKernelSignature(WhichRegistryMixins, SpectralDescriptor):
"""Heat kernel signature.
Parameters
----------
scale : bool
Whether to scale weights to sum to one.
n_domain : int
Number of domain points. Ignored if ``domain`` is not None.
domain : callable or array-like, shape=[n_domain]
Method to compute domain points (``f(shape)``) or
domain points.
"""
_Registry = HeatKernelSignatureRegistry
def __init__(self, scale=True, n_domain=3, domain=None):
super().__init__(
domain or (lambda shape: hks_default_domain(shape, n_domain=n_domain)),
use_landmarks=False,
)
self.scale = scale
def __call__(self, shape):
"""Compute descriptor.
Parameters
----------
shape : Shape.
Shape with basis.
Returns
-------
descr : array-like, shape=[n_domain, n_vertices]
Descriptor.
"""
domain = self.domain(shape) if callable(self.domain) else self.domain
vals_term = gs.exp(-la.scalarvecmul(domain, shape.basis.vals))
vecs_term = xgs.square(shape.basis.vecs)
if self.scale:
vals_term = la.scale_to_unit_sum(vals_term)
return gs.einsum("...j,ij->...i", vals_term, vecs_term)
[docs]
class WaveKernelSignature(WhichRegistryMixins, SpectralDescriptor):
"""Wave kernel signature."""
_Registry = WaveKernelSignatureRegistry
def __init__(self, scale=True, sigma=None, n_domain=3, domain=None):
super().__init__(
domain or WksDefaultDomain(n_domain=n_domain, sigma=sigma),
use_landmarks=False,
)
self.scale = scale
self.sigma = sigma
def __call__(self, shape):
"""Compute descriptor.
Parameters
----------
shape : Shape.
Shape with basis.
Returns
-------
descr : array-like, shape=[n_domain, n_vertices]
Descriptor.
"""
if callable(self.domain):
# TODO: document domain better
domain, sigma = self.domain(shape)
else:
domain = self.domain
sigma = self.sigma
exp_arg = -xgs.square(gs.log(shape.basis.nonzero_vals) - domain[:, None]) / (
2 * sigma**2
)
vals_term = gs.exp(exp_arg)
vecs_term = xgs.square(shape.basis.nonzero_vecs)
if self.scale:
vals_term = la.scale_to_unit_sum(vals_term)
return gs.einsum("...j,ij->...i", vals_term, vecs_term)