Source code for neuralop.models.fno

from functools import partialmethod

import torch.nn as nn
import torch.nn.functional as F

from ..layers.spectral_convolution import SpectralConv
from ..layers.spherical_convolution import SphericalConv
from ..layers.padding import DomainPadding
from ..layers.fno_block import FNOBlocks
from ..layers.mlp import MLP
from .base_model import BaseModel

[docs] class FNO(BaseModel, name='FNO'): """N-Dimensional Fourier Neural Operator Parameters ---------- n_modes : int tuple number of modes to keep in Fourier Layer, along each dimension The dimensionality of the TFNO is inferred from ``len(n_modes)`` hidden_channels : int width of the FNO (i.e. number of channels) in_channels : int, optional Number of input channels, by default 3 out_channels : int, optional Number of output channels, by default 1 lifting_channels : int, optional number of hidden channels of the lifting block of the FNO, by default 256 projection_channels : int, optional number of hidden channels of the projection block of the FNO, by default 256 n_layers : int, optional Number of Fourier Layers, by default 4 max_n_modes : None or int tuple, default is None * If not None, this allows to incrementally increase the number of modes in Fourier domain during training. Has to verify n <= N for (n, m) in zip(max_n_modes, n_modes). * If None, all the n_modes are used. This can be updated dynamically during training. fno_block_precision : str {'full', 'half', 'mixed'} if 'full', the FNO Block runs in full precision if 'half', the FFT, contraction, and inverse FFT run in half precision if 'mixed', the contraction and inverse FFT run in half precision stabilizer : str {'tanh'} or None, optional By default None, otherwise tanh is used before FFT in the FNO block use_mlp : bool, optional Whether to use an MLP layer after each FNO block, by default False mlp_dropout : float , optional droupout parameter of MLP layer, by default 0 mlp_expansion : float, optional expansion parameter of MLP layer, by default 0.5 non_linearity : nn.Module, optional Non-Linearity module to use, by default F.gelu norm : F.module, optional Normalization layer to use, by default None preactivation : bool, default is False if True, use resnet-style preactivation fno_skip : {'linear', 'identity', 'soft-gating'}, optional Type of skip connection to use in fno, by default 'linear' mlp_skip : {'linear', 'identity', 'soft-gating'}, optional Type of skip connection to use in mlp, by default 'soft-gating' separable : bool, default is False if True, use a depthwise separable spectral convolution factorization : str or None, {'tucker', 'cp', 'tt'} Tensor factorization of the parameters weight to use, by default None. * If None, a dense tensor parametrizes the Spectral convolutions * Otherwise, the specified tensor factorization is used. joint_factorization : bool, optional Whether all the Fourier Layers should be parametrized by a single tensor (vs one per layer), by default False rank : float or rank, optional Rank of the tensor factorization of the Fourier weights, by default 1.0 fixed_rank_modes : bool, optional Modes to not factorize, by default False implementation : {'factorized', 'reconstructed'}, optional, default is 'factorized' If factorization is not None, forward mode to use:: * `reconstructed` : the full weight tensor is reconstructed from the factorization and used for the forward pass * `factorized` : the input is directly contracted with the factors of the decomposition decomposition_kwargs : dict, optional, default is {} Optionaly additional parameters to pass to the tensor decomposition domain_padding : None, float, or List[float], optional If not None, percentage of padding to use, by default None To vary the percentage of padding used along each input dimension, pass in a list of percentages e.g. [p1, p2, ..., pN] such that p1 corresponds to the percentage of padding along dim 1, etc. domain_padding_mode : {'symmetric', 'one-sided'}, optional How to perform domain padding, by default 'one-sided' fft_norm : str, optional by default 'forward' """ def __init__( self, n_modes, hidden_channels, in_channels=3, out_channels=1, lifting_channels=256, projection_channels=256, n_layers=4, output_scaling_factor=None, max_n_modes=None, fno_block_precision="full", use_mlp=False, mlp_dropout=0, mlp_expansion=0.5, non_linearity=F.gelu, stabilizer=None, norm=None, preactivation=False, fno_skip="linear", mlp_skip="soft-gating", separable=False, factorization=None, rank=1.0, joint_factorization=False, fixed_rank_modes=False, implementation="factorized", decomposition_kwargs=dict(), domain_padding=None, domain_padding_mode="one-sided", fft_norm="forward", SpectralConv=SpectralConv, **kwargs ): super().__init__() self.n_dim = len(n_modes) # See the class' property for underlying mechanism # When updated, change should be reflected in fno blocks self._n_modes = n_modes self.hidden_channels = hidden_channels self.lifting_channels = lifting_channels self.projection_channels = projection_channels self.in_channels = in_channels self.out_channels = out_channels self.n_layers = n_layers self.joint_factorization = joint_factorization self.non_linearity = non_linearity self.rank = rank self.factorization = factorization self.fixed_rank_modes = fixed_rank_modes self.decomposition_kwargs = decomposition_kwargs self.fno_skip = (fno_skip,) self.mlp_skip = (mlp_skip,) self.fft_norm = fft_norm self.implementation = implementation self.separable = separable self.preactivation = preactivation self.fno_block_precision = fno_block_precision if domain_padding is not None and ( (isinstance(domain_padding, list) and sum(domain_padding) > 0) or (isinstance(domain_padding, (float, int)) and domain_padding > 0) ): self.domain_padding = DomainPadding( domain_padding=domain_padding, padding_mode=domain_padding_mode, output_scaling_factor=output_scaling_factor, ) else: self.domain_padding = None self.domain_padding_mode = domain_padding_mode if output_scaling_factor is not None and not joint_factorization: if isinstance(output_scaling_factor, (float, int)): output_scaling_factor = [output_scaling_factor] * self.n_layers self.output_scaling_factor = output_scaling_factor self.fno_blocks = FNOBlocks( in_channels=hidden_channels, out_channels=hidden_channels, n_modes=self.n_modes, output_scaling_factor=output_scaling_factor, use_mlp=use_mlp, mlp_dropout=mlp_dropout, mlp_expansion=mlp_expansion, non_linearity=non_linearity, stabilizer=stabilizer, norm=norm, preactivation=preactivation, fno_skip=fno_skip, mlp_skip=mlp_skip, max_n_modes=max_n_modes, fno_block_precision=fno_block_precision, rank=rank, fft_norm=fft_norm, fixed_rank_modes=fixed_rank_modes, implementation=implementation, separable=separable, factorization=factorization, decomposition_kwargs=decomposition_kwargs, joint_factorization=joint_factorization, SpectralConv=SpectralConv, n_layers=n_layers, **kwargs ) # if lifting_channels is passed, make lifting an MLP # with a hidden layer of size lifting_channels if self.lifting_channels: self.lifting = MLP( in_channels=in_channels, out_channels=self.hidden_channels, hidden_channels=self.lifting_channels, n_layers=2, n_dim=self.n_dim, ) # otherwise, make it a linear layer else: self.lifting = MLP( in_channels=in_channels, out_channels=self.hidden_channels, hidden_channels=self.hidden_channels, n_layers=1, n_dim=self.n_dim, ) self.projection = MLP( in_channels=self.hidden_channels, out_channels=out_channels, hidden_channels=self.projection_channels, n_layers=2, n_dim=self.n_dim, non_linearity=non_linearity, )
[docs] def forward(self, x, output_shape=None, **kwargs): """TFNO's forward pass Parameters ---------- x : tensor input tensor output_shape : {tuple, tuple list, None}, default is None Gives the option of specifying the exact output shape for odd shaped inputs. * If None, don't specify an output shape * If tuple, specifies the output-shape of the **last** FNO Block * If tuple list, specifies the exact output-shape of each FNO Block """ if output_shape is None: output_shape = [None]*self.n_layers elif isinstance(output_shape, tuple): output_shape = [None]*(self.n_layers - 1) + [output_shape] x = self.lifting(x) if self.domain_padding is not None: x = self.domain_padding.pad(x) for layer_idx in range(self.n_layers): x = self.fno_blocks(x, layer_idx, output_shape=output_shape[layer_idx]) if self.domain_padding is not None: x = self.domain_padding.unpad(x) x = self.projection(x) return x
@property def n_modes(self): return self._n_modes @n_modes.setter def n_modes(self, n_modes): self.fno_blocks.n_modes = n_modes self._n_modes = n_modes
[docs] class FNO1d(FNO): """1D Fourier Neural Operator For the full list of parameters, see :class:`neuralop.models.FNO`. Parameters ---------- modes_height : int number of Fourier modes to keep along the height """ def __init__( self, n_modes_height, hidden_channels, in_channels=3, out_channels=1, lifting_channels=256, projection_channels=256, max_n_modes=None, fno_block_precision="full", n_layers=4, output_scaling_factor=None, non_linearity=F.gelu, stabilizer=None, use_mlp=False, mlp_dropout=0, mlp_expansion=0.5, norm=None, skip="soft-gating", separable=False, preactivation=False, factorization=None, rank=1.0, joint_factorization=False, fixed_rank_modes=False, implementation="factorized", decomposition_kwargs=dict(), domain_padding=None, domain_padding_mode="one-sided", fft_norm="forward", **kwargs ): super().__init__( n_modes=(n_modes_height,), hidden_channels=hidden_channels, in_channels=in_channels, out_channels=out_channels, lifting_channels=lifting_channels, projection_channels=projection_channels, n_layers=n_layers, output_scaling_factor=output_scaling_factor, non_linearity=non_linearity, stabilizer=stabilizer, use_mlp=use_mlp, mlp_dropout=mlp_dropout, mlp_expansion=mlp_expansion, max_n_modes=max_n_modes, fno_block_precision=fno_block_precision, norm=norm, skip=skip, separable=separable, preactivation=preactivation, factorization=factorization, rank=rank, joint_factorization=joint_factorization, fixed_rank_modes=fixed_rank_modes, implementation=implementation, decomposition_kwargs=decomposition_kwargs, domain_padding=domain_padding, domain_padding_mode=domain_padding_mode, fft_norm=fft_norm, ) self.n_modes_height = n_modes_height
[docs] class FNO2d(FNO): """2D Fourier Neural Operator For the full list of parameters, see :class:`neuralop.models.FNO`. Parameters ---------- n_modes_width : int number of modes to keep in Fourier Layer, along the width n_modes_height : int number of Fourier modes to keep along the height """ def __init__( self, n_modes_height, n_modes_width, hidden_channels, in_channels=3, out_channels=1, lifting_channels=256, projection_channels=256, n_layers=4, output_scaling_factor=None, max_n_modes=None, fno_block_precision="full", non_linearity=F.gelu, stabilizer=None, use_mlp=False, mlp_dropout=0, mlp_expansion=0.5, norm=None, skip="soft-gating", separable=False, preactivation=False, factorization=None, rank=1.0, joint_factorization=False, fixed_rank_modes=False, implementation="factorized", decomposition_kwargs=dict(), domain_padding=None, domain_padding_mode="one-sided", fft_norm="forward", **kwargs ): super().__init__( n_modes=(n_modes_height, n_modes_width), hidden_channels=hidden_channels, in_channels=in_channels, out_channels=out_channels, lifting_channels=lifting_channels, projection_channels=projection_channels, n_layers=n_layers, output_scaling_factor=output_scaling_factor, non_linearity=non_linearity, stabilizer=stabilizer, use_mlp=use_mlp, mlp_dropout=mlp_dropout, mlp_expansion=mlp_expansion, max_n_modes=max_n_modes, fno_block_precision=fno_block_precision, norm=norm, skip=skip, separable=separable, preactivation=preactivation, factorization=factorization, rank=rank, joint_factorization=joint_factorization, fixed_rank_modes=fixed_rank_modes, implementation=implementation, decomposition_kwargs=decomposition_kwargs, domain_padding=domain_padding, domain_padding_mode=domain_padding_mode, fft_norm=fft_norm, ) self.n_modes_height = n_modes_height self.n_modes_width = n_modes_width
[docs] class FNO3d(FNO): """3D Fourier Neural Operator For the full list of parameters, see :class:`neuralop.models.FNO`. Parameters ---------- modes_width : int number of modes to keep in Fourier Layer, along the width modes_height : int number of Fourier modes to keep along the height modes_depth : int number of Fourier modes to keep along the depth """ def __init__( self, n_modes_height, n_modes_width, n_modes_depth, hidden_channels, in_channels=3, out_channels=1, lifting_channels=256, projection_channels=256, n_layers=4, output_scaling_factor=None, max_n_modes=None, fno_block_precision="full", non_linearity=F.gelu, stabilizer=None, use_mlp=False, mlp_dropout=0, mlp_expansion=0.5, norm=None, skip="soft-gating", separable=False, preactivation=False, factorization=None, rank=1.0, joint_factorization=False, fixed_rank_modes=False, implementation="factorized", decomposition_kwargs=dict(), domain_padding=None, domain_padding_mode="one-sided", fft_norm="forward", **kwargs ): super().__init__( n_modes=(n_modes_height, n_modes_width, n_modes_depth), hidden_channels=hidden_channels, in_channels=in_channels, out_channels=out_channels, lifting_channels=lifting_channels, projection_channels=projection_channels, n_layers=n_layers, output_scaling_factor=output_scaling_factor, non_linearity=non_linearity, stabilizer=stabilizer, max_n_modes=max_n_modes, fno_block_precision=fno_block_precision, use_mlp=use_mlp, mlp_dropout=mlp_dropout, mlp_expansion=mlp_expansion, norm=norm, skip=skip, separable=separable, preactivation=preactivation, factorization=factorization, rank=rank, joint_factorization=joint_factorization, fixed_rank_modes=fixed_rank_modes, implementation=implementation, decomposition_kwargs=decomposition_kwargs, domain_padding=domain_padding, domain_padding_mode=domain_padding_mode, fft_norm=fft_norm, ) self.n_modes_height = n_modes_height self.n_modes_width = n_modes_width self.n_modes_depth = n_modes_depth
def partialclass(new_name, cls, *args, **kwargs): """Create a new class with different default values Notes ----- An obvious alternative would be to use functools.partial >>> new_class = partial(cls, **kwargs) The issue is twofold: 1. the class doesn't have a name, so one would have to set it explicitly: >>> new_class.__name__ = new_name 2. the new class will be a functools object and one cannot inherit from it. Instead, here, we define dynamically a new class, inheriting from the existing one. """ __init__ = partialmethod(cls.__init__, *args, **kwargs) new_class = type( new_name, (cls,), { "__init__": __init__, "__doc__": cls.__doc__, "forward": cls.forward, }, ) return new_class TFNO = partialclass("TFNO", FNO, factorization="Tucker") TFNO1d = partialclass("TFNO1d", FNO1d, factorization="Tucker") TFNO2d = partialclass("TFNO2d", FNO2d, factorization="Tucker") TFNO3d = partialclass("TFNO3d", FNO3d, factorization="Tucker") SFNO = partialclass("SFNO", FNO, factorization="dense", SpectralConv=SphericalConv) SFNO.__doc__ = SFNO.__doc__.replace("Fourier", "Spherical Fourier", 1) SFNO.__doc__ = SFNO.__doc__.replace("FNO", "SFNO") SFNO.__doc__ = SFNO.__doc__.replace("fno", "sfno")