from functools import partialmethod
from typing import Tuple, List, Union, Literal
Number = Union[float, int]
import torch
import torch.nn as nn
import torch.nn.functional as F
# Set warning filter to show each warning only once
import warnings
warnings.filterwarnings("once", category=UserWarning)
from ..layers.embeddings import GridEmbeddingND, GridEmbedding2D
from ..layers.spectral_convolution import SpectralConv
from ..layers.padding import DomainPadding
from ..layers.fno_block import FNOBlocks
from ..layers.channel_mlp import ChannelMLP
from ..layers.complex import ComplexValued
from .base_model import BaseModel
[docs]
class FNO(BaseModel, name="FNO"):
"""N-Dimensional Fourier Neural Operator. The FNO learns a mapping between
spaces of functions discretized over regular grids using Fourier convolutions,
as described in [1]_.
The key component of an FNO is its SpectralConv layer (see
``neuralop.layers.spectral_convolution``), which is similar to a standard CNN
conv layer but operates in the frequency domain.
For a deeper dive into the FNO architecture, refer to :ref:`fno_intro`.
Main Parameters
---------------
n_modes : Tuple[int, ...]
Number of modes to keep in Fourier Layer, along each dimension.
The dimensionality of the FNO is inferred from len(n_modes).
n_modes must be larger enough but smaller than max_resolution//2 (Nyquist frequency)
in_channels : int
Number of channels in input function. Determined by the problem.
out_channels : int
Number of channels in output function. Determined by the problem.
hidden_channels : int
Width of the FNO (i.e. number of channels).
This significantly affects the number of parameters of the FNO.
Good starting point can be 64, and then increased if more expressivity is needed.
Update lifting_channel_ratio and projection_channel_ratio accordingly since they are proportional to hidden_channels.
n_layers : int, optional
Number of Fourier Layers. Default: 4
Other parameters
---------------
lifting_channel_ratio : Number, optional
Ratio of lifting channels to hidden_channels.
The number of lifting channels in the lifting block of the FNO is
lifting_channel_ratio * hidden_channels (e.g. default 2 * hidden_channels).
projection_channel_ratio : Number, optional
Ratio of projection channels to hidden_channels.
The number of projection channels in the projection block of the FNO is
projection_channel_ratio * hidden_channels (e.g. default 2 * hidden_channels).
positional_embedding : Union[str, nn.Module], optional
Positional embedding to apply to last channels of raw input
before being passed through the FNO.
Options:
- "grid": Appends a grid positional embedding with default settings to the last channels of raw input.
Assumes the inputs are discretized over a grid with entry [0,0,...] at the origin and side lengths of 1.
- GridEmbeddingND: Uses this module directly (see :mod:`neuralop.embeddings.GridEmbeddingND` for details).
- GridEmbedding2D: Uses this module directly for 2D cases.
- None: Does nothing.
Default: "grid"
non_linearity : nn.Module, optional
Non-Linear activation function module to use. Default: F.gelu
norm : Literal["ada_in", "group_norm", "instance_norm"], optional
Normalization layer to use. Options: "ada_in", "group_norm", "instance_norm", None. Default: None
complex_data : bool, optional
Whether the data is complex-valued. If True, initializes complex-valued modules. Default: False
use_channel_mlp : bool, optional
Whether to use an MLP layer after each FNO block. Default: True
channel_mlp_dropout : float, optional
Dropout parameter for ChannelMLP in FNO Block. Default: 0
channel_mlp_expansion : float, optional
Expansion parameter for ChannelMLP in FNO Block. Default: 0.5
channel_mlp_skip : Literal["linear", "identity", "soft-gating", None], optional
Type of skip connection to use in channel-mixing mlp. Options: "linear", "identity", "soft-gating", None.
Default: "soft-gating"
fno_skip : Literal["linear", "identity", "soft-gating", None], optional
Type of skip connection to use in FNO layers. Options: "linear", "identity", "soft-gating", None.
Default: "linear"
resolution_scaling_factor : Union[Number, List[Number]], optional
Layer-wise factor by which to scale the domain resolution of function.
Options:
- None: No scaling
- Single number n: Scales resolution by n at each layer
- List of numbers [n_0, n_1,...]: Scales layer i's resolution by n_i
Default: None
domain_padding : Union[Number, List[Number]], optional
Percentage of padding to use.
Options:
- None: No padding
- Single number: Percentage of padding to use along all dimensions
- List of numbers [p1, p2, ..., pN]: Percentage of padding along each dimension
Default: None
fno_block_precision : str, optional
Precision mode in which to perform spectral convolution.
Options: "full", "half", "mixed". Default: "full". Default: "full"
stabilizer : str, optional
Whether to use a stabilizer in FNO block. Options: "tanh", None. Default: None.
stabilizer greatly improves performance in the case `fno_block_precision='mixed'`.
max_n_modes : Tuple[int, ...], optional
Maximum number of modes to use in Fourier domain during training.
None means that all the n_modes are used.
Tuple of integers: Incrementally increase the number of modes during training.
This can be updated dynamically during training.
factorization : str, optional
Tensor factorization of the FNO layer weights to use.
Options: "None", "Tucker", "CP", "TT"
Other factorization methods supported by tltorch. Default: None
rank : float, optional
Tensor rank to use in factorization. Default: 1.0
Set to float <1.0 when using TFNO (i.e. when factorization is not None).
A TFNO with rank 0.1 has roughly 10% of the parameters of a dense FNO.
fixed_rank_modes : bool, optional
Whether to not factorize certain modes. Default: False
implementation : str, optional
Implementation method for factorized tensors.
Options: "factorized", "reconstructed". Default: "factorized"
decomposition_kwargs : dict, optional
Extra kwargs for tensor decomposition (see `tltorch.FactorizedTensor`). Default: {}
separable : bool, optional
Whether to use a separable spectral convolution. Default: False
preactivation : bool, optional
Whether to compute FNO forward pass with resnet-style preactivation. Default: False
conv_module : nn.Module, optional
Module to use for FNOBlock's convolutions. Default: SpectralConv
Examples
---------
>>> from neuralop.models import FNO
>>> model = FNO(n_modes=(12,12), in_channels=1, out_channels=1, hidden_channels=64)
>>> model
FNO(
(positional_embedding): GridEmbeddingND()
(fno_blocks): FNOBlocks(
(convs): SpectralConv(
(weight): ModuleList(
(0-3): 4 x DenseTensor(shape=torch.Size([64, 64, 12, 7]), rank=None)
)
)
... torch.nn.Module printout truncated ...
References
-----------
.. [1] :
Li, Z. et al. "Fourier Neural Operator for Parametric Partial Differential
Equations" (2021). ICLR 2021, https://arxiv.org/pdf/2010.08895.
"""
def __init__(
self,
n_modes: Tuple[int, ...],
in_channels: int,
out_channels: int,
hidden_channels: int,
n_layers: int = 4,
lifting_channel_ratio: Number = 2,
projection_channel_ratio: Number = 2,
positional_embedding: Union[str, nn.Module] = "grid",
non_linearity: nn.Module = F.gelu,
norm: Literal["ada_in", "group_norm", "instance_norm"] = None,
complex_data: bool = False,
use_channel_mlp: bool = True,
channel_mlp_dropout: float = 0,
channel_mlp_expansion: float = 0.5,
channel_mlp_skip: Literal["linear", "identity", "soft-gating", None] = "soft-gating",
fno_skip: Literal["linear", "identity", "soft-gating", None] = "linear",
resolution_scaling_factor: Union[Number, List[Number]] = None,
domain_padding: Union[Number, List[Number]] = None,
fno_block_precision: str = "full",
stabilizer: str = None,
max_n_modes: Tuple[int, ...] = None,
factorization: str = None,
rank: float = 1.0,
fixed_rank_modes: bool = False,
implementation: str = "factorized",
decomposition_kwargs: dict = None,
separable: bool = False,
preactivation: bool = False,
conv_module: nn.Module = SpectralConv,
):
if decomposition_kwargs is None:
decomposition_kwargs = {}
super().__init__()
self.n_dim = len(n_modes)
# n_modes is a special property - see the class' property for underlying mechanism
# When updated, change should be reflected in fno blocks
self._n_modes = n_modes
self.hidden_channels = hidden_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.n_layers = n_layers
# init lifting and projection channels using ratios w.r.t hidden channels
self.lifting_channel_ratio = lifting_channel_ratio
self.lifting_channels = int(lifting_channel_ratio * self.hidden_channels)
self.projection_channel_ratio = projection_channel_ratio
self.projection_channels = int(projection_channel_ratio * self.hidden_channels)
self.non_linearity = non_linearity
self.rank = rank
self.factorization = factorization
self.fixed_rank_modes = fixed_rank_modes
self.decomposition_kwargs = decomposition_kwargs
self.fno_skip = (fno_skip,)
self.channel_mlp_skip = (channel_mlp_skip,)
self.implementation = implementation
self.separable = separable
self.preactivation = preactivation
self.complex_data = complex_data
self.fno_block_precision = fno_block_precision
## Positional embedding
if positional_embedding == "grid":
spatial_grid_boundaries = [[0.0, 1.0]] * self.n_dim
self.positional_embedding = GridEmbeddingND(
in_channels=self.in_channels,
dim=self.n_dim,
grid_boundaries=spatial_grid_boundaries,
)
elif isinstance(positional_embedding, GridEmbedding2D):
if self.n_dim == 2:
self.positional_embedding = positional_embedding
else:
raise ValueError(
f"Error: expected {self.n_dim}-d positional embeddings, got {positional_embedding}"
)
elif isinstance(positional_embedding, GridEmbeddingND):
self.positional_embedding = positional_embedding
elif positional_embedding is None:
self.positional_embedding = None
else:
raise ValueError(
f"Error: tried to instantiate FNO positional embedding with {positional_embedding},\
expected one of 'grid', GridEmbeddingND"
)
## Domain padding
if domain_padding is not None and (
(isinstance(domain_padding, list) and sum(domain_padding) > 0)
or (isinstance(domain_padding, (float, int)) and domain_padding > 0)
):
self.domain_padding = DomainPadding(
domain_padding=domain_padding,
resolution_scaling_factor=resolution_scaling_factor,
)
else:
self.domain_padding = None
## Resolution scaling factor
if resolution_scaling_factor is not None:
if isinstance(resolution_scaling_factor, (float, int)):
resolution_scaling_factor = [resolution_scaling_factor] * self.n_layers
self.resolution_scaling_factor = resolution_scaling_factor
## FNO blocks
self.fno_blocks = FNOBlocks(
in_channels=hidden_channels,
out_channels=hidden_channels,
n_modes=self.n_modes,
resolution_scaling_factor=resolution_scaling_factor,
use_channel_mlp=use_channel_mlp,
channel_mlp_dropout=channel_mlp_dropout,
channel_mlp_expansion=channel_mlp_expansion,
non_linearity=non_linearity,
stabilizer=stabilizer,
norm=norm,
preactivation=preactivation,
fno_skip=fno_skip,
channel_mlp_skip=channel_mlp_skip,
complex_data=complex_data,
max_n_modes=max_n_modes,
fno_block_precision=fno_block_precision,
rank=rank,
fixed_rank_modes=fixed_rank_modes,
implementation=implementation,
separable=separable,
factorization=factorization,
decomposition_kwargs=decomposition_kwargs,
conv_module=conv_module,
n_layers=n_layers,
)
## Lifting layer
# if adding a positional embedding, add those channels to lifting
lifting_in_channels = self.in_channels
if self.positional_embedding is not None:
lifting_in_channels += self.n_dim
# if lifting_channels is passed, make lifting a Channel-Mixing MLP
# with a hidden layer of size lifting_channels
if self.lifting_channels:
self.lifting = ChannelMLP(
in_channels=lifting_in_channels,
out_channels=self.hidden_channels,
hidden_channels=self.lifting_channels,
n_layers=2,
n_dim=self.n_dim,
non_linearity=non_linearity,
)
# otherwise, make it a linear layer
else:
self.lifting = ChannelMLP(
in_channels=lifting_in_channels,
hidden_channels=self.hidden_channels,
out_channels=self.hidden_channels,
n_layers=1,
n_dim=self.n_dim,
non_linearity=non_linearity,
)
# Convert lifting to a complex ChannelMLP if self.complex_data==True
if self.complex_data:
self.lifting = ComplexValued(self.lifting)
## Projection layer
self.projection = ChannelMLP(
in_channels=self.hidden_channels,
out_channels=out_channels,
hidden_channels=self.projection_channels,
n_layers=2,
n_dim=self.n_dim,
non_linearity=non_linearity,
)
if self.complex_data:
self.projection = ComplexValued(self.projection)
[docs]
def forward(self, x, output_shape=None, **kwargs):
"""FNO's forward pass
1. Applies optional positional encoding
2. Sends inputs through a lifting layer to a high-dimensional latent space
3. Applies optional domain padding to high-dimensional intermediate function representation
4. Applies `n_layers` Fourier/FNO layers in sequence (SpectralConvolution + skip connections, nonlinearity)
5. If domain padding was applied, domain padding is removed
6. Projection of intermediate function representation to the output channels
Parameters
----------
x : tensor
input tensor
output_shape : {tuple, tuple list, None}, default is None
Gives the option of specifying the exact output shape for odd shaped inputs.
* If None, don't specify an output shape
* If tuple, specifies the output-shape of the **last** FNO Block
* If tuple list, specifies the exact output-shape of each FNO Block
"""
if kwargs:
warnings.warn(
f"FNO.forward() received unexpected keyword arguments: {list(kwargs.keys())}. "
"These arguments will be ignored.",
UserWarning,
stacklevel=2,
)
if output_shape is None:
output_shape = [None] * self.n_layers
elif isinstance(output_shape, tuple):
output_shape = [None] * (self.n_layers - 1) + [output_shape]
# append spatial pos embedding if set
if self.positional_embedding is not None:
x = self.positional_embedding(x)
x = self.lifting(x)
if self.domain_padding is not None:
x = self.domain_padding.pad(x)
for layer_idx in range(self.n_layers):
x = self.fno_blocks(x, layer_idx, output_shape=output_shape[layer_idx])
if self.domain_padding is not None:
x = self.domain_padding.unpad(x)
x = self.projection(x)
return x
@property
def n_modes(self):
return self._n_modes
@n_modes.setter
def n_modes(self, n_modes):
self.fno_blocks.n_modes = n_modes
self._n_modes = n_modes
def partialclass(new_name, cls, *args, **kwargs):
"""Create a new class with different default values
See the Spherical FNO class in neuralop/models/sfno.py for an example.
Notes
-----
An obvious alternative would be to use functools.partial
>>> new_class = partial(cls, **kwargs)
The issue is twofold:
1. the class doesn't have a name, so one would have to set it explicitly:
>>> new_class.__name__ = new_name
2. the new class will be a functools object and one cannot inherit from it.
Instead, here, we define dynamically a new class, inheriting from the existing one.
"""
__init__ = partialmethod(cls.__init__, *args, **kwargs)
return type(
new_name,
(cls,),
{
"__init__": __init__,
"__doc__": cls.__doc__,
"forward": cls.forward,
},
)
[docs]
class TFNO(FNO):
"""Tucker Tensorized Fourier Neural Operator (TFNO).
TFNO is an FNO with Tucker factorization enabled by default.
It uses Tucker factorization of the weights, making the forward pass efficient by contracting
directly with the factors of the decomposition.
This results in a fraction of the parameters of an equivalent dense FNO.
Parameters
----------
factorization : str, optional
Tensor factorization method, by default "Tucker"
rank : float, optional
Tensor rank for factorization, by default 0.1.
A TFNO with rank 0.1 has roughly 10% of the parameters of a dense FNO.
All other parameters are inherited from FNO with identical defaults.
See FNO class docstring for the complete parameter list.
Examples
--------
>>> from neuralop.models import TFNO
>>> # Create a TFNO model with default Tucker factorization
>>> model = TFNO(n_modes=(12, 12), in_channels=1, out_channels=1, hidden_channels=64)
>>>
>>> # Equivalent FNO model with explicit factorization:
>>> model = FNO(n_modes=(12, 12), in_channels=1, out_channels=1, hidden_channels=64,
... factorization="Tucker", rank=0.1)
"""
def __init__(self, *args, **kwargs):
kwargs.setdefault("factorization", "Tucker")
kwargs.setdefault("rank", 0.1)
super().__init__(*args, **kwargs)