from typing import List, Optional, Union
import warnings
import torch
from torch import nn
import torch.nn.functional as F
from .channel_mlp import ChannelMLP
from .fno_block import SubModule
from .differential_conv import FiniteDifferenceConvolution
from .discrete_continuous_convolution import EquidistantDiscreteContinuousConv2d
from .normalization_layers import AdaIN, InstanceNorm
from .skip_connections import skip_connection
from .spectral_convolution import SpectralConv
from ..utils import validate_scaling_factor
Number = Union[int, float]
[docs]
class LocalNOBlocks(nn.Module):
"""Local Neural Operator blocks with localized integral and differential kernels.
It is implemented as described in [3]_.
This class implements neural operator blocks that combine Fourier neural operators
with localized integral and differential kernels to capture both global and local
features in PDE solutions [3]_. The architecture addresses the over-smoothing limitations
of purely global FNOs while maintaining resolution-independence through principled
local operations.
The key innovation is the integration of two types of local operations:
1. Differential kernels: Learn finite difference stencils that converge to
differential operators under appropriate scaling.
2. Local integral kernels: Use discrete-continuous convolutions with locally
supported kernels to capture local interactions.
Parameters
----------
in_channels : int
Number of input channels to Fourier layers
out_channels : int
Number of output channels after Fourier layers
n_modes : int, List[int]
Number of modes to keep along each dimension in frequency space.
Can either be specified as an int (for all dimensions)
or an iterable with one number per dimension
default_in_shape : Tuple[int]
Default input shape for spatiotemporal dimensions
resolution_scaling_factor : Optional[Union[Number, List[Number]]], optional
Factor by which to scale outputs for super-resolution, by default None
n_layers : int, optional
Number of neural operator layers to apply in sequence, by default 1
disco_layers : bool or List[bool], optional
Whether to include local integral kernel connections at each layer.
If a single bool, applies to all layers. If a list, must match n_layers.
disco_kernel_shape : Union[int, List[int]], optional
Kernel shape for local integral operations. Single int for isotropic kernels,
two ints for anisotropic kernels, by default [2,4]
domain_length : torch.Tensor, optional
Physical domain extent/length. Assumes square domain [-1, 1]^2 by default
disco_groups : int, optional
Number of groups in local integral convolution, by default 1
disco_bias : bool, optional
Whether to use bias for integral kernel, by default True
radius_cutoff : float, optional
Cutoff radius (relative to domain_length) for local integral kernel, by default None
diff_layers : bool or List[bool], optional
Whether to include differential kernel connections at each layer.
If a single bool, applies to all layers. If a list, must match n_layers.
conv_padding_mode : str, optional
Padding mode for spatial convolution kernels.
Options: 'periodic', 'circular', 'replicate', 'reflect', 'zeros'. By default 'periodic'
fin_diff_kernel_size : int, optional
Kernel size for finite difference convolution (must be odd), by default 3
mix_derivatives : bool, optional
Whether to mix derivatives across channels, by default True
max_n_modes : int or List[int], optional
Maximum number of modes to keep along each dimension, by default None
local_no_block_precision : str, optional
Floating point precision for computations, by default "full"
use_channel_mlp : bool, optional
Whether to use MLP layer after each block, by default False
channel_mlp_dropout : int, optional
Dropout parameter for channel MLP, by default 0
channel_mlp_expansion : float, optional
Expansion factor for channel MLP, by default 0.5
non_linearity : torch.nn.F module, optional
Nonlinear activation function between layers, by default F.gelu
stabilizer : Literal["tanh"], optional
Stabilizing module between layers. Options: "tanh". By default None
norm : Literal["ada_in", "group_norm", "instance_norm"], optional
Normalization layer to use, by default None
ada_in_features : int, optional
Number of features for adaptive instance normalization, by default None
preactivation : bool, optional
Whether to call forward pass with pre-activation, by default False
if True, call nonlinear activation and norm before Fourier convolution
if False, call activation and norms after Fourier convolutions
local_no_skip : str, optional
Module to use for Local NO skip connections, by default "linear"
Options: "linear", "identity", "soft-gating", None.
If None, no skip connection is added. See layers.skip_connections for more details
channel_mlp_skip : str, optional
Module to use for ChannelMLP skip connections, by default "soft-gating"
Options: "linear", "identity", "soft-gating", None.
If None, no skip connection is added. See layers.skip_connections for more details
Other Parameters
---------------
complex_data : bool, optional
Whether the data takes complex values in space, by default False
separable : bool, optional
Separable parameter for SpectralConv, by default False
factorization : str, optional
Factorization method for SpectralConv, by default None
Options: "factorized", "reconstructed".
rank : float, optional
Rank parameter for SpectralConv, by default 1.0
conv_module : BaseConv, optional
Convolution module for Local NO block, by default SpectralConv
joint_factorization : bool, optional
Whether to factorize all SpectralConv weights as one tensor, by default False
fixed_rank_modes : bool, optional
Fixed rank modes parameter for SpectralConv, by default False
implementation : str, optional
Implementation method for SpectralConv, by default "factorized"
decomposition_kwargs : dict, optional
Keyword arguments for tensor decomposition in SpectralConv, by default dict()
Notes
-----
- Differential kernels are only implemented for dimensions ≤ 3
- Local integral kernels are only implemented for 2D domains
References
----------
.. [1] Li, Z. et al. "Fourier Neural Operator for Parametric Partial Differential
Equations" (2021). ICLR 2021, https://arxiv.org/pdf/2010.08895.
.. [2] Kossaifi, J., Kovachki, N., Azizzadenesheli, K., Anandkumar, A. "Multi-Grid
Tensorized Fourier Neural Operator for High-Resolution PDEs" (2024).
TMLR 2024, https://openreview.net/pdf?id=AWiDlO63bH.
.. [3] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
"Neural Operators with Localized Integral and Differential Kernels" (2024).
ICML 2024, https://arxiv.org/pdf/2402.16845.
"""
def __init__(
self,
in_channels,
out_channels,
n_modes,
default_in_shape,
resolution_scaling_factor=None,
n_layers=1,
disco_layers=True,
disco_kernel_shape=[2, 4],
radius_cutoff=None,
domain_length=[2, 2],
disco_groups=1,
disco_bias=True,
diff_layers=True,
conv_padding_mode="periodic",
fin_diff_kernel_size=3,
mix_derivatives=True,
max_n_modes=None,
local_no_block_precision="full",
use_channel_mlp=False,
channel_mlp_dropout=0,
channel_mlp_expansion=0.5,
non_linearity=F.gelu,
stabilizer=None,
norm=None,
ada_in_features=None,
preactivation=False,
local_no_skip="linear",
channel_mlp_skip="soft-gating",
separable=False,
factorization=None,
rank=1.0,
conv_module=SpectralConv,
fixed_rank_modes=False,
implementation="factorized",
decomposition_kwargs=dict(),
fft_norm="forward",
):
super().__init__()
if isinstance(n_modes, int):
n_modes = [n_modes]
self._n_modes = n_modes
assert len(n_modes) == len(default_in_shape), "Spatiotemporal dimensions must be consistent"
# If a single bool is passed for disco_layers or diff_layers, set values for all layers
if isinstance(disco_layers, bool):
disco_layers = [disco_layers] * n_layers
if isinstance(diff_layers, bool):
diff_layers = [diff_layers] * n_layers
if len(n_modes) > 3 and True in diff_layers:
NotImplementedError("Differential convs not implemented for dimensions higher than 3.")
if len(n_modes) != 2 and True in disco_layers:
NotImplementedError("Local conv layers only implemented for dimension 2.")
if conv_padding_mode not in ['circular', 'periodic', 'zeros'] and True in disco_layers:
warnings.warn("Local conv layers only support periodic or zero padding, defaulting to zero padding for local convs.")
self.n_dim = len(n_modes)
self.resolution_scaling_factor: Union[
None, List[List[float]]
] = validate_scaling_factor(resolution_scaling_factor, self.n_dim, n_layers)
self.max_n_modes = max_n_modes
self.local_no_block_precision = local_no_block_precision
self.in_channels = in_channels
self.out_channels = out_channels
self.n_layers = n_layers
self.non_linearity = non_linearity
self.stabilizer = stabilizer
self.rank = rank
self.factorization = factorization
self.fixed_rank_modes = fixed_rank_modes
self.decomposition_kwargs = decomposition_kwargs
self.local_no_skip = local_no_skip
self.channel_mlp_skip = channel_mlp_skip
self.use_channel_mlp = use_channel_mlp
self.channel_mlp_expansion = channel_mlp_expansion
self.channel_mlp_dropout = channel_mlp_dropout
self.fft_norm = fft_norm
self.implementation = implementation
self.separable = separable
self.preactivation = preactivation
self.ada_in_features = ada_in_features
self.diff_layers = diff_layers
self.conv_padding_mode = conv_padding_mode
self.default_in_shape = default_in_shape
self.fin_diff_kernel_size = fin_diff_kernel_size
self.mix_derivatives = mix_derivatives
self.disco_layers = disco_layers
self.disco_kernel_shape = disco_kernel_shape
self.radius_cutoff = radius_cutoff
self.domain_length = domain_length
self.disco_groups = disco_groups
self.disco_bias = disco_bias
self.periodic = self.conv_padding_mode in ["circular", "periodic"]
assert (
len(diff_layers) == n_layers
), f"diff_layers must either provide a single bool value or a list of booleans of length n_layers,\
got {len(diff_layers)=}"
assert (
len(disco_layers) == n_layers
), f"disco_layers must either provide a single bool value or a list of booleans of length n_layers,\
got {len(disco_layers)=}"
self.convs = nn.ModuleList(
[
conv_module(
self.in_channels,
self.out_channels,
self.n_modes,
resolution_scaling_factor=None if resolution_scaling_factor is None else self.resolution_scaling_factor[i],
max_n_modes=max_n_modes,
rank=rank,
fixed_rank_modes=fixed_rank_modes,
implementation=implementation,
separable=separable,
factorization=factorization,
decomposition_kwargs=decomposition_kwargs,
)
for i in range(n_layers)
]
)
if local_no_skip is not None:
self.local_no_skips = nn.ModuleList(
[
skip_connection(
self.in_channels,
self.out_channels,
skip_type=local_no_skip,
n_dim=self.n_dim,
)
for _ in range(n_layers)
]
)
else:
self.local_no_skips = None
self.diff_groups = 1 if mix_derivatives else in_channels
self.differential = nn.ModuleList(
[
FiniteDifferenceConvolution(
self.in_channels,
self.out_channels,
self.n_dim,
self.fin_diff_kernel_size,
self.diff_groups,
self.conv_padding_mode,
)
for _ in range(sum(self.diff_layers))
]
)
self.local_convs = nn.ModuleList(
[
EquidistantDiscreteContinuousConv2d(
self.in_channels,
self.out_channels,
in_shape=self.default_in_shape,
out_shape=self.default_in_shape,
kernel_shape=self.disco_kernel_shape,
domain_length=self.domain_length,
radius_cutoff=self.radius_cutoff,
periodic=self.periodic,
groups=self.disco_groups,
bias=self.disco_bias,
)
for _ in range(sum(self.disco_layers))
]
)
# Helper for calling differential layers
self.differential_idx_list = []
j = 0
for i in range(n_layers):
if self.diff_layers[i]:
self.differential_idx_list.append(j)
j += 1
else:
self.differential_idx_list.append(-1)
assert max(self.differential_idx_list) == sum(self.diff_layers) - 1
# Helper for calling local conv layers
self.disco_idx_list = []
j = 0
for i in range(n_layers):
if self.disco_layers[i]:
self.disco_idx_list.append(j)
j += 1
else:
self.disco_idx_list.append(-1)
assert max(self.disco_idx_list) == sum(self.disco_layers) - 1
if use_channel_mlp:
self.mlp = nn.ModuleList(
[
ChannelMLP(
in_channels=self.out_channels,
hidden_channels=round(self.out_channels * channel_mlp_expansion),
dropout=channel_mlp_dropout,
n_dim=self.n_dim,
)
for _ in range(n_layers)
]
)
if channel_mlp_skip is not None:
self.channel_mlp_skips = nn.ModuleList(
[
skip_connection(
self.in_channels,
self.out_channels,
skip_type=channel_mlp_skip,
n_dim=self.n_dim,
)
for _ in range(n_layers)
]
)
else:
self.channel_mlp_skips = None
else:
self.mlp = None
# Each block will have 2 norms if we also use an MLP
self.n_norms = 1 if self.mlp is None else 2
if norm is None:
self.norm = None
elif norm == "instance_norm":
self.norm = nn.ModuleList(
[InstanceNorm() for _ in range(n_layers * self.n_norms)]
)
elif norm == "group_norm":
self.norm = nn.ModuleList(
[
nn.GroupNorm(num_groups=1, num_channels=self.out_channels)
for _ in range(n_layers * self.n_norms)
]
)
elif norm == "ada_in":
self.norm = nn.ModuleList(
[
AdaIN(ada_in_features, out_channels)
for _ in range(n_layers * self.n_norms)
]
)
else:
raise ValueError(
f"Got norm={norm} but expected None or one of "
"[instance_norm, group_norm, ada_in]"
)
[docs]
def set_ada_in_embeddings(self, *embeddings):
"""Sets the embeddings of each Ada-IN norm layers
Parameters
----------
embeddings : tensor or list of tensor
if a single embedding is given, it will be used for each norm layer
otherwise, each embedding will be used for the corresponding norm layer
"""
if len(embeddings) == 1:
for norm in self.norm:
norm.set_embedding(embeddings[0])
else:
for norm, embedding in zip(self.norm, embeddings):
norm.set_embedding(embedding)
[docs]
def forward(self, x, index=0, output_shape=None):
if self.preactivation:
return self.forward_with_preactivation(x, index, output_shape)
else:
return self.forward_with_postactivation(x, index, output_shape)
def forward_with_postactivation(self, x, index=0, output_shape=None):
if self.local_no_skips is not None:
x_skip_local_no = self.local_no_skips[index](x)
x_skip_local_no = self.convs[index].transform(x_skip_local_no, output_shape=output_shape)
if self.mlp is not None and self.channel_mlp_skips is not None:
x_skip_mlp = self.channel_mlp_skips[index](x)
x_skip_mlp = self.convs[index].transform(x_skip_mlp, output_shape=output_shape)
if self.stabilizer == "tanh":
x = torch.tanh(x)
x_local_no = self.convs[index](x, output_shape=output_shape)
if self.differential_idx_list[index] != -1:
grid_width_scaling_factor = 1 / (x.shape[-1] / self.default_in_shape[0])
x_differential = self.differential[self.differential_idx_list[index]](x, grid_width_scaling_factor)
x_differential = self.convs[index].transform(x_differential, output_shape=output_shape)
else:
x_differential = 0
if self.disco_idx_list[index] != -1:
x_localconv = self.local_convs[self.disco_idx_list[index]](x)
x_localconv = self.convs[index].transform(x_localconv, output_shape=output_shape)
else:
x_localconv = 0
x_local_no_diff_disco = x_local_no + x_differential + x_localconv
if self.norm is not None:
x_local_no_diff_disco = self.norm[self.n_norms * index](x_local_no_diff_disco)
x = (
x_local_no_diff_disco + x_skip_local_no
if self.local_no_skips is not None
else x_local_no_diff_disco
)
if (self.mlp is not None) or (index < (self.n_layers - 1)):
x = self.non_linearity(x)
if self.mlp is not None:
if self.channel_mlp_skips is not None:
x = self.mlp[index](x) + x_skip_mlp
else:
x = self.mlp[index](x)
if self.norm is not None:
x = self.norm[self.n_norms * index + 1](x)
if index < (self.n_layers - 1):
x = self.non_linearity(x)
return x
def forward_with_preactivation(self, x, index=0, output_shape=None):
# Apply non-linear activation (and norm)
# before this block's convolution/forward pass:
x = self.non_linearity(x)
if self.norm is not None:
x = self.norm[self.n_norms * index](x)
if self.differential_idx_list[index] != -1:
grid_width_scaling_factor = 1 / (x.shape[-1] / self.default_grid_res)
x_differential = self.differential[self.differential_idx_list[index]](x, grid_width_scaling_factor)
else:
x_differential = 0
if self.disco_idx_list[index] != -1:
x_localconv = self.local_convs[self.disco_idx_list[index]](x)
else:
x_localconv = 0
if self.local_no_skips is not None:
x_skip_local_no = self.local_no_skips[index](x)
x_skip_local_local_no = self.convs[index].transform(x_skip_local_no + x_differential + x_localconv, output_shape=output_shape)
else:
x_skip_local_local_no = self.convs[index].transform(x_differential + x_localconv, output_shape=output_shape)
if self.mlp is not None:
if self.channel_mlp_skips is not None:
x_skip_mlp = self.channel_mlp_skips[index](x)
x_skip_mlp = self.convs[index].transform(x_skip_mlp, output_shape=output_shape)
if self.stabilizer == "tanh":
x = torch.tanh(x)
x_local_no = self.convs[index](x, output_shape=output_shape)
x = (
x_local_no + x_skip_local_local_no
if self.local_no_skips is not None
else x_local_no
)
if self.mlp is not None:
if index < (self.n_layers - 1):
x = self.non_linearity(x)
if self.norm is not None:
x = self.norm[self.n_norms * index + 1](x)
if self.channel_mlp_skips is not None:
x = self.mlp[index](x) + x_skip_mlp
else:
x = self.mlp[index](x)
return x
@property
def n_modes(self):
return self._n_modes
@n_modes.setter
def n_modes(self, n_modes):
for i in range(self.n_layers):
self.convs[i].n_modes = n_modes
self._n_modes = n_modes
[docs]
def get_block(self, indices):
"""Returns a sub-NO Block layer from the jointly parametrized main block
The parametrization of an LocalNOBlock layer is shared with the main one.
"""
if self.n_layers == 1:
raise ValueError(
"A single layer is parametrized, directly use the main class."
)
return SubModule(self, indices)
def __getitem__(self, indices):
return self.get_block(indices)