Source code for geomfum.descriptor.learned
"""Implementation of the learned descriptor.
The learned descriptor is a descriptor that uses a neural network to compute features.
"""
import abc
import geomstats.backend as gs
import torch
import torch.nn as nn
from geomfum._registry import FeatureExtractorRegistry, WhichRegistryMixins
from geomfum.descriptor._base import Descriptor
[docs]
class BaseFeatureExtractor(abc.ABC):
"""Base class for feature extractor."""
def __init__(self):
super().__init__()
[docs]
def load_from_path(self, path):
"""Load model parameters from the provided file path.
Parameters
----------
path : str
Path to the saved model parameters
"""
try:
self.model.load_state_dict(torch.load(path, map_location=self.device))
except FileNotFoundError as e:
raise FileNotFoundError(f"Model file not found: {path}") from e
except Exception as e:
raise ValueError(f"Failed to load model from {path}: {e}") from e
[docs]
def save(self, path):
"""Save model parameters to the specified file path.
Parameters
----------
path : str
Path to the saved model parameters
"""
torch.save(self.model.state_dict(), path)
[docs]
class FeatureExtractor(WhichRegistryMixins):
"""Feature extractor."""
_Registry = FeatureExtractorRegistry
[docs]
class LearnedDescriptor(Descriptor, abc.ABC, nn.Module):
"""Learned descriptor.
Parameters
----------
feature_extractor: Feature Extractor
Feature extractor to use.
"""
def __init__(self, feature_extractor=None):
super().__init__()
self.feature_extractor = feature_extractor
if self.feature_extractor is None:
self.feature_extractor = FeatureExtractor.from_registry(
which="diffusionnet"
)
[docs]
def forward(self, shape):
"""Compute descriptor.
Parameters
----------
shape : Shape.
Shape.
Returns
-------
features : array-like, shape=[..., n_features, n_vertices]
Descriptors of the shape, where `n_features` is the number of features extracted by the feature extractor.
"""
features = self.feature_extractor(shape)
features = gs.array(features.squeeze().double()).T
return features