Source code for geomfum.wrap.pyrmt
"""pyRMT wrapper."""
import geomstats.backend as gs
import numpy as np
from PyRMT import RMTMesh
import geomfum.linalg as la
from geomfum.shape.hierarchical import HierarchicalShape
from geomfum.shape.mesh import TriangleMesh
[docs]
class PyrmtHierarchicalMesh(HierarchicalShape):
"""Hierarchical mesh from PyRMT.
Based on [MBMR2023]_.
Parameters
----------
mesh : TriangleMesh
High-resolution mesh.
min_n_samples : int
Minimum number of vertices in low-resolution mesh.
References
----------
.. [MBMR2023] Filippo Maggioli, Daniele Baieri, Simone Melzi, and Emanuele Rodolà.
“ReMatching: Low-Resolution Representations for Scalable Shape
Correspondence.” arXiv, October 30, 2023.
https://doi.org/10.48550/arXiv.2305.09274.
"""
def __init__(self, mesh, min_n_samples):
if min_n_samples > mesh.n_vertices:
raise ValueError(
f"Number of samples ({min_n_samples}) is greater than number of"
f"vertices of high-resolution mesh ({mesh.n_vertices})"
)
low = self._remesh(mesh, min_n_samples)
super().__init__(low=low, high=mesh)
def _remesh(self, mesh, min_n_samples):
vertices = gs.to_numpy(mesh.vertices)
faces = gs.to_numpy(mesh.faces)
if vertices.dtype != np.float64:
vertices = vertices.astype(np.float64)
if not vertices.flags.f_contiguous:
vertices = np.asfortranarray(vertices, dtype=np.float64)
if faces.dtype != np.int32:
faces = faces.astype(np.int32)
if not faces.flags.f_contiguous:
faces = np.asfortranarray(faces, dtype=np.int32)
rhigh = RMTMesh(vertices, faces)
rhigh.make_manifold()
rlow = rhigh.remesh(min_n_samples)
rlow.clean_up()
self._rhigh = rhigh
self._rlow = rlow
self._baryc_map = rlow.baryc_map(vertices)
return TriangleMesh(gs.array(rlow.vertices), gs.array(rlow.triangles))
[docs]
def scalar_low_high(self, scalar):
"""Transfer scalar from low-resolution to high.
Parameters
----------
scalar : array-like, shape=[..., low.n_vertices]
Scalar map on the low-resolution shape.
Returns
-------
high_scalar : array-like, shape=[..., high.n_vertices]
Scalar map on the high-resolution shape.
"""
return gs.asarray(la.matvecmul(self._baryc_map, scalar))