Source code for neuralop.layers.integral_transform

import torch
from torch import nn
import torch.nn.functional as F

from .channel_mlp import LinearChannelMLP
from .segment_csr import segment_csr


[docs] class IntegralTransform(nn.Module): """Integral Kernel Transform (GNO) Computes one of the following: (a) \int_{A(x)} k(x, y) dy (b) \int_{A(x)} k(x, y) * f(y) dy (c) \int_{A(x)} k(x, y, f(y)) dy (d) \int_{A(x)} k(x, y, f(y)) * f(y) dy x : Points for which the output is defined y : Points for which the input is defined A(x) : A subset of all points y (depending on\ each x) over which to integrate k : A kernel parametrized as a MLP (LinearChannelMLP) f : Input function to integrate against given\ on the points y If f is not given, a transform of type (a) is computed. Otherwise transforms (b), (c), or (d) are computed. The sets A(x) are specified as a graph in CRS format. Parameters ---------- channel_mlp : torch.nn.Module, default None MLP parametrizing the kernel k. Input dimension should be dim x + dim y or dim x + dim y + dim f. MLP should not be pointwise and should only operate across channels to preserve the discretization-invariance of the kernel integral. channel_mlp_layers : list, default None List of layers sizes speficing a MLP which parametrizes the kernel k. The MLP will be instansiated by the LinearChannelMLP class channel_mlp_non_linearity : callable, default torch.nn.functional.gelu Non-linear function used to be used by the LinearChannelMLP class. Only used if channel_mlp_layers is given and channel_mlp is None transform_type : str, default 'linear' Which integral transform to compute. The mapping is: 'linear_kernelonly' -> (a) 'linear' -> (b) 'nonlinear_kernelonly' -> (c) 'nonlinear' -> (d) If the input f is not given then (a) is computed by default independently of this parameter. use_torch_scatter : bool, default 'True' Whether to use torch_scatter's implementation of segment_csr or our native PyTorch version. torch_scatter should be installed by default, but there are known versioning issues on some linux builds of CPU-only PyTorch. Try setting to False if you experience an error from torch_scatter. """ def __init__( self, channel_mlp=None, channel_mlp_layers=None, channel_mlp_non_linearity=F.gelu, transform_type="linear", use_torch_scatter=True, ): super().__init__() assert channel_mlp is not None or channel_mlp_layers is not None self.transform_type = transform_type self.use_torch_scatter = use_torch_scatter if ( self.transform_type != "linear_kernelonly" and self.transform_type != "linear" and self.transform_type != "nonlinear_kernelonly" and self.transform_type != "nonlinear" ): raise ValueError( f"Got transform_type={transform_type} but expected one of " "[linear_kernelonly, linear, nonlinear_kernelonly, nonlinear]" ) if channel_mlp is None: self.channel_mlp = LinearChannelMLP(layers=channel_mlp_layers, non_linearity=channel_mlp_non_linearity) else: self.channel_mlp = channel_mlp """" Assumes x=y if not specified Integral is taken w.r.t. the neighbors If no weights are given, a Monte-Carlo approximation is made NOTE: For transforms of type 0 or 2, out channels must be the same as the channels of f """
[docs] def forward(self, y, neighbors, x=None, f_y=None, weights=None): """Compute a kernel integral transform Parameters ---------- y : torch.Tensor of shape [n, d1] n points of dimension d1 specifying the space to integrate over. If batched, these must remain constant over the whole batch so no batch dim is needed. neighbors : dict The sets A(x) given in CRS format. The dict must contain the keys "neighbors_index" and "neighbors_row_splits." For descriptions of the two, see NeighborSearch. If batch > 1, the neighbors must be constant across the entire batch. x : torch.Tensor of shape [m, d2], default None m points of dimension d2 over which the output function is defined. If None, x = y. f_y : torch.Tensor of shape [batch, n, d3] or [n, d3], default None Function to integrate the kernel against defined on the points y. The kernel is assumed diagonal hence its output shape must be d3 for the transforms (b) or (d). If None, (a) is computed. weights : torch.Tensor of shape [n,], default None Weights for each point y proprtional to the volume around f(y) being integrated. For example, suppose d1=1 and let y_1 < y_2 < ... < y_{n+1} be some points. Then, for a Riemann sum, the weights are y_{j+1} - y_j. If None, 1/|A(x)| is used. Output ---------- out_features : torch.Tensor of shape [batch, m, d4] or [m, d4] Output function given on the points x. d4 is the output size of the kernel k. """ if x is None: x = y rep_features = y[neighbors["neighbors_index"]] # batching only matters if f_y (latent embedding) values are provided batched = False # f_y has a batch dim IFF batched=True if f_y is not None: if f_y.ndim == 3: batched = True batch_size = f_y.shape[0] in_features = f_y[:, neighbors["neighbors_index"], :] elif f_y.ndim == 2: batched = False in_features = f_y[neighbors["neighbors_index"]] num_reps = ( neighbors["neighbors_row_splits"][1:] - neighbors["neighbors_row_splits"][:-1] ) self_features = torch.repeat_interleave(x, num_reps, dim=0) agg_features = torch.cat([rep_features, self_features], dim=-1) if f_y is not None and ( self.transform_type == "nonlinear_kernelonly" or self.transform_type == "nonlinear" ): if batched: # repeat agg features for every example in the batch agg_features = agg_features.repeat( [batch_size] + [1] * agg_features.ndim ) agg_features = torch.cat([agg_features, in_features], dim=-1) rep_features = self.channel_mlp(agg_features) if f_y is not None and self.transform_type != "nonlinear_kernelonly": rep_features = rep_features * in_features if weights is not None: assert weights.ndim == 1, "Weights must be of dimension 1 in all cases" nbr_weights = weights[neighbors["neighbors_index"]] # repeat weights along batch dim if batched if batched: nbr_weights = nbr_weights.repeat( [batch_size] + [1] * nbr_weights.ndim ) rep_features = nbr_weights * rep_features reduction = "sum" else: reduction = "mean" splits = neighbors["neighbors_row_splits"] if batched: splits = splits.repeat([batch_size] + [1] * splits.ndim) out_features = segment_csr(rep_features, splits, reduce=reduction, use_scatter=self.use_torch_scatter) return out_features