Source code for geomfum.descriptor.pipeline

"""Descriptor pipeline."""

import abc

import geomstats.backend as gs
import torch

import geomfum.linalg as la
from geomfum.descriptor.learned import LearnedDescriptor

from ._base import Descriptor


[docs] class Subsampler(abc.ABC): """Subsampler.""" @abc.abstractmethod def __call__(self, array): """Subsample array. Parameters ---------- array : array-like Array to subsample. Returns ------- array : array-like Subsampled array. """
[docs] class ArangeSubsampler(Subsampler): """Subsampler based on arange method. Parameters ---------- subsample_step : int Arange step. axis : int Axis from which to subsample. """ def __init__(self, subsample_step=1, axis=0): self.subsample_step = subsample_step self.axis = axis def __call__(self, array): """Subsample array based on arange method. Parameters ---------- array : array-like, shape=[..., n, ...] Array to subsample. Returns ------- array : array-like, shape=[..., d, ...] Subsampled array. """ indices = gs.arange(0, array.shape[self.axis], self.subsample_step) slc = [slice(None)] * array.ndim slc[self.axis] = indices return array[tuple(slc)]
[docs] class Normalizer(abc.ABC): """Normalizer.""" @abc.abstractmethod def __call__(self, shape, array): """Normalize array. Parameters ---------- shape : Shape Shape. array : array-like Array to normalize. Returns ------- array : array-like Normalized array. """
[docs] class L2InnerNormalizer(Normalizer): """L2 inner normalizer.""" def __call__(self, shape, array): """Normalize array with respect to L2 inner product. Parameters ---------- shape : Shape Shape. array : array-like, shape=[..., n] Array to normalize. Returns ------- array : array-like, shape=[..., n] Normalized array. """ coeff = gs.sqrt( gs.einsum( "...n,...n->...", array, la.matvecmul(shape.laplacian.mass_matrix, array), ), ) return la.scalarvecmul(1 / coeff, array)
[docs] class DescriptorPipeline: """Descriptor pipeline. Parameters ---------- steps : list or tuple Steps to apply. Include: descriptor, subsampler, normalizer. """ def __init__(self, steps): self.steps = steps def _update_descr(self, current, new): if current is None: return new return gs.vstack([current, new])
[docs] def apply(self, shape): """Apply descriptor pipeline. Parameters ---------- shape : Shape Shape to apply pipeline to. Returns ------- descr : array-like, shape=[..., n] Descriptor. """ descr = None for step in self.steps: if isinstance(step, Descriptor): if isinstance(step, LearnedDescriptor): with torch.no_grad(): new = step(shape) descr = self._update_descr(descr, new) else: descr = self._update_descr(descr, step(shape)) elif isinstance(step, Subsampler): descr = step(descr) elif isinstance(step, Normalizer): descr = step(shape, descr) else: raise ValueError("Unknown step type.") return descr