Source code for neuralop.models.rno

from typing import Tuple, List, Union, Literal

Number = Union[float, int]

import torch
import torch.nn as nn
import torch.nn.functional as F

# Set warning filter to show each warning only once
import warnings

warnings.filterwarnings("once", category=UserWarning)

from ..layers.rno_block import RNOBlock
from ..layers.padding import DomainPadding
from ..layers.fno_block import FNOBlocks
from ..layers.channel_mlp import ChannelMLP
from ..layers.spectral_convolution import SpectralConv
from ..layers.complex import ComplexValued
from ..layers.embeddings import GridEmbeddingND, GridEmbedding2D
from .base_model import BaseModel

[docs] class RNO(BaseModel, name='RNO'): """ N-Dimensional Recurrent Neural Operator. The RNO has an identical architecture to the finite-dimensional GRU, with the exception that linear matrix-vector multiplications are replaced by Fourier layers (Li et al., 2021), and for regression problems, the output nonlinearity is replaced by a SELU activation. The operation of the GRU is as follows: z_t = sigmoid(W_z x + U_z h_{t-1} + b_z) r_t = sigmoid(W_r x + U_r h_{t-1} + b_r) \hat h_t = selu(W_h x_t + U_h (r_t * h_{t-1}) + b_h) h_t = (1 - z_t) * h_{t-1} + z_t * \hat h_t, where * is element-wise, the b_i's are bias functions, and W_i, U_i are linear Fourier layers. Paper: .. [RNO] Liu-Schiaffini, M., Singer, C. E., Kovachki, N., Schneider, T., Azizzadenesheli, K., & Anandkumar, A. (2023). Tipping point forecasting in non-stationary dynamics on function spaces. arXiv preprint arXiv:2308.08794. Main Parameters --------------- n_modes : Tuple[int, ...] Number of modes to keep in Fourier Layer, along each dimension. The dimensionality of the RNO is inferred from ``len(n_modes)``. n_modes must be larger enough but smaller than max_resolution//2 (Nyquist frequency) in_channels : int Number of input channels in input function. Determined by the problem. out_channels : int Number of output channels in output function. Determined by the problem. hidden_channels : int Width of the RNO (i.e. number of channels). This significantly affects the number of parameters of the RNO. For 1D problems, 24-48 channels are typically sufficient. For 2D/3D problems without Tucker/CP/TT factorization, start with 16-32 channels to avoid excessive parameters. Increase if more expressivity is needed. Update lifting_channel_ratio and projection_channel_ratio accordingly since they are proportional to hidden_channels. n_layers : int, optional Number of RNO layers to use. Default: 4 rno_skip : bool, optional Whether to use skip connections between RNO layers. When True, adds the input to each layer's output: x_{l+1} = x_l + RNOBlock(x_l). Default: False. Unlike FNO where skip connections improve gradient flow, RNO's recurrent structure already provides temporal gradient pathways, so skip connections are optional and may not always improve performance. Other parameters --------------- lifting_channel_ratio : Number, optional Ratio of lifting channels to hidden_channels. The number of lifting channels in the lifting block of the RNO is lifting_channel_ratio * hidden_channels (e.g. default 2 * hidden_channels). projection_channel_ratio : Number, optional Ratio of projection channels to hidden_channels. The number of projection channels in the projection block of the RNO is projection_channel_ratio * hidden_channels (e.g. default 2 * hidden_channels). positional_embedding : Union[str, nn.Module], optional Type of positional embedding to use. Options: - "grid": Use default grid-based positional embedding - GridEmbeddingND: Custom N-dimensional grid embedding - GridEmbedding2D: Custom 2D grid embedding (only for 2D problems) - None: No positional embedding Default: None non_linearity : nn.Module, optional Non-Linear activation function module to use. Default: F.gelu norm : Literal["ada_in", "group_norm", "instance_norm"], optional Normalization layer to use. Options: "ada_in", "group_norm", "instance_norm", None. Default: None complex_data : bool, optional Whether the data is complex-valued. If True, initializes complex-valued modules. Default: False use_channel_mlp : bool, optional Whether to use an MLP layer after each FNO block. Default: True channel_mlp_dropout : float, optional Dropout parameter for ChannelMLP in FNO Block. Default: 0 channel_mlp_expansion : float, optional Expansion parameter for ChannelMLP in FNO Block. Default: 0.5 channel_mlp_skip : Literal["linear", "identity", "soft-gating", None], optional Type of skip connection to use in channel-mixing mlp. Options: "linear", "identity", "soft-gating", None. Default: "soft-gating" fno_skip : Literal["linear", "identity", "soft-gating", None], optional Type of skip connection to use in FNO layers. Options: "linear", "identity", "soft-gating", None. Default: "linear" return_sequences : bool, optional Whether the final RNO layer returns the full sequence of hidden states or just the final state. Intermediate layers always return sequences. Default: False resolution_scaling_factor : Union[Number, List[Number]], optional Layer-wise factor by which to scale the domain resolution of function. Options: - None: No scaling - Single number n: Scales resolution by n at each layer - List of numbers [n_0, n_1,...]: Scales layer i's resolution by n_i Default: None domain_padding : Union[Number, List[Number]], optional Percentage of padding to use. Options: - None: No padding - Single number: Percentage of padding to use along all dimensions - List of numbers [p1, p2, ..., pN]: Percentage of padding along each dimension Default: None fno_block_precision : str, optional Precision mode in which to perform spectral convolution. Options: "full", "half", "mixed". Default: "full". stabilizer : str, optional Whether to use a stabilizer in FNO block. Options: "tanh", None. Default: None. stabilizer greatly improves performance in the case `fno_block_precision='mixed'`. max_n_modes : Tuple[int, ...], optional Maximum number of modes to use in Fourier domain during training. None means that all the n_modes are used. Tuple of integers: Incrementally increase the number of modes during training. This can be updated dynamically during training. factorization : str, optional Tensor factorization of the FNO layer weights to use. Options: "None", "Tucker", "CP", "TT" Other factorization methods supported by tltorch. Default: None rank : float, optional Tensor rank to use in factorization. Default: 1.0 Set to float <1.0 when using TFNO (i.e. when factorization is not None). A TFNO with rank 0.1 has roughly 10% of the parameters of a dense FNO. fixed_rank_modes : bool, optional Whether to not factorize certain modes. Default: False implementation : str, optional Implementation method for factorized tensors. Options: "factorized", "reconstructed". Default: "factorized" decomposition_kwargs : dict, optional Extra kwargs for tensor decomposition (see `tltorch.FactorizedTensor`). Default: {} separable : bool, optional Whether to use a separable spectral convolution. Default: False preactivation : bool, optional Whether to compute FNO forward pass with resnet-style preactivation. Default: False conv_module : nn.Module, optional Module to use for FNOBlock's convolutions. Default: SpectralConv """ def __init__( self, n_modes: Tuple[int, ...], in_channels: int, out_channels: int, hidden_channels: int, n_layers: int = 4, rno_skip: bool = False, lifting_channel_ratio: Number = 2, projection_channel_ratio: Number = 2, positional_embedding: Union[str, nn.Module] = "grid", non_linearity: nn.Module = F.gelu, norm: Literal["ada_in", "group_norm", "instance_norm"] = None, complex_data: bool = False, use_channel_mlp: bool = True, channel_mlp_dropout: float = 0, channel_mlp_expansion: float = 0.5, channel_mlp_skip: Literal["linear", "identity", "soft-gating", None] = "soft-gating", fno_skip: Literal["linear", "identity", "soft-gating", None] = "linear", return_sequences: bool = False, resolution_scaling_factor: Union[Number, List[Number]] = None, domain_padding: Union[Number, List[Number]] = None, fno_block_precision: str = "full", stabilizer: str = None, max_n_modes: Tuple[int, ...] = None, factorization: str = None, rank: float = 1.0, fixed_rank_modes: bool = False, implementation: str = "factorized", decomposition_kwargs: dict = None, separable: bool = False, preactivation: bool = False, conv_module: nn.Module = SpectralConv, ): if decomposition_kwargs is None: decomposition_kwargs = {} super().__init__() self.n_dim = len(n_modes) # n_modes is a special property - see the class' property for underlying mechanism # When updated, change should be reflected in fno blocks self.n_modes = n_modes self.hidden_channels = hidden_channels self.in_channels = in_channels self.out_channels = out_channels self.n_layers = n_layers self.rno_skip = rno_skip # init lifting and projection channels using ratios w.r.t hidden channels self.lifting_channel_ratio = lifting_channel_ratio self.lifting_channels = int(lifting_channel_ratio * self.hidden_channels) self.projection_channel_ratio = projection_channel_ratio self.projection_channels = int(projection_channel_ratio * self.hidden_channels) self.non_linearity = non_linearity self.rank = rank self.factorization = factorization self.fixed_rank_modes = fixed_rank_modes self.decomposition_kwargs = decomposition_kwargs self.fno_skip = (fno_skip,) self.channel_mlp_skip = (channel_mlp_skip,) self.implementation = implementation self.separable = separable self.preactivation = preactivation self.complex_data = complex_data self.fno_block_precision = fno_block_precision ## Domain padding if domain_padding is not None and ( (isinstance(domain_padding, list) and sum(domain_padding) > 0) or (isinstance(domain_padding, (float, int)) and domain_padding > 0) ): self.domain_padding = DomainPadding( domain_padding=domain_padding, resolution_scaling_factor=resolution_scaling_factor, ) else: self.domain_padding = None ## Positional embedding if positional_embedding == "grid": spatial_grid_boundaries = [[0.0, 1.0]] * self.n_dim self.positional_embedding = GridEmbeddingND( in_channels=self.in_channels, dim=self.n_dim, grid_boundaries=spatial_grid_boundaries, ) elif isinstance(positional_embedding, GridEmbedding2D): if self.n_dim == 2: self.positional_embedding = positional_embedding else: raise ValueError( f"Error: expected {self.n_dim}-d positional embeddings, got {positional_embedding}" ) elif isinstance(positional_embedding, GridEmbeddingND): self.positional_embedding = positional_embedding elif positional_embedding is None: self.positional_embedding = None else: raise ValueError( f"Error: tried to instantiate RNO positional embedding with {positional_embedding}, " f"expected one of 'grid', GridEmbeddingND, GridEmbedding2D, or None" ) ## Resolution scaling factor if resolution_scaling_factor: if isinstance(resolution_scaling_factor, (float, int)): resolution_scaling_factor = [resolution_scaling_factor] * self.n_layers else: resolution_scaling_factor = [None] * self.n_layers self.resolution_scaling_factor = resolution_scaling_factor ## Return sequences configuration # Intermediate layers always return sequences, only last layer is configurable return_sequences_list = [True] * (self.n_layers - 1) + [return_sequences] self.return_sequences = return_sequences_list module_list = [ RNOBlock( n_modes=self.n_modes, hidden_channels=hidden_channels, return_sequences=return_sequences_list[i], resolution_scaling_factor=self.resolution_scaling_factor[i], use_channel_mlp=use_channel_mlp, channel_mlp_dropout=channel_mlp_dropout, channel_mlp_expansion=channel_mlp_expansion, non_linearity=non_linearity, stabilizer=stabilizer, norm=norm, preactivation=preactivation, fno_skip=fno_skip, channel_mlp_skip=channel_mlp_skip, complex_data=complex_data, max_n_modes=max_n_modes, fno_block_precision=fno_block_precision, rank=rank, fixed_rank_modes=fixed_rank_modes, implementation=implementation, separable=separable, factorization=factorization, decomposition_kwargs=decomposition_kwargs, conv_module=conv_module, ) for i in range(n_layers)] self.layers = nn.ModuleList(module_list) ## Lifting layer # if adding a positional embedding, add those channels to lifting lifting_in_channels = self.in_channels if self.positional_embedding is not None: lifting_in_channels += self.n_dim # if lifting_channels is set, make lifting a Channel-Mixing MLP # with a hidden layer of size lifting_channels if self.lifting_channels: self.lifting = ChannelMLP( in_channels=lifting_in_channels, out_channels=self.hidden_channels, hidden_channels=self.lifting_channels, n_layers=2, n_dim=self.n_dim, non_linearity=non_linearity, ) # otherwise, make it a linear layer else: self.lifting = ChannelMLP( in_channels=lifting_in_channels, hidden_channels=self.hidden_channels, out_channels=self.hidden_channels, n_layers=1, n_dim=self.n_dim, non_linearity=non_linearity, ) # Convert lifting to a complex ChannelMLP if self.complex_data==True if self.complex_data: self.lifting = ComplexValued(self.lifting) ## Projection layer self.projection = ChannelMLP( in_channels=self.hidden_channels, out_channels=out_channels, hidden_channels=self.projection_channels, n_layers=2, n_dim=self.n_dim, non_linearity=non_linearity, ) if self.complex_data: self.projection = ComplexValued(self.projection)
[docs] def forward( self, x, init_hidden_states=None, return_hidden_states=False, keep_states_padded=False, ): """ Forward pass for the Recurrent Neural Operator. Parameters ---------- x : torch.Tensor Input tensor with shape (batch, timesteps, in_channels, *spatial_dims), where len(spatial_dims) == self.n_dim. The channel dimension MUST be at index 2 and the time dimension MUST be at index 1. init_hidden_states : list[torch.Tensor] | None Optional list of per-layer initial hidden states. Each tensor should have shape (batch, hidden_channels, *spatial_dims_h), where spatial_dims_h are the unpadded spatial dimensions at each layer (accounting for resolution scaling). If domain_padding is enabled, hidden states will be automatically padded like the input x, unless keep_states_padded is True. If None, all hidden states are initialized internally. return_hidden_states : bool, optional Whether to return the final hidden states for each layer in addition to the prediction. Default: False keep_states_padded : bool, optional If True, treats provided init_hidden_states as already padded (no additional padding applied) and returns final_hidden_states in padded form (no unpadding). Use this for efficient autoregressive loops to preserve boundary information. Default: False Returns ------- pred : torch.Tensor Output tensor with shape (batch, out_channels, *spatial_dims_out). final_hidden_states : list[torch.Tensor], optional Returned only if return_hidden_states=True. """ # Strict input validation to avoid silent errors from resolution invariance expected_rank = 3 + self.n_dim if x.ndim != expected_rank: raise ValueError( f"RNO.forward expected input of rank {expected_rank} = " f"(batch, timesteps, channels, {self.n_dim} spatial dims), got rank {x.ndim} with shape {tuple(x.shape)}" ) if x.shape[2] != self.in_channels: raise ValueError( f"RNO.forward expected x.shape[2] == in_channels ({self.in_channels}); " f"got {x.shape[2]}. Input must be shaped as (batch, timesteps, in_channels, *spatial_dims)." ) # x shape (batch, timesteps, in_channels, *spatial_dims) batch_size, timesteps = x.shape[:2] dom_sizes = x.shape[3 : 3 + self.n_dim] if init_hidden_states is None: init_hidden_states = [None] * self.n_layers else: # If domain padding is enabled AND states are not already padded (keep_states_padded=False), pad them. if self.domain_padding and not keep_states_padded: padded_hidden_states = [] for h in init_hidden_states: if h is not None: h = self.domain_padding.pad(h) padded_hidden_states.append(h) init_hidden_states = padded_hidden_states # Reshape for processing: (batch*timesteps, channels, *spatial_dims) x = x.reshape(batch_size * timesteps, *x.shape[2:]) # append spatial pos embedding if set if self.positional_embedding is not None: x = self.positional_embedding(x) x = self.lifting(x) x = x.reshape(batch_size, timesteps, *x.shape[1:]) if self.domain_padding: # DomainPadding expects (batch, channels, *spatial_dims), so reshape to remove timestep dim x = x.reshape(batch_size * timesteps, *x.shape[2:]) x = self.domain_padding.pad(x) x = x.reshape(batch_size, timesteps, *x.shape[1:]) final_hidden_states = [] for i in range(self.n_layers): pred_x = self.layers[i](x, init_hidden_states[i]) if i < self.n_layers - 1: if self.rno_skip: x = x + pred_x else: x = pred_x final_hidden_states.append(x[:, -1]) else: x = pred_x final_hidden_states.append(x) h = final_hidden_states[-1] if self.domain_padding: # DomainPadding.unpad expects (batch, channels, *spatial_dims) # h comes from the last layer, which is padded logic. # We ALWAYS unpad 'h' for the projection layer to produce the correct output size # (Computation continues in unpadded space for projection) h = self.domain_padding.unpad(h) # For final_hidden_states, we unpad UNLESS explicitly told to keep them padded if return_hidden_states and not keep_states_padded: unpadded_states = [] for state in final_hidden_states: unpadded_states.append(self.domain_padding.unpad(state)) final_hidden_states = unpadded_states pred = self.projection(h) if return_hidden_states: return pred, final_hidden_states else: return pred
[docs] def predict(self, x, num_steps, grid_function=None): """Autoregressively predict future time steps. Performs autoregressive rollout by iteratively feeding predictions back as input. At each step, the model predicts the next time step from the current input, then uses that prediction as input for the subsequent prediction. Parameters ---------- x : torch.Tensor Initial input sequence with shape (batch, timesteps, in_channels, *spatial_dims). The last time step of this sequence serves as the starting point for predictions. num_steps : int Number of future time steps to predict autoregressively. grid_function : callable, optional Function that generates positional embeddings (e.g., spatial coordinates) to concatenate with predictions before feeding them back as input. Should have signature grid_function(shape, device) -> torch.Tensor, where the returned tensor has shape matching the input requirements. Use this when your model was trained with concatenated positional information. Default: None Returns ------- torch.Tensor Predicted sequence with shape (batch, num_steps, out_channels, *spatial_dims), containing the autoregressively generated future states. Notes ----- This method maintains hidden states across prediction steps, enabling the RNO to leverage its recurrent structure during autoregressive generation. """ output = [] states = [None] * self.n_layers for _ in range(num_steps): pred, states = self.forward( x, states, return_hidden_states=True, keep_states_padded=True ) output.append(pred) x = pred.reshape((pred.shape[0], 1, *pred.shape[1:])) if grid_function: grid = grid_function( (x.shape[0], x.shape[1], 1, x.shape[-2], x.shape[-1]), x.device ) x = torch.cat((x, grid), dim=2) return torch.stack(output, dim=1)