"""Shape dataset for PyTorch."""
import itertools
import os
import random
import warnings
import geomstats.backend as gs
import meshio
import numpy as np
import scipy
import torch
from torch.utils.data import Dataset
import geomfum.backend as xgs
from geomfum.metric.mesh import ScipyGraphShortestPathMetric
from geomfum.shape.mesh import TriangleMesh
[docs]
class ShapeDataset(Dataset):
"""ShapeDataset for loading and preprocessing shape data.
Parameters
----------
dataset_dir : str
Path to the directory containing the dataset. We assume the dataset directory to have a subfolder shapes, for shapes, corr, for correspondences and dist, for chaced distance matrices.
spectral : bool
Whether to compute the spectral features.
distances : bool
Whether to compute geodesic distance matrices. For computational reasons, these are not computed on the fly, but rather loaded from a precomputed .mat file.
k : int
Number of eigenvectors to use for the spectral features.
device : torch.device, optional
Device to move the data to.
"""
def __init__(
self,
dataset_dir,
spectral=False,
distances=False,
correspondences=True,
k=200,
device=None,
):
self.dataset_dir = dataset_dir
self.shape_dir = os.path.join(dataset_dir, "shapes")
all_shape_files = sorted(
[
f
for f in os.listdir(self.shape_dir)
if f.lower().endswith((".off", ".ply", ".obj"))
]
)
self.shape_files = all_shape_files
self.device = (
device
if device is not None
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.spectral = spectral
self.k = k
self.distances = distances
self.correspondences = correspondences
# Preload meshes (or their important features) into memory
self.meshes = {}
self.corrs = {}
for filename in self.shape_files:
ext = os.path.splitext(filename)[1][1:]
if ext not in meshio._helpers._writer_map:
warnings.warn(f"Skipped unsupported mesh file: {filename}")
continue
filepath = os.path.join(self.shape_dir, filename)
mesh = TriangleMesh.from_file(filepath)
base_name, _ = os.path.splitext(filename)
# preprocess
if spectral:
mesh.laplacian.find_spectrum(spectrum_size=200, set_as_basis=True)
mesh.basis.use_k = self.k
self.meshes[filename] = mesh
corr_filename = base_name + ".vts"
if self.correspondences:
if os.path.exists(
os.path.join(self.dataset_dir, "corr", corr_filename)
):
# Load correspondences from file, subtract 1 to convert to zero-based indexing.
self.corrs[filename] = (
np.loadtxt(
os.path.join(self.dataset_dir, "corr", corr_filename)
).astype(np.int32)
- 1
)
else:
self.corrs[filename] = np.arange(mesh.vertices.shape[0])
def __getitem__(self, idx):
"""Retrieve a data sample by index.
Parameters
----------
idx : int
Index of the item to retrieve.
Returns
-------
shape_data: dict
Dictionary containing the shape, the correspondence and the distances if available and required.
"""
filename = self.shape_files[idx]
mesh = self.meshes[filename]
shape_data = {}
if self.correspondences:
shape_data.update({"corr": gs.array(self.corrs[filename])})
if self.distances:
mat_subfolder = os.path.join(self.dataset_dir, "dist")
base_name, _ = os.path.splitext(filename)
mat_filename = base_name + ".mat"
dist_path = os.path.join(mat_subfolder, mat_filename)
geod_distance_matrix = None
if os.path.exists(dist_path):
mat_contents = scipy.io.loadmat(dist_path)
if "D" in mat_contents:
geod_distance_matrix = mat_contents["D"]
if geod_distance_matrix is None:
metric = ScipyGraphShortestPathMetric(mesh)
geod_distance_matrix = metric.dist_matrix()
os.makedirs(os.path.dirname(dist_path), exist_ok=True)
scipy.io.savemat(
dist_path,
{"D": gs.to_numpy(geod_distance_matrix)},
)
shape_data.update({"dist_matrix": gs.array(geod_distance_matrix)})
mesh.vertices = xgs.to_device(mesh.vertices, self.device)
mesh.faces = xgs.to_device(mesh.faces, self.device)
mesh.basis.full_vals = xgs.to_device(mesh.basis.full_vals, self.device)
mesh.basis.full_vecs = xgs.to_device(mesh.basis.full_vecs, self.device)
mesh.laplacian._mass_matrix = xgs.to_device(
mesh.laplacian._mass_matrix, self.device
)
shape_data.update({"mesh": mesh})
return shape_data
def __len__(self):
"""Get the length of the dataset."""
return len(self.shape_files)
[docs]
class PairsDataset(Dataset):
"""
Dataset of pairs of shapes.
Parameters
----------
dataset : torch.utils.data.Dataset or list
Preloaded dataset or list of shape data objects.
pair_mode : str, optional
Strategy to generate pairs. Options: 'all', 'random'. Default is 'all'.
n_pairs : int, optional
Number of random pairs to generate if pair_mode is 'random'. Default is 100.
device : torch.device, optional
Device to move the data to. If None, uses CUDA if available, else CPU.
"""
def __init__(self, dataset=None, pair_mode="all", pairs_ratio=100, device=None):
# Preload meshes
self.shape_data = dataset
self.pair_mode = pair_mode
self.device = (
device
if device is not None
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
# Depending on pair_mode, choose the appropriate strategy
if pair_mode == "all":
self.pairs = self.generate_all_pairs()
elif pair_mode == "random":
self.pairs = self.generate_random_pairs(
pairs_ratio
) # You can specify the number of pairs
else:
raise ValueError(f"Unsupported pair_mode: {pair_mode}")
[docs]
def generate_all_pairs(self):
"""Generate all possible pairs of shapes."""
return list(itertools.permutations(range(self.shape_data.__len__()), 2))
[docs]
def generate_random_pairs(self, pairs_ratio=0.5):
"""Generate random pairs of shapes.
Parameters
----------
pairs_ratio : float
Ratio of pairs to generate compared to the total number of possible pairs.
Default is 0.5, meaning half of the possible pairs will be generated.
"""
return random.sample(
list(itertools.combinations(range(self.shape_data.__len__()), 2)),
int(self.shape_data.__len__() * pairs_ratio),
)
def __getitem__(self, idx):
"""Get item by index.
Parameters
----------
idx : int
Index of the item to retrieve.
Returns
-------
data: dict
Dictionary containing the source and target shapes.
"""
src_idx, tgt_idx = self.pairs[idx]
return {"source": self.shape_data[src_idx], "target": self.shape_data[tgt_idx]}
def __len__(self):
"""Get the length of the dataset."""
return len(self.pairs)