from functools import partialmethod
from typing import Tuple, List, Union
Number = Union[float, int]
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..layers.embeddings import GridEmbeddingND, GridEmbedding2D
from ..layers.spectral_convolution import SpectralConv
from ..layers.padding import DomainPadding
from ..layers.fno_block import FNOBlocks
from ..layers.channel_mlp import ChannelMLP
from ..layers.complex import ComplexValued
from .base_model import BaseModel
[docs]
class FNO(BaseModel, name='FNO'):
"""N-Dimensional Fourier Neural Operator. The FNO learns a mapping between
spaces of functions discretized over regular grids using Fourier convolutions,
as described in [1]_.
The key component of an FNO is its SpectralConv layer (see
``neuralop.layers.spectral_convolution``), which is similar to a standard CNN
conv layer but operates in the frequency domain.
For a deeper dive into the FNO architecture, refer to :ref:`fno_intro`.
Parameters
----------
n_modes : Tuple[int]
number of modes to keep in Fourier Layer, along each dimension
The dimensionality of the FNO is inferred from ``len(n_modes)``
in_channels : int
Number of channels in input function
out_channels : int
Number of channels in output function
hidden_channels : int
width of the FNO (i.e. number of channels), by default 256
n_layers : int, optional
Number of Fourier Layers, by default 4
Documentation for more advanced parameters is below.
Other parameters
------------------
lifting_channel_ratio : int, optional
ratio of lifting channels to hidden_channels, by default 2
The number of liting channels in the lifting block of the FNO is
lifting_channel_ratio * hidden_channels (e.g. default 512)
projection_channel_ratio : int, optional
ratio of projection channels to hidden_channels, by default 2
The number of projection channels in the projection block of the FNO is
projection_channel_ratio * hidden_channels (e.g. default 512)
positional_embedding : Union[str, nn.Module], optional
Positional embedding to apply to last channels of raw input
before being passed through the FNO. Defaults to "grid"
* If "grid", appends a grid positional embedding with default settings to
the last channels of raw input. Assumes the inputs are discretized
over a grid with entry [0,0,...] at the origin and side lengths of 1.
* If an initialized GridEmbedding module, uses this module directly
See :mod:`neuralop.embeddings.GridEmbeddingND` for details.
* If None, does nothing
non_linearity : nn.Module, optional
Non-Linear activation function module to use, by default F.gelu
norm : str {"ada_in", "group_norm", "instance_norm"}, optional
Normalization layer to use, by default None
complex_data : bool, optional
Whether data is complex-valued (default False)
if True, initializes complex-valued modules.
channel_mlp_dropout : float, optional
dropout parameter for ChannelMLP in FNO Block, by default 0
channel_mlp_expansion : float, optional
expansion parameter for ChannelMLP in FNO Block, by default 0.5
channel_mlp_skip : str {'linear', 'identity', 'soft-gating'}, optional
Type of skip connection to use in channel-mixing mlp, by default 'soft-gating'
fno_skip : str {'linear', 'identity', 'soft-gating'}, optional
Type of skip connection to use in FNO layers, by default 'linear'
resolution_scaling_factor : Union[Number, List[Number]], optional
layer-wise factor by which to scale the domain resolution of function, by default None
* If a single number n, scales resolution by n at each layer
* if a list of numbers [n_0, n_1,...] scales layer i's resolution by n_i.
domain_padding : Union[Number, List[Number]], optional
If not None, percentage of padding to use, by default None
To vary the percentage of padding used along each input dimension,
pass in a list of percentages e.g. [p1, p2, ..., pN] such that
p1 corresponds to the percentage of padding along dim 1, etc.
domain_padding_mode : str {'symmetric', 'one-sided'}, optional
How to perform domain padding, by default 'one-sided'
fno_block_precision : str {'full', 'half', 'mixed'}, optional
precision mode in which to perform spectral convolution, by default "full"
stabilizer : str {'tanh'} | None, optional
whether to use a tanh stabilizer in FNO block, by default None
Note: stabilizer greatly improves performance in the case
`fno_block_precision='mixed'`.
max_n_modes : Tuple[int] | None, optional
* If not None, this allows to incrementally increase the number of
modes in Fourier domain during training. Has to verify n <= N
for (n, m) in zip(max_n_modes, n_modes).
* If None, all the n_modes are used.
This can be updated dynamically during training.
factorization : str, optional
Tensor factorization of the FNO layer weights to use, by default None.
* If None, a dense tensor parametrizes the Spectral convolutions
* Otherwise, the specified tensor factorization is used.
rank : float, optional
tensor rank to use in above factorization, by default 1.0
fixed_rank_modes : bool, optional
Modes to not factorize, by default False
implementation : str {'factorized', 'reconstructed'}, optional
* If 'factorized', implements tensor contraction with the individual factors of the decomposition
* If 'reconstructed', implements with the reconstructed full tensorized weight.
decomposition_kwargs : dict, optional
extra kwargs for tensor decomposition (see `tltorch.FactorizedTensor`), by default dict()
separable : bool, optional (**DEACTIVATED**)
if True, use a depthwise separable spectral convolution, by default False
preactivation : bool, optional (**DEACTIVATED**)
whether to compute FNO forward pass with resnet-style preactivation, by default False
conv_module : nn.Module, optional
module to use for FNOBlock's convolutions, by default SpectralConv
Examples
---------
>>> from neuralop.models import FNO
>>> model = FNO(n_modes=(12,12), in_channels=1, out_channels=1, hidden_channels=64)
>>> model
FNO(
(positional_embedding): GridEmbeddingND()
(fno_blocks): FNOBlocks(
(convs): SpectralConv(
(weight): ModuleList(
(0-3): 4 x DenseTensor(shape=torch.Size([64, 64, 12, 7]), rank=None)
)
)
... torch.nn.Module printout truncated ...
References
-----------
.. [1] :
Li, Z. et al. "Fourier Neural Operator for Parametric Partial Differential
Equations" (2021). ICLR 2021, https://arxiv.org/pdf/2010.08895.
"""
def __init__(
self,
n_modes: Tuple[int],
in_channels: int,
out_channels: int,
hidden_channels: int,
n_layers: int=4,
lifting_channel_ratio: int=2,
projection_channel_ratio: int=2,
positional_embedding: Union[str, nn.Module]="grid",
non_linearity: nn.Module=F.gelu,
norm: str=None,
complex_data: bool=False,
channel_mlp_dropout: float=0,
channel_mlp_expansion: float=0.5,
channel_mlp_skip: str="soft-gating",
fno_skip: str="linear",
resolution_scaling_factor: Union[Number, List[Number]]=None,
domain_padding: Union[Number, List[Number]]=None,
domain_padding_mode: str="one-sided",
fno_block_precision: str="full",
stabilizer: str=None,
max_n_modes: Tuple[int]=None,
factorization: str=None,
rank: float=1.0,
fixed_rank_modes: bool=False,
implementation: str="factorized",
decomposition_kwargs: dict=dict(),
separable: bool=False,
preactivation: bool=False,
conv_module: nn.Module=SpectralConv,
**kwargs
):
super().__init__()
self.n_dim = len(n_modes)
# n_modes is a special property - see the class' property for underlying mechanism
# When updated, change should be reflected in fno blocks
self._n_modes = n_modes
self.hidden_channels = hidden_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.n_layers = n_layers
# init lifting and projection channels using ratios w.r.t hidden channels
self.lifting_channel_ratio = lifting_channel_ratio
self.lifting_channels = lifting_channel_ratio * self.hidden_channels
self.projection_channel_ratio = projection_channel_ratio
self.projection_channels = projection_channel_ratio * self.hidden_channels
self.non_linearity = non_linearity
self.rank = rank
self.factorization = factorization
self.fixed_rank_modes = fixed_rank_modes
self.decomposition_kwargs = decomposition_kwargs
self.fno_skip = (fno_skip,)
self.channel_mlp_skip = (channel_mlp_skip,)
self.implementation = implementation
self.separable = separable
self.preactivation = preactivation
self.complex_data = complex_data
self.fno_block_precision = fno_block_precision
if positional_embedding == "grid":
spatial_grid_boundaries = [[0., 1.]] * self.n_dim
self.positional_embedding = GridEmbeddingND(in_channels=self.in_channels,
dim=self.n_dim,
grid_boundaries=spatial_grid_boundaries)
elif isinstance(positional_embedding, GridEmbedding2D):
if self.n_dim == 2:
self.positional_embedding = positional_embedding
else:
raise ValueError(f'Error: expected {self.n_dim}-d positional embeddings, got {positional_embedding}')
elif isinstance(positional_embedding, GridEmbeddingND):
self.positional_embedding = positional_embedding
elif positional_embedding == None:
self.positional_embedding = None
else:
raise ValueError(f"Error: tried to instantiate FNO positional embedding with {positional_embedding},\
expected one of \'grid\', GridEmbeddingND")
if domain_padding is not None and (
(isinstance(domain_padding, list) and sum(domain_padding) > 0)
or (isinstance(domain_padding, (float, int)) and domain_padding > 0)
):
self.domain_padding = DomainPadding(
domain_padding=domain_padding,
padding_mode=domain_padding_mode,
resolution_scaling_factor=resolution_scaling_factor,
)
else:
self.domain_padding = None
self.domain_padding_mode = domain_padding_mode
self.complex_data = self.complex_data
if resolution_scaling_factor is not None:
if isinstance(resolution_scaling_factor, (float, int)):
resolution_scaling_factor = [resolution_scaling_factor] * self.n_layers
self.resolution_scaling_factor = resolution_scaling_factor
self.fno_blocks = FNOBlocks(
in_channels=hidden_channels,
out_channels=hidden_channels,
n_modes=self.n_modes,
resolution_scaling_factor=resolution_scaling_factor,
channel_mlp_dropout=channel_mlp_dropout,
channel_mlp_expansion=channel_mlp_expansion,
non_linearity=non_linearity,
stabilizer=stabilizer,
norm=norm,
preactivation=preactivation,
fno_skip=fno_skip,
channel_mlp_skip=channel_mlp_skip,
complex_data=complex_data,
max_n_modes=max_n_modes,
fno_block_precision=fno_block_precision,
rank=rank,
fixed_rank_modes=fixed_rank_modes,
implementation=implementation,
separable=separable,
factorization=factorization,
decomposition_kwargs=decomposition_kwargs,
conv_module=conv_module,
n_layers=n_layers,
**kwargs
)
# if adding a positional embedding, add those channels to lifting
lifting_in_channels = self.in_channels
if self.positional_embedding is not None:
lifting_in_channels += self.n_dim
# if lifting_channels is passed, make lifting a Channel-Mixing MLP
# with a hidden layer of size lifting_channels
if self.lifting_channels:
self.lifting = ChannelMLP(
in_channels=lifting_in_channels,
out_channels=self.hidden_channels,
hidden_channels=self.lifting_channels,
n_layers=2,
n_dim=self.n_dim,
non_linearity=non_linearity
)
# otherwise, make it a linear layer
else:
self.lifting = ChannelMLP(
in_channels=lifting_in_channels,
hidden_channels=self.hidden_channels,
out_channels=self.hidden_channels,
n_layers=1,
n_dim=self.n_dim,
non_linearity=non_linearity
)
# Convert lifting to a complex ChannelMLP if self.complex_data==True
if self.complex_data:
self.lifting = ComplexValued(self.lifting)
self.projection = ChannelMLP(
in_channels=self.hidden_channels,
out_channels=out_channels,
hidden_channels=self.projection_channels,
n_layers=2,
n_dim=self.n_dim,
non_linearity=non_linearity,
)
if self.complex_data:
self.projection = ComplexValued(self.projection)
[docs]
def forward(self, x, output_shape=None, **kwargs):
"""FNO's forward pass
1. Applies optional positional encoding
2. Sends inputs through a lifting layer to a high-dimensional latent space
3. Applies optional domain padding to high-dimensional intermediate function representation
4. Applies `n_layers` Fourier/FNO layers in sequence (SpectralConvolution + skip connections, nonlinearity)
5. If domain padding was applied, domain padding is removed
6. Projection of intermediate function representation to the output channels
Parameters
----------
x : tensor
input tensor
output_shape : {tuple, tuple list, None}, default is None
Gives the option of specifying the exact output shape for odd shaped inputs.
* If None, don't specify an output shape
* If tuple, specifies the output-shape of the **last** FNO Block
* If tuple list, specifies the exact output-shape of each FNO Block
"""
if output_shape is None:
output_shape = [None]*self.n_layers
elif isinstance(output_shape, tuple):
output_shape = [None]*(self.n_layers - 1) + [output_shape]
# append spatial pos embedding if set
if self.positional_embedding is not None:
x = self.positional_embedding(x)
x = self.lifting(x)
if self.domain_padding is not None:
x = self.domain_padding.pad(x)
for layer_idx in range(self.n_layers):
x = self.fno_blocks(x, layer_idx, output_shape=output_shape[layer_idx])
if self.domain_padding is not None:
x = self.domain_padding.unpad(x)
x = self.projection(x)
return x
@property
def n_modes(self):
return self._n_modes
@n_modes.setter
def n_modes(self, n_modes):
self.fno_blocks.n_modes = n_modes
self._n_modes = n_modes
[docs]
class FNO1d(FNO):
"""1D Fourier Neural Operator
For the full list of parameters, see :class:`neuralop.models.FNO`.
Parameters
----------
modes_height : int
number of Fourier modes to keep along the height
"""
def __init__(
self,
n_modes_height,
hidden_channels,
in_channels=3,
out_channels=1,
lifting_channels=256,
projection_channels=256,
max_n_modes=None,
n_layers=4,
resolution_scaling_factor=None,
non_linearity=F.gelu,
stabilizer=None,
complex_data=False,
fno_block_precision="full",
channel_mlp_dropout=0,
channel_mlp_expansion=0.5,
norm=None,
skip="soft-gating",
separable=False,
preactivation=False,
factorization=None,
rank=1.0,
fixed_rank_modes=False,
implementation="factorized",
decomposition_kwargs=dict(),
domain_padding=None,
domain_padding_mode="one-sided",
**kwargs
):
super().__init__(
n_modes=(n_modes_height,),
hidden_channels=hidden_channels,
in_channels=in_channels,
out_channels=out_channels,
lifting_channels=lifting_channels,
projection_channels=projection_channels,
n_layers=n_layers,
resolution_scaling_factor=resolution_scaling_factor,
non_linearity=non_linearity,
stabilizer=stabilizer,
complex_data=complex_data,
fno_block_precision=fno_block_precision,
channel_mlp_dropout=channel_mlp_dropout,
channel_mlp_expansion=channel_mlp_expansion,
max_n_modes=max_n_modes,
norm=norm,
skip=skip,
separable=separable,
preactivation=preactivation,
factorization=factorization,
rank=rank,
fixed_rank_modes=fixed_rank_modes,
implementation=implementation,
decomposition_kwargs=decomposition_kwargs,
domain_padding=domain_padding,
domain_padding_mode=domain_padding_mode,
)
self.n_modes_height = n_modes_height
[docs]
class FNO2d(FNO):
"""2D Fourier Neural Operator
For the full list of parameters, see :class:`neuralop.models.FNO`.
Parameters
----------
n_modes_width : int
number of modes to keep in Fourier Layer, along the width
n_modes_height : int
number of Fourier modes to keep along the height
"""
def __init__(
self,
n_modes_height,
n_modes_width,
hidden_channels,
in_channels=3,
out_channels=1,
lifting_channels=256,
projection_channels=256,
n_layers=4,
resolution_scaling_factor=None,
max_n_modes=None,
non_linearity=F.gelu,
stabilizer=None,
complex_data=False,
fno_block_precision="full",
channel_mlp_dropout=0,
channel_mlp_expansion=0.5,
norm=None,
skip="soft-gating",
separable=False,
preactivation=False,
factorization=None,
rank=1.0,
fixed_rank_modes=False,
implementation="factorized",
decomposition_kwargs=dict(),
domain_padding=None,
domain_padding_mode="one-sided",
**kwargs
):
super().__init__(
n_modes=(n_modes_height, n_modes_width),
hidden_channels=hidden_channels,
in_channels=in_channels,
out_channels=out_channels,
lifting_channels=lifting_channels,
projection_channels=projection_channels,
n_layers=n_layers,
resolution_scaling_factor=resolution_scaling_factor,
non_linearity=non_linearity,
stabilizer=stabilizer,
complex_data=complex_data,
fno_block_precision=fno_block_precision,
channel_mlp_dropout=channel_mlp_dropout,
channel_mlp_expansion=channel_mlp_expansion,
max_n_modes=max_n_modes,
norm=norm,
skip=skip,
separable=separable,
preactivation=preactivation,
factorization=factorization,
rank=rank,
fixed_rank_modes=fixed_rank_modes,
implementation=implementation,
decomposition_kwargs=decomposition_kwargs,
domain_padding=domain_padding,
domain_padding_mode=domain_padding_mode,
)
self.n_modes_height = n_modes_height
self.n_modes_width = n_modes_width
[docs]
class FNO3d(FNO):
"""3D Fourier Neural Operator
For the full list of parameters, see :class:`neuralop.models.FNO`.
Parameters
----------
modes_width : int
number of modes to keep in Fourier Layer, along the width
modes_height : int
number of Fourier modes to keep along the height
modes_depth : int
number of Fourier modes to keep along the depth
"""
def __init__(
self,
n_modes_height,
n_modes_width,
n_modes_depth,
hidden_channels,
in_channels=3,
out_channels=1,
lifting_channels=256,
projection_channels=256,
n_layers=4,
resolution_scaling_factor=None,
max_n_modes=None,
non_linearity=F.gelu,
stabilizer=None,
complex_data=False,
fno_block_precision="full",
channel_mlp_dropout=0,
channel_mlp_expansion=0.5,
norm=None,
skip="soft-gating",
separable=False,
preactivation=False,
factorization=None,
rank=1.0,
fixed_rank_modes=False,
implementation="factorized",
decomposition_kwargs=dict(),
domain_padding=None,
domain_padding_mode="one-sided",
**kwargs
):
super().__init__(
n_modes=(n_modes_height, n_modes_width, n_modes_depth),
hidden_channels=hidden_channels,
in_channels=in_channels,
out_channels=out_channels,
lifting_channels=lifting_channels,
projection_channels=projection_channels,
n_layers=n_layers,
resolution_scaling_factor=resolution_scaling_factor,
non_linearity=non_linearity,
stabilizer=stabilizer,
complex_data=complex_data,
fno_block_precision=fno_block_precision,
max_n_modes=max_n_modes,
channel_mlp_dropout=channel_mlp_dropout,
channel_mlp_expansion=channel_mlp_expansion,
norm=norm,
skip=skip,
separable=separable,
preactivation=preactivation,
factorization=factorization,
rank=rank,
fixed_rank_modes=fixed_rank_modes,
implementation=implementation,
decomposition_kwargs=decomposition_kwargs,
domain_padding=domain_padding,
domain_padding_mode=domain_padding_mode,
)
self.n_modes_height = n_modes_height
self.n_modes_width = n_modes_width
self.n_modes_depth = n_modes_depth
def partialclass(new_name, cls, *args, **kwargs):
"""Create a new class with different default values
Notes
-----
An obvious alternative would be to use functools.partial
>>> new_class = partial(cls, **kwargs)
The issue is twofold:
1. the class doesn't have a name, so one would have to set it explicitly:
>>> new_class.__name__ = new_name
2. the new class will be a functools object and one cannot inherit from it.
Instead, here, we define dynamically a new class, inheriting from the existing one.
"""
__init__ = partialmethod(cls.__init__, *args, **kwargs)
new_class = type(
new_name,
(cls,),
{
"__init__": __init__,
"__doc__": cls.__doc__,
"forward": cls.forward,
},
)
return new_class
TFNO = partialclass("TFNO", FNO, factorization="Tucker")
TFNO1d = partialclass("TFNO1d", FNO1d, factorization="Tucker")
TFNO2d = partialclass("TFNO2d", FNO2d, factorization="Tucker")
TFNO3d = partialclass("TFNO3d", FNO3d, factorization="Tucker")