"""
Implementation of the DiffusionNet feature extractor for 3D shapes.
References
----------
..DiffusionNet: Discretization Agnostic Learning on Surfaces
Nicholas Sharp, Souhaib Attaiki, Keenan Crane, Maks Ovsjanikov
https://arxiv.org/abs/2012.00888
..https://github.com/dongliangcao/Self-Supervised-Multimodal-Shape-Matching by Dongliang Cao
..https://github.com/nmwsharp/diffusion-net
"""
import torch
import torch.nn as nn
import geomfum.backend as xgs
import geomfum.backend as xgs
from geomfum.descriptor.learned import BaseFeatureExtractor
# TODO: Implement betching operations. for now diffusionnet accept just one mesh as input
"""
Implementation from
https://github.com/dongliangcao/Self-Supervised-Multimodal-Shape-Matching by Dongliang Cao
"""
[docs]
class DiffusionNet(nn.Module):
"""DiffusionNet: stacked of DiffusionBlocks.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
hidden_channels : int
Number of hidden channels in diffusion block. Default 128.
n_block : int
Number of diffusion blocks. Default 4.
last_activation : nn.Module or None
Output layer. Default None.
mlp_hidden_channels : List or None
MLP hidden layers. Default None means [hidden_channels, hidden_channels].
output_at : str
Produce outputs at various mesh elements by averaging from vertices.
One of ['vertices', 'edges', 'faces', 'global_mean']. Default 'vertices'.
dropout : bool
Whether use dropout in mlp. Default True.
with_gradient_features : bool
Whether use SpatialGradientFeatures in DiffusionBlock. Default True.
with_gradient_rotations : bool
Whether use gradient rotations in SpatialGradientFeatures. Default True.
diffusion_method : str
Diffusion method applied in diffusion layer.
One of ['spectral', 'implicit_dense']. Default 'spectral'.
k_eig : int
Number of eigenvalues/eigenvectors to compute diffusion. Default 128.
cache_dir : str or None
Cache dir contains all pre-computed spectral operators. Default None.
"""
def __init__(
self,
in_channels,
out_channels,
hidden_channels=128,
n_block=4,
last_activation=None,
mlp_hidden_channels=None,
output_at="vertices",
dropout=True,
with_gradient_features=True,
with_gradient_rotations=True,
diffusion_method="spectral",
k_eig=128,
cache_dir=None,
):
super(DiffusionNet, self).__init__()
# sanity check
assert diffusion_method in ["spectral", "implicit_dense"], (
f"Invalid diffusion method: {diffusion_method}"
)
assert output_at in ["vertices", "edges", "faces", "global_mean"], (
f"Invalid output_at: {output_at}"
)
# basic params
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.n_block = n_block
self.cache_dir = cache_dir
# output params
self.last_activation = last_activation
self.output_at = output_at
# mlp options
if not mlp_hidden_channels:
mlp_hidden_channels = [hidden_channels, hidden_channels]
self.mlp_hidden_channels = mlp_hidden_channels
self.dropout = dropout
# diffusion options
self.diffusion_method = diffusion_method
self.k_eig = k_eig
# gradient feature options
self.with_gradient_features = with_gradient_features
self.with_gradient_rotations = with_gradient_rotations
# setup networks
# first and last linear layers
self.first_linear = nn.Linear(in_channels, hidden_channels)
self.last_linear = nn.Linear(hidden_channels, out_channels)
# diffusion blocks
blocks = []
for i_block in range(self.n_block):
block = DiffusionNetBlock(
in_channels=hidden_channels,
mlp_hidden_channels=mlp_hidden_channels,
dropout=dropout,
diffusion_method=diffusion_method,
with_gradient_features=with_gradient_features,
with_gradient_rotations=with_gradient_rotations,
)
blocks += [block]
self.blocks = nn.ModuleList(blocks)
[docs]
def forward(
self,
verts,
faces=None,
feats=None,
frames=None,
mass=None,
L=None,
evals=None,
evecs=None,
gradX=None,
gradY=None,
):
"""Compute the forward pass of the DiffusionNet.
Parameters
----------
verts : torch.Tensor
Input vertices [B, V, 3].
faces : torch.Tensor, optional
Input faces [B, F, 3]. Default None.
feats : torch.Tensor, optional
Input features. Default None.
frames : torch.Tensor
Tangent frames for vertices.
mass : torch.Tensor
Diagonal elements in mass matrix.
L : torch.SparseTensor
Sparse Laplacian matrix.
evals : torch.Tensor
Eigenvalues of Laplacian Matrix.
evecs : torch.Tensor
Eigenvectors of Laplacian Matrix.
gradX : torch.SparseTensor
Real part of gradient matrix.
gradY : torch.SparseTensor
Imaginary part of gradient matrix.
Returns
-------
torch.Tensor
Output features.
"""
assert verts.dim() == 3, "Only support batch operation"
if faces is not None:
assert faces.dim() == 3, "Only support batch operation"
mass = mass
L = L
evals = evals
evecs = evecs
gradX = gradX
gradY = gradY
if feats is not None:
x = feats
else:
x = verts
x = self.first_linear(x)
for block in self.blocks:
x = block(x, mass, L, evals, evecs, gradX, gradY)
x = self.last_linear(x)
if self.output_at == "vertices":
x_out = x
elif self.output_at == "faces":
x_gather = x.unsqueeze(-1).expand(-1, -1, -1, 3)
faces_gather = faces.unsqueeze(2).expand(-1, -1, x.shape[-1], -1)
x_out = torch.gather(x_gather, 1, faces_gather).mean(dim=-1)
else:
x_out = torch.sum(x * mass.unsqueeze(-1), dim=-1) / torch.sum(
mass, dim=-1, keepdim=True
)
if self.last_activation:
x_out = self.last_activation(x_out)
return x_out
[docs]
class DiffusionNetBlock(nn.Module):
"""Building Block of DiffusionNet.
Parameters
----------
in_channels : int
Number of input channels.
mlp_hidden_channels : List
List of mlp hidden channels.
dropout : bool
Whether use dropout in MLP. Default True.
diffusion_method : str
Method for diffusion. Default "spectral".
with_gradient_features : bool
Whether use spatial gradient feature. Default True.
with_gradient_rotations : bool
Whether use spatial gradient rotation. Default True.
"""
def __init__(
self,
in_channels,
mlp_hidden_channels,
dropout=True,
diffusion_method="spectral",
with_gradient_features=True,
with_gradient_rotations=True,
):
super(DiffusionNetBlock, self).__init__()
self.in_channels = in_channels
self.mlp_hidden_channels = mlp_hidden_channels
self.dropout = dropout
self.with_gradient_features = with_gradient_features
self.with_gradient_rotations = with_gradient_rotations
# Diffusion block
self.diffusion = LearnedTimeDiffusion(self.in_channels, method=diffusion_method)
# concat of both diffused features and original features
self.mlp_in_channels = 2 * self.in_channels
# Spatial gradient block
if self.with_gradient_features:
self.gradient_features = SpatialGradientFeatures(
self.in_channels, with_gradient_rotations=self.with_gradient_rotations
)
# concat of gradient features
self.mlp_in_channels += self.in_channels
# MLP block
self.mlp = MiniMLP(
[self.mlp_in_channels] + self.mlp_hidden_channels + [self.in_channels],
dropout=self.dropout,
)
[docs]
def forward(self, feat_in, mass, L, evals, evecs, gradX, gradY):
"""Compute the forward pass of the diffusion block.
Parameters
----------
feat_in : torch.Tensor
Input feature vector [B, V, C].
mass : torch.Tensor
Diagonal elements of mass matrix [B, V].
L : torch.SparseTensor
Sparse Laplacian matrix [B, V, V].
evals : torch.Tensor
Eigenvalues of Laplacian Matrix [B, K].
evecs : torch.Tensor
Eigenvectors of Laplacian Matrix [B, V, K].
gradX : torch.SparseTensor
Real part of gradient matrix [B, V, V].
gradY : torch.SparseTensor
Imaginary part of gradient matrix [B, V, V].
Returns
-------
torch.Tensor
Output feature vector.
"""
B = feat_in.shape[0]
assert feat_in.shape[-1] == self.in_channels, (
f"Expected feature channel: {self.in_channels}, but got: {feat_in.shape[-1]}"
)
# Diffusion block
feat_diffuse = self.diffusion(feat_in, L, mass, evals, evecs)
# Compute gradient features
if self.with_gradient_features:
# Compute gradient
feat_grads = []
for b in range(B):
# gradient after diffusion
feat_gradX = torch.mm(gradX[b, ...], feat_diffuse[b, ...])
feat_gradY = torch.mm(gradY[b, ...], feat_diffuse[b, ...])
feat_grads.append(torch.stack((feat_gradX, feat_gradY), dim=-1))
feat_grad = torch.stack(feat_grads, dim=0) # [B, V, C, 2]
# Compute gradient features
feat_grad_features = self.gradient_features(feat_grad)
# Stack inputs to MLP
feat_combined = torch.cat(
(feat_in, feat_diffuse, feat_grad_features), dim=-1
)
else:
# Stack inputs to MLP
feat_combined = torch.cat((feat_in, feat_diffuse), dim=-1)
# MLP block
feat_out = self.mlp(feat_combined)
# Skip connection
feat_out = feat_out + feat_in
return feat_out
[docs]
class LearnedTimeDiffusion(nn.Module):
"""Applied diffusion with learned time per-channel.
In the spectral domain this becomes f_out = e ^ (lambda_i * t) * f_in
Parameters
----------
in_channels : int
Number of input channels.
method : str
Method to perform time diffusion. Default 'spectral'.
"""
def __init__(self, in_channels, method="spectral"):
super(LearnedTimeDiffusion, self).__init__()
assert method in ["spectral", "implicit_dense"], f"Invalid method: {method}"
self.in_channels = in_channels
self.diffusion_time = nn.Parameter(torch.Tensor(in_channels))
self.method = method
# init as zero
nn.init.constant_(self.diffusion_time, 0.0)
[docs]
def forward(self, feat, L, mass, evals, evecs):
"""Forward pass of the diffusion layer.
Parameters
----------
feat : torch.Tensor
Feature vector [B, V, C].
L : torch.SparseTensor
Sparse Laplacian matrix [B, V, V].
mass : torch.Tensor
Diagonal elements in mass matrix [B, V].
evals : torch.Tensor
Eigenvalues of Laplacian matrix [B, K].
evecs : torch.Tensor
Eigenvectors of Laplacian matrix [B, V, K].
Returns
-------
feat_diffuse : torch.Tensor
Diffused feature vector [B, V, C].
"""
# project times to the positive half-space
# (and away from 0 in the incredibly rare chance that they get stuck)
with torch.no_grad():
self.diffusion_time.data = torch.clamp(self.diffusion_time, min=1e-8)
assert feat.shape[-1] == self.in_channels, (
f"Expected feature channel: {self.in_channels}, but got: {feat.shape[-1]}"
)
if self.method == "spectral":
# Transform to spectral
feat_spec = torch.matmul(evecs.transpose(-2, -1), feat * mass.unsqueeze(-1))
# Diffuse
diffuse_coefs = torch.exp(
-evals.unsqueeze(-1) * self.diffusion_time.unsqueeze(0)
)
feat_diffuse_spec = diffuse_coefs * feat_spec
# Transform back to feature
feat_diffuse = torch.matmul(evecs, feat_diffuse_spec)
else: # 'implicit_dense'
# Form the dense matrix (M + tL) with dims (B, C, V, V)
mat_dense = (
L.to_dense().unsuqeeze(1).expand(-1, self.in_channels, -1, -1).clone()
)
mat_dense *= self.diffusion_time.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
mat_dense += torch.diag_embed(mass).unsqueeze(1)
# Factor the system
cholesky_factors = torch.linalg.cholesky(mat_dense)
# Solve the system
rhs = feat * mass.unsqueeze(-1)
rhsT = rhs.transpose(1, 2).unsqueeze(-1)
sols = torch.cholesky_solve(rhsT, cholesky_factors)
feat_diffuse = sols.squeeze(-1).transpose(1, 2)
return feat_diffuse
[docs]
class SpatialGradientFeatures(nn.Module):
"""Compute dot-products between input vectors. Uses a learned complex-linear layer to keep dimension down.
Parameters
----------
in_channels : int
Number of input channels.
with_gradient_rotations : bool
Whether with gradient rotations. Default True.
"""
def __init__(self, in_channels, with_gradient_rotations=True):
super(SpatialGradientFeatures, self).__init__()
self.in_channels = in_channels
self.with_gradient_rotations = with_gradient_rotations
if self.with_gradient_rotations:
self.A_re = nn.Linear(self.in_channels, self.in_channels, bias=False)
self.A_im = nn.Linear(self.in_channels, self.in_channels, bias=False)
else:
self.A = nn.Linear(self.in_channels, self.in_channels, bias=False)
[docs]
def forward(self, feat_in):
"""Compute the spatial gradient features.
Parameters
----------
feat_in : torch.Tensor
Input feature vector (B, V, C, 2).
Returns
-------
feat_out : torch.Tensor
Output feature vector (B, V, C)
"""
feat_a = feat_in
if self.with_gradient_rotations:
feat_real_b = self.A_re(feat_in[..., 0]) - self.A_im(feat_in[..., 1])
feat_img_b = self.A_re(feat_in[..., 0]) + self.A_im(feat_in[..., 1])
else:
feat_real_b = self.A(feat_in[..., 0])
feat_img_b = self.A(feat_in[..., 1])
feat_out = feat_a[..., 0] * feat_real_b + feat_a[..., 1] * feat_img_b
return torch.tanh(feat_out)
[docs]
class MiniMLP(nn.Sequential):
"""A simple MLP with configurable hidden layer sizes.
Parameters
----------
layer_sizes : List
List of layer size.
dropout : bool
Whether use dropout. Default False.
activation : nn.Module
Activation function. Default ReLU.
name : str
Module name. Default 'miniMLP'
"""
def __init__(self, layer_sizes, dropout=False, activation=nn.ReLU, name="miniMLP"):
super(MiniMLP, self).__init__()
for i in range(len(layer_sizes) - 1):
is_last = i + 2 == len(layer_sizes)
# Dropout Layer
if dropout and i > 0:
self.add_module(name + "_dropout_{:03d}".format(i), nn.Dropout(p=0.5))
# Affine Layer
self.add_module(
name + "_linear_{:03d}".format(i),
nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
)
# Activation Layer
if not is_last:
self.add_module(name + "_activation_{:03d}".format(i), activation())