Source code for geomfum.metric.mesh

"""Module containing metrics to calcualte distances on a mesh."""

import abc

import geomstats.backend as gs
import networkx as nx
from scipy.sparse.csgraph import shortest_path
import geomfum.backend as xgs
import networkx as nx

from geomfum._registry import HeatDistanceMetricRegistry, WhichRegistryMixins
from geomfum.numerics.graph import single_source_partial_dijkstra_path_length


[docs] def to_nx_edge_graph(shape): """Convert a shape to a networkx graph. Parameters ---------- shape : Shape Shape. Returns ------- graph : networkx.Graph Graph. """ # TODO: move to utils? circular imports vertex_a, vertex_b = shape.edges.T lengths = VertexEuclideanMetric(shape).dist(vertex_a, vertex_b) weighted_edges = [ (vertex_a_, vertex_b_, length) for vertex_a_, vertex_b_, length in zip( gs.to_numpy(vertex_a), gs.to_numpy(vertex_b), gs.to_numpy(lengths) ) ] graph = nx.Graph() graph.add_weighted_edges_from(weighted_edges) return graph
[docs] class Metric(abc.ABC): """Metric. Parameters ---------- shape : Shape Considered as a manifold. """ def __init__(self, shape): self._shape = shape
[docs] @abc.abstractmethod def dist(self, point_a, point_b): """Distance between points. Parameters ---------- point_a : array-like, shape=[...] Index Point. point_b : array-like, shape=[...] Index Point. Returns ------- dist : array-like, shape=[...,] Distance. """
[docs] class FinitePointSetMetric(Metric, abc.ABC): """Metric on a finite set of indexed points."""
[docs] @abc.abstractmethod def dist_matrix(self): """Distance between all the points of a shape. Returns ------- dist_matrix : array-like, shape=[n_vertices, n_vertices] Distance matrix. """
[docs] @abc.abstractmethod def dist_from_source(self, source_point): """Distance from source point. Parameters ---------- source_point : array-like, shape=[...] Index of source point. Returns ------- dist : array-like, shape=[...] or list-like[array-like] Distance. target_point : array-like, shape=[n_targets] or list-like[array-like] Target index. """
[docs] class VertexEuclideanMetric(FinitePointSetMetric): """Euclidean metric between vertices of a mesh."""
[docs] def dist(self, point_a, point_b): """Distance between mesh vertices. Parameters ---------- point_a : array-like, shape=[...] Index of source point. point_b : array-like, shape=[...] Index of target point. Returns ------- dist : array-like, shape=[...] Distance. """ vertices = self._shape.vertices diff = vertices[point_a] - vertices[point_b] return gs.linalg.norm(diff, axis=diff.ndim - 1)
[docs] def dist_from_source(self, source_point): """Distance from source point. Parameters ---------- source_point : array-like, shape=[...] Index of source point. Returns ------- dist : array-like, shape=[...] or array-like[array-like] Distance. target_point : array-like, shape=[n_targets] or array-like[array-like] Target index. """ vertices = self._shape.vertices source_vertices = vertices[source_point] if source_vertices.ndim > 1: source_vertices = gs.expand_dims(source_vertices, 1) diff = source_vertices - vertices dist = gs.linalg.norm(diff, axis=diff.ndim - 1) target_point = gs.arange(self._shape.n_vertices) if diff.ndim > 1: target_point = gs.broadcast_to( target_point, dist.shape[:-1] + target_point.shape ) return dist, target_point
[docs] def dist_matrix(self): """Distance between mesh vertices. Returns ------- dist_matrix : array-like, shape=[n_vertices, n_vertices] Distance matrix. """ return self.dist_from_source(gs.arange(self._shape.n_vertices))[0]
class _SingleDispatchMixins: def dist(self, point_a, point_b): """Distance between mesh vertices. Parameters ---------- point_a : array-like, shape=[...] Index of source point. point_b : array-like, shape=[...] Index of target point. Returns ------- dist : array-like, shape=[...,] Distance. """ point_a = gs.asarray(point_a) point_b = gs.asarray(point_b) if point_a.ndim == 0 and point_b.ndim == 0: return self._dist_single(point_a, point_b) point_a, point_b = gs.broadcast_arrays(point_a, point_b) return gs.stack( [ self._dist_single(point_a_, point_b_) for point_a_, point_b_ in zip(point_a, point_b) ] ) def dist_from_source(self, source_point): """Distance between mesh vertices. Parameters ---------- source_point : array-like, shape=[...] Index of source point. Returns ------- dist : array-like, shape=[...,] or list[array-like] Distance. target_point : array-like, shape=[n_targets,] or list[array-like] Target index. """ source_point = gs.asarray(source_point) if source_point.ndim == 0: return self._dist_from_source_single(source_point) out = [ self._dist_from_source_single(source_index_) for source_index_ in source_point ] return list(zip(*out)) @abc.abstractmethod def _dist_from_source_single(self, source_point): pass @abc.abstractmethod def _dist_single(self, point_a, point_b): pass class _NxDijkstraMixins(_SingleDispatchMixins): def dist_matrix(self): """Distance between mesh vertices. Returns ------- dist_matrix : array-like, shape=[n_vertices, n_vertices] Distance matrix. Notes ----- * infinitely slow """ all_pairs = nx.all_pairs_dijkstra_path_length(self._graph) n_vertices = self._shape.n_vertices dist_mat = gs.empty((n_vertices, n_vertices)) for node_index, all_dict in all_pairs: dists = gs.array(list(all_dict.values())) indices = gs.array(list(all_dict.keys())) dist_mat[node_index, indices] = dists return dist_mat def _dist_single(self, point_a, point_b): """Distance between mesh vertices. Parameters ---------- point_a : array-like, shape=() Index of source point. point_b : array-like, shape=() Index of target point. Returns ------- dist : numeric Distance. """ try: dist, _ = nx.single_source_dijkstra( self._graph, point_a.item(), target=point_b.item(), cutoff=None, weight="weight", ) except nx.NetworkXNoPath: dist = float("inf") return gs.asarray(dist)
[docs] class GraphShortestPathMetric(_NxDijkstraMixins, FinitePointSetMetric): """Shortest path on edge graph of mesh with single source Dijkstra. Parameters ---------- shape : Shape Shape. cutoff : float Length (sum of edge weights) at which the search is stopped. """ # TODO: add scipy-based implementation? def __init__(self, shape, cutoff=None): self.cutoff = cutoff super().__init__(shape) self._graph = to_nx_edge_graph(shape) def _dist_from_source_single(self, source_point): """Distance between mesh vertices. Parameters ---------- source_point : array-like, shape=() Index of source point. Returns ------- dist : array-like, shape=[n_targets] Distance. target_point : array-like, shape=[n_targets] Target index. Notes ----- The Distances are ordered following the order of the indices. """ dist_dict = nx.single_source_dijkstra_path_length( self._graph, source_point.item(), cutoff=self.cutoff, weight="weight" ) indices = gs.asarray(list(dist_dict.keys())) distances = gs.asarray(list(dist_dict.values())) sort_order = xgs.argsort(indices) return gs.asarray(list(distances[sort_order])), gs.asarray( list(indices[sort_order]) )
[docs] class KClosestGraphShortestPathMetric(_NxDijkstraMixins, FinitePointSetMetric): """Shortest path on edge graph of mesh with Dijkstra. Parameters ---------- shape : Shape Shape. k_closest : int Number of nodes to find distances to (including the source itself). """ def __init__(self, shape, k_closest=5): self.k_closest = k_closest super().__init__(shape) self._graph = to_nx_edge_graph(shape) def _dist_from_source_single(self, source_point): """Distance between mesh vertices. Parameters ---------- source_point : array-like, shape=() Index of source point. Returns ------- dist : array-like, shape=[n_closest] Distance. target_point : array-like, shape=[n_closest,] Target index. """ dist_dict = single_source_partial_dijkstra_path_length( self._graph, source_point.item(), self.k_closest, weight="weight" ) return gs.array(list(dist_dict.values())), gs.array(list(dist_dict.keys()))
[docs] class HeatDistanceMetric(WhichRegistryMixins): """Heat distance metric between vertices of a mesh. References ---------- .. [CWW2013] Crane, K., Weischedel, C., Wardetzky, M., 2017. The heat method for distance computation. Commun. ACM 60, 90–99. https://doi.org/10.1145/3131280 """ _Registry = HeatDistanceMetricRegistry
class _ScipyShortestPathMixins(_SingleDispatchMixins): def dist_matrix(self): """Distance between mesh vertices. Returns ------- dist_matrix : array-like, shape=[n_vertices, n_vertices] Distance matrix. Notes ----- * infinitely slow """ dist_mat = shortest_path( nx.adjacency_matrix( self._graph, nodelist=range(self._shape.vertices.shape[0]) ).tolil(), directed=False, ) return gs.array(dist_mat)
[docs] class ScipyGraphShortestPathMetric(_ScipyShortestPathMixins, FinitePointSetMetric): """Shortest path on edge graph of mesh with Scipy shortest path solver. Parameters ---------- shape : Shape Shape. cutoff : float Length (sum of edge weights) at which the search is stopped. """ def __init__(self, shape, cutoff=None): self.cutoff = cutoff super().__init__(shape) self._graph = to_nx_edge_graph(shape) def _dist_from_source_single(self, source_point): """Distance between mesh vertices. Parameters ---------- source_point : array-like, shape=() Index of source point. Returns ------- dist : array-like, shape=[n_targets] Distance. target_point : array-like, shape=[n_targets] Target index. """ dist = shortest_path( nx.adjacency_matrix( self._graph, nodelist=range(self._shape.vertices.shape[0]) ).tolil(), directed=False, indices=source_point, ) return gs.array(list(dist)), gs.arange(len(dist))