Source code for neuralop.layers.differential_conv

import torch
import torch.nn as nn
import torch.nn.functional as F

[docs] class FiniteDifferenceConvolution(nn.Module): """Finite Difference Convolution Layer introduced in [1]_. "Neural Operators with Localized Integral and Differential Kernels" (ICML 2024) https://arxiv.org/abs/2402.16845 Computes a finite difference convolution on a regular grid, which converges to a directional derivative as the grid is refined. Parameters ---------- in_channels : int number of in_channels out_channels : int number of out_channels n_dim : int number of dimensions in the input domain kernel_size : int odd kernel size used for convolutional finite difference stencil groups : int splitting number of channels padding : literal {'periodic', 'replicate', 'reflect', 'zeros'} mode of padding to use on input. See `torch.nn.functional.padding`. References ---------- .. [1] : Liu-Schiaffini, M., et al. (2024). "Neural Operators with Localized Integral and Differential Kernels". ICML 2024, https://arxiv.org/abs/2402.16845. """ def __init__( self, in_channels, out_channels, n_dim, kernel_size=3, groups=1, padding='periodic'): super().__init__() conv_module = getattr(nn, f"Conv{n_dim}d") self.F_conv_module = getattr(F, f"conv{n_dim}d") self.conv_function = getattr(F, f"conv{n_dim}d") assert kernel_size % 2 == 1, "Kernel size should be odd" self.kernel_size = kernel_size self.in_channels = in_channels self.groups = groups self.n_dim = n_dim if padding == 'periodic': self.padding_mode = 'circular' elif padding == 'replicate': self.padding_mode = 'replicate' elif padding == 'reflect': self.padding_mode = 'reflect' elif padding == 'zeros': self.padding_mode = 'zeros' else: raise NotImplementedError("Desired padding mode is not currently supported") self.pad_size = kernel_size // 2 self.conv = conv_module(in_channels, out_channels, kernel_size=kernel_size, padding='same', padding_mode=self.padding_mode, bias=False, groups=groups) self.weight = self.conv.weight
[docs] def forward(self, x, grid_width): """FiniteDifferenceConvolution's forward pass. Alternatively, one could center the conv kernel by subtracting the mean pointwise in the kernel: ``conv(x, kernel - mean(kernel)) / grid_width`` Parameters ---------- x : torch.tensor input tensor, shape (batch, in_channels, d_1, d_2, ...d_n) grid_width : float discretization size of input grid """ conv = self.conv(x) conv_sum = torch.sum(self.conv.weight, dim=tuple([i for i in range(2, 2 + self.n_dim)]), keepdim=True) conv_sum = self.conv_function(x, conv_sum, groups=self.groups) return (conv - conv_sum) / grid_width