Source code for neuralop.models.uno

import torch.nn as nn
import torch.nn.functional as F
import torch

# Set warning filter to show each warning only once
import warnings

warnings.filterwarnings("once", category=UserWarning)

from ..layers.channel_mlp import ChannelMLP
from ..layers.spectral_convolution import SpectralConv
from ..layers.skip_connections import skip_connection
from ..layers.padding import DomainPadding
from ..layers.fno_block import FNOBlocks
from ..layers.resample import resample
from ..layers.embeddings import GridEmbedding2D, GridEmbeddingND


[docs] class UNO(nn.Module): """U-Shaped Neural Operator The architecture is described in [1]_. Parameters ---------- in_channels : int Number of input channels. Determined by the problem. out_channels : int Number of output channels. Determined by the problem. hidden_channels : int Initial width of the UNO. This significantly affects the number of parameters of the UNO. Good starting point can be 64, and then increased if more expressivity is needed. Update lifting_channels and projection_channels accordingly since they are proportional to hidden_channels. uno_out_channels : list Number of output channels of each Fourier layer. Example: For a five layer UNO uno_out_channels can be [32,64,64,64,32] uno_n_modes : list Number of Fourier modes to use in integral operation of each Fourier layer (along each dimension). Example: For a five layer UNO with 2D input the uno_n_modes can be: [[5,5],[5,5],[5,5],[5,5],[5,5]] uno_scalings : list Scaling factors for each Fourier layer. Example: For a five layer UNO with 2D input, the uno_scalings can be: [[1.0,1.0],[0.5,0.5],[1,1],[1,1],[2,2]] n_layers : int, optional Number of Fourier layers. Default: 4 lifting_channels : int, optional Number of hidden channels of the lifting block of the FNO. Default: 256 projection_channels : int, optional Number of hidden channels of the projection block of the FNO. Default: 256 positional_embedding : Union[str, GridEmbedding2D, GridEmbeddingND, None], optional Positional embedding to apply to last channels of raw input before being passed through the UNO. Options: - "grid": Appends a grid positional embedding with default settings to the last channels of raw input. Assumes the inputs are discretized over a grid with entry [0,0,...] at the origin and side lengths of 1. - GridEmbedding2D: Uses this module directly for 2D cases. - GridEmbeddingND: Uses this module directly (see `neuralop.embeddings.GridEmbeddingND` for details). - None: Does nothing. Default: "grid" horizontal_skips_map : Dict, optional A dictionary {b: a, ...} denoting horizontal skip connection from a-th layer to b-th layer. If None, default skip connection is applied. Example: For a 5 layer UNO architecture, the skip connections can be horizontal_skips_map = {4:0,3:1} Default: None channel_mlp_dropout : float, optional Dropout parameter for ChannelMLP after each FNO block. Default: 0 channel_mlp_expansion : float, optional Expansion parameter for ChannelMLP after each FNO block. Default: 0.5 non_linearity : nn.Module, optional Non-linearity module to use. Default: F.gelu norm : str, optional Normalization layer to use. Options: "ada_in", "group_norm", "instance_norm", None. Default: None preactivation : bool, optional Whether to use ResNet-style preactivation. Default: False fno_skip : str, optional Type of skip connection to use in FNO layers. Options: "linear", "identity", "soft-gating", None. Default: "linear" horizontal_skip : str, optional Type of skip connection to use in horizontal connections. Options: "linear", "identity", "soft-gating", None. Default: "linear" channel_mlp_skip : str, optional Type of skip connection to use in channel-mixing MLP. Options: "linear", "identity", "soft-gating", None. Default: "soft-gating" separable : bool, optional Whether to use a separable spectral convolution. Default: False factorization : str, optional Tensor factorization of the parameters weight to use. Options: "None", "Tucker", "CP", "TT" Other factorization methods supported by tltorch. Default: None rank : float, optional Rank of the tensor factorization of the Fourier weights. Default: 1.0. Set to float <1.0 when using TFNO (i.e. when factorization is not None). A TFNO with rank 0.1 has roughly 10% of the parameters of a dense FNO. fixed_rank_modes : bool, optional Whether to not factorize certain modes. Default: False implementation : str, optional If factorization is not None, forward mode to use. Options: "reconstructed", "factorized". Default: "factorized" decomposition_kwargs : dict, optional Additional parameters to pass to the tensor decomposition. Default: {} domain_padding : Union[float, List[float], None], optional Percentage of padding to use. If not None, percentage of padding to use. Default: None fft_norm : str, optional FFT normalization mode. Default: "forward" References ----------- .. [1] : Rahman, M.A., Ross, Z., Azizzadenesheli, K. "U-NO: U-shaped Neural Operators" (2022). TMLR 2022, https://arxiv.org/pdf/2204.11127. """ def __init__( self, in_channels, out_channels, hidden_channels, lifting_channels=256, projection_channels=256, positional_embedding="grid", n_layers=4, uno_out_channels=None, uno_n_modes=None, uno_scalings=None, horizontal_skips_map=None, channel_mlp_dropout=0, channel_mlp_expansion=0.5, non_linearity=F.gelu, norm=None, preactivation=False, fno_skip="linear", horizontal_skip="linear", channel_mlp_skip="soft-gating", separable=False, factorization=None, rank=1.0, fixed_rank_modes=False, integral_operator=SpectralConv, operator_block=FNOBlocks, implementation="factorized", decomposition_kwargs=dict(), domain_padding=None, verbose=False, ): super().__init__() self.n_layers = n_layers assert uno_out_channels is not None, "uno_out_channels can not be None" assert uno_n_modes is not None, "uno_n_modes can not be None" assert uno_scalings is not None, "uno_scalings can not be None" assert ( len(uno_out_channels) == n_layers ), "Output channels for all layers are not given" assert ( len(uno_n_modes) == n_layers ), "number of modes for all layers are not given" assert ( len(uno_scalings) == n_layers ), "Scaling factor for all layers are not given" self.n_dim = len(uno_n_modes[0]) self.uno_out_channels = uno_out_channels self.uno_n_modes = uno_n_modes self.uno_scalings = uno_scalings self.hidden_channels = hidden_channels self.lifting_channels = lifting_channels self.projection_channels = projection_channels self.in_channels = in_channels self.out_channels = out_channels self.horizontal_skips_map = horizontal_skips_map self.non_linearity = non_linearity self.rank = rank self.factorization = factorization self.fixed_rank_modes = fixed_rank_modes self.decomposition_kwargs = decomposition_kwargs self.fno_skip = (fno_skip,) self.channel_mlp_skip = (channel_mlp_skip,) self.implementation = implementation self.separable = separable self.preactivation = preactivation self.operator_block = operator_block self.integral_operator = integral_operator # create positional embedding at the beginning of the model if positional_embedding == "grid": spatial_grid_boundaries = [[0.0, 1.0]] * self.n_dim self.positional_embedding = GridEmbeddingND( in_channels=self.in_channels, dim=self.n_dim, grid_boundaries=spatial_grid_boundaries, ) elif isinstance(positional_embedding, GridEmbedding2D): if self.n_dim == 2: self.positional_embedding = positional_embedding else: raise ValueError(f'Error: expected {self.n_dim}-d positional embeddings, got {positional_embedding}') elif isinstance(positional_embedding, GridEmbeddingND): self.positional_embedding = positional_embedding elif positional_embedding == None: self.positional_embedding = None else: raise ValueError( f"Error: tried to instantiate FNO positional embedding with {positional_embedding},\ expected one of 'grid', GridEmbeddingND" ) if self.positional_embedding is not None: in_channels += self.n_dim # constructing default skip maps if self.horizontal_skips_map is None: self.horizontal_skips_map = {} for i in range( 0, n_layers // 2, ): # example, if n_layers = 5, then 4:0, 3:1 self.horizontal_skips_map[n_layers - i - 1] = i # self.uno_scalings may be a 1d list specifying uniform scaling factor at each layer # or a 2d list, where each row specifies scaling factors along each dimention. # To get the final (end to end) scaling factors we need to multiply # the scaling factors (a list) of all layer. self.end_to_end_scaling_factor = [1] * len(self.uno_scalings[0]) # multiplying scaling factors for k in self.uno_scalings: self.end_to_end_scaling_factor = [ i * j for (i, j) in zip(self.end_to_end_scaling_factor, k) ] # list with a single element is replaced by the scaler. if len(self.end_to_end_scaling_factor) == 1: self.end_to_end_scaling_factor = self.end_to_end_scaling_factor[0] if isinstance(self.end_to_end_scaling_factor, (float, int)): self.end_to_end_scaling_factor = [ self.end_to_end_scaling_factor ] * self.n_dim if verbose: print("calculated out factor", self.end_to_end_scaling_factor) if domain_padding is not None and ( (isinstance(domain_padding, list) and sum(domain_padding) > 0) or (isinstance(domain_padding, (float, int)) and domain_padding > 0) ): self.domain_padding = DomainPadding( domain_padding=domain_padding, resolution_scaling_factor=self.end_to_end_scaling_factor, ) else: self.domain_padding = None self.lifting = ChannelMLP( in_channels=in_channels, out_channels=self.hidden_channels, hidden_channels=self.lifting_channels, n_layers=2, n_dim=self.n_dim, ) self.fno_blocks = nn.ModuleList([]) self.horizontal_skips = torch.nn.ModuleDict({}) prev_out = self.hidden_channels for i in range(self.n_layers): if i in self.horizontal_skips_map.keys(): prev_out = ( prev_out + self.uno_out_channels[self.horizontal_skips_map[i]] ) print(f"{fno_skip=}") print(f"{channel_mlp_skip=}") self.fno_blocks.append( self.operator_block( in_channels=prev_out, out_channels=self.uno_out_channels[i], n_modes=self.uno_n_modes[i], channel_mlp_dropout=channel_mlp_dropout, channel_mlp_expansion=channel_mlp_expansion, resolution_scaling_factor=[self.uno_scalings[i]], non_linearity=non_linearity, norm=norm, preactivation=preactivation, fno_skip=fno_skip, channel_mlp_skip=channel_mlp_skip, rank=rank, fixed_rank_modes=fixed_rank_modes, implementation=implementation, separable=separable, factorization=factorization, decomposition_kwargs=decomposition_kwargs, ) ) if i in self.horizontal_skips_map.values(): self.horizontal_skips[str(i)] = skip_connection( self.uno_out_channels[i], self.uno_out_channels[i], skip_type=horizontal_skip, n_dim=self.n_dim, ) prev_out = self.uno_out_channels[i] self.projection = ChannelMLP( in_channels=prev_out, out_channels=out_channels, hidden_channels=self.projection_channels, n_layers=2, n_dim=self.n_dim, non_linearity=non_linearity, )
[docs] def forward(self, x, **kwargs): if kwargs: warnings.warn( f"UNO.forward() received unexpected keyword arguments: {list(kwargs.keys())}. " "These arguments will be ignored.", UserWarning, stacklevel=2, ) if self.positional_embedding is not None: x = self.positional_embedding(x) x = self.lifting(x) if self.domain_padding is not None: x = self.domain_padding.pad(x) output_shape = [ int(round(i * j)) for (i, j) in zip(x.shape[-self.n_dim :], self.end_to_end_scaling_factor) ] skip_outputs = {} cur_output = None for layer_idx in range(self.n_layers): if layer_idx in self.horizontal_skips_map.keys(): skip_val = skip_outputs[self.horizontal_skips_map[layer_idx]] resolution_scaling_factors = [ m / n for (m, n) in zip(x.shape, skip_val.shape) ] resolution_scaling_factors = resolution_scaling_factors[-1 * self.n_dim :] t = resample( skip_val, resolution_scaling_factors, list(range(-self.n_dim, 0)) ) x = torch.cat([x, t], dim=1) if layer_idx == self.n_layers - 1: cur_output = output_shape x = self.fno_blocks[layer_idx](x, output_shape=cur_output) if layer_idx in self.horizontal_skips_map.values(): skip_outputs[layer_idx] = self.horizontal_skips[str(layer_idx)](x) if self.domain_padding is not None: x = self.domain_padding.unpad(x) x = self.projection(x) return x