Source code for neuralop.layers.differential_conv

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class FiniteDifferenceConvolution(nn.Module): """Finite Difference Convolution Layer This is the finite difference convolution layer introduced in [1]_. It computes a finite difference convolution on a regular grid, which converges to a directional derivative as the grid is refined. Parameters ---------- in_channels : int Number of in_channels out_channels : int Number of out_channels n_dim : int Number of dimensions in the input domain kernel_size : int, optional Odd kernel size used for convolutional finite difference stencil, by default 3 groups : int, optional Splitting number of channels, by default 1 padding : literal {'periodic', 'replicate', 'reflect', 'zeros'}, optional Mode of padding to use on input. Options: 'periodic', 'replicate', 'reflect', 'zeros', by default 'periodic' See `torch.nn.functional.padding`. References ---------- .. [1] : Liu-Schiaffini, M., et al. (2024). "Neural Operators with Localized Integral and Differential Kernels". ICML 2024, https://arxiv.org/abs/2402.16845. """ def __init__( self, in_channels, out_channels, n_dim, kernel_size=3, groups=1, padding="periodic", ): super().__init__() conv_module = getattr(nn, f"Conv{n_dim}d") self.F_conv_module = getattr(F, f"conv{n_dim}d") self.conv_function = getattr(F, f"conv{n_dim}d") assert kernel_size % 2 == 1, "Kernel size should be odd" self.kernel_size = kernel_size self.in_channels = in_channels self.groups = groups self.n_dim = n_dim if padding == "periodic": self.padding_mode = "circular" elif padding == "replicate": self.padding_mode = "replicate" elif padding == "reflect": self.padding_mode = "reflect" elif padding == "zeros": self.padding_mode = "zeros" else: raise NotImplementedError("Desired padding mode is not currently supported") self.pad_size = kernel_size // 2 self.conv = conv_module( in_channels, out_channels, kernel_size=kernel_size, padding="same", padding_mode=self.padding_mode, bias=False, groups=groups, ) self.weight = self.conv.weight
[docs] def forward(self, x, grid_width): """FiniteDifferenceConvolution's forward pass. Alternatively, one could center the conv kernel by subtracting the mean pointwise in the kernel: ``conv(x, kernel - mean(kernel)) / grid_width`` Parameters ---------- x : torch.tensor input tensor, shape (batch, in_channels, d_1, d_2, ...d_n) grid_width : float discretization size of input grid """ conv = self.conv(x) conv_sum = torch.sum( self.conv.weight, dim=tuple([i for i in range(2, 2 + self.n_dim)]), keepdim=True, ) conv_sum = self.conv_function(x, conv_sum, groups=self.groups) return (conv - conv_sum) / grid_width