Source code for neuralop.models.fnogno

import torch
import torch.nn.functional as F

# Set warning filter to show each warning only once
import warnings

warnings.filterwarnings("once", category=UserWarning)


from .base_model import BaseModel

from ..layers.channel_mlp import ChannelMLP
from ..layers.embeddings import SinusoidalEmbedding
from ..layers.fno_block import FNOBlocks
from ..layers.spectral_convolution import SpectralConv
from ..layers.gno_block import GNOBlock
from ..layers.gno_weighting_functions import dispatch_weighting_fn


[docs] class FNOGNO(BaseModel, name="FNOGNO"): """FNOGNO: Fourier/Geometry Neural Operator - maps from a regular N-d grid to an arbitrary query point cloud. Parameters ---------- in_channels : int Number of input channels. Determined by the problem. out_channels : int Number of output channels. Determined by the problem. fno_n_modes : tuple, optional Number of modes to keep along each spectral dimension of FNO block. Must be larger enough but smaller than max_resolution//2 (Nyquist frequency). Default: (16, 16, 16) fno_hidden_channels : int, optional Number of hidden channels of FNO block. Default: 64 fno_n_layers : int, optional Number of FNO layers in the block. Default: 4 projection_channel_ratio : int, optional Ratio of pointwise projection channels in the final ChannelMLP to fno_hidden_channels. The number of projection channels in the final ChannelMLP is computed by projection_channel_ratio * fno_hidden_channels (i.e. default 256). Default: 4 gno_coord_dim : int, optional Dimension of coordinate space where GNO is computed. Determined by the problem. Default: 3 gno_pos_embed_type : Literal["transformer", "nerf"], optional Type of optional sinusoidal positional embedding to use in GNOBlock. Default: "transformer" gno_radius : float, optional Radius parameter to construct graph. Default: 0.033 Larger radius means more neighboors so more global interactions, but larger computational cost. gno_transform_type : str, optional Type of kernel integral transform to apply in GNO. Kernel k(x,y): parameterized as ChannelMLP MLP integrated over a neighborhood of x. Options: - "linear_kernelonly": Integrand is k(x, y) - "linear": Integrand is k(x, y) * f(y) - "nonlinear_kernelonly": Integrand is k(x, y, f(y)) - "nonlinear": Integrand is k(x, y, f(y)) * f(y) Default: "linear" gno_weighting_function : Literal["half_cos", "bump", "quartic", "quadr", "octic"], optional Choice of weighting function to use in the output GNO for Mollified Graph Neural Operator-based models. See neuralop.layers.gno_weighting_functions for more details. Default: None gno_weight_function_scale : float, optional Factor by which to scale weights from GNO weighting function. If gno_weighting_function is None, this is not used. Default: 1 gno_embed_channels : int, optional Dimension of optional per-channel embedding to use in GNOBlock. Default: 32 gno_embed_max_positions : int, optional Max positions of optional per-channel embedding to use in GNOBlock. If gno_pos_embed_type != 'transformer', value is unused. Default: 10000 gno_channel_mlp_hidden_layers : list, optional Dimension of hidden ChannelMLP layers of GNO. Default: [512, 256] gno_channel_mlp_non_linearity : nn.Module, optional Nonlinear activation function between layers. Default: F.gelu gno_use_open3d : bool, optional Whether to use Open3D functionality. If False, uses simple fallback neighbor search. Default: True gno_use_torch_scatter : bool, optional Whether to use torch-scatter to perform grouped reductions in the IntegralTransform. If False, uses native Python reduction in neuralop.layers.segment_csr. .. warning:: torch-scatter is an optional dependency that conflicts with the newest versions of PyTorch, so you must handle the conflict explicitly in your environment. See :ref:`torch_scatter_dependency` for more information. Default: True gno_batched : bool, optional Whether to use IntegralTransform/GNO layer in "batched" mode. If False, sets batched=False. Default: False fno_lifting_channel_ratio : int, optional Ratio of lifting channels to FNO hidden channels. Default: 4 fno_resolution_scaling_factor : float, optional Factor by which to rescale output predictions in the original domain. Default: None fno_block_precision : str, optional Data precision to compute within FNO block. Options: "full", "half", "mixed". Default: "full" fno_use_channel_mlp : bool, optional Whether to use a ChannelMLP layer after each FNO block. Default: True fno_channel_mlp_dropout : float, optional Dropout parameter of above ChannelMLP. Default: 0 fno_channel_mlp_expansion : float, optional Expansion parameter of above ChannelMLP. Default: 0.5 fno_non_linearity : nn.Module, optional Nonlinear activation function between each FNO layer. Default: F.gelu fno_stabilizer : nn.Module, optional By default None, otherwise tanh is used before FFT in the FNO block. Default: None fno_norm : str, optional Normalization layer to use in FNO. Options: "ada_in", "group_norm", "instance_norm", None. Default: None fno_ada_in_features : int, optional If an adaptive mesh is used, number of channels of its positional embedding. Default: None fno_ada_in_dim : int, optional Dimensions of above FNO adaptive mesh. Default: 1 fno_preactivation : bool, optional Whether to use ResNet-style preactivation. Default: False fno_skip : str, optional Type of skip connection to use. Options: "linear", "identity", "soft-gating", None. Default: "linear" fno_channel_mlp_skip : str, optional Type of skip connection to use in the FNO. Options: - "linear": Conv layer - "soft-gating": Weights the channels of the input - "identity": nn.Identity - None: No skip connection Default: "soft-gating" fno_separable : bool, optional Whether to use a depthwise separable spectral convolution. Default: False fno_factorization : str, optional Tensor factorization of the parameters weight to use. Options: "tucker", "tt", "cp", None. Default: None fno_rank : float, optional Rank of the tensor factorization of the Fourier weights. Default: 1.0 fno_fixed_rank_modes : bool, optional Whether to not factorize certain modes. Default: False fno_implementation : str, optional If factorization is not None, forward mode to use. Options: - "reconstructed": The full weight tensor is reconstructed from the factorization and used for the forward pass - "factorized": The input is directly contracted with the factors of the decomposition Default: "factorized" fno_decomposition_kwargs : dict, optional Additional parameters to pass to the tensor decomposition. Default: {} fno_conv_module : nn.Module, optional Spectral convolution module to use. Default: SpectralConv fno_enforce_hermitian_symmetry : bool, optional Whether to enforce Hermitian symmetry conditions when performing inverse FFT for real-valued data in the FNO branch. Only used in :class:`SpectralConv`; ignored otherwise. When True, explicitly enforces that the 0th frequency and Nyquist frequency are real-valued before calling irfft. When False, relies on cuFFT's irfftn to handle symmetry automatically, which may fail on certain GPUs or input sizes, causing line artifacts. By default True. """ def __init__( self, in_channels, out_channels, projection_channel_ratio=4, gno_coord_dim=3, gno_pos_embed_type="transformer", gno_transform_type="linear", fno_n_modes=(16, 16, 16), fno_hidden_channels=64, fno_lifting_channel_ratio=4, fno_n_layers=4, # Other GNO params gno_embed_channels=32, gno_embed_max_positions=10000, gno_radius=0.033, gno_weighting_function=None, gno_weight_function_scale=1.0, gno_channel_mlp_hidden_layers=[512, 256], gno_channel_mlp_non_linearity=F.gelu, gno_use_open3d=True, gno_use_torch_scatter=True, gno_batched=False, # Other FNO params fno_resolution_scaling_factor=None, fno_block_precision="full", fno_use_channel_mlp=True, fno_channel_mlp_dropout=0, fno_channel_mlp_expansion=0.5, fno_non_linearity=F.gelu, fno_stabilizer=None, fno_norm=None, fno_ada_in_features=None, fno_ada_in_dim=1, fno_preactivation=False, fno_skip="linear", fno_channel_mlp_skip="soft-gating", fno_separable=False, fno_factorization=None, fno_rank=1.0, fno_fixed_rank_modes=False, fno_implementation="factorized", fno_decomposition_kwargs=dict(), fno_conv_module=SpectralConv, fno_enforce_hermitian_symmetry=True, ): super().__init__() self.gno_coord_dim = gno_coord_dim if self.gno_coord_dim != 3 and gno_use_open3d: warnings.warn( f"GNO expects {self.gno_coord_dim}-d data but Open3d expects 3-d data", UserWarning, stacklevel=2, ) self.in_coord_dim = len(fno_n_modes) if self.in_coord_dim != self.gno_coord_dim: warnings.warn( f"FNO expects {self.in_coord_dim}-d data while GNO expects {self.gno_coord_dim}-d data", UserWarning, stacklevel=2, ) # these lists contain the interior dimensions of the input # in order to reshape without explicitly providing dims self.in_coord_dim_forward_order = list(range(self.in_coord_dim)) self.in_coord_dim_reverse_order = [ j + 1 for j in self.in_coord_dim_forward_order ] self.gno_batched = gno_batched # used in forward call to GNO # if batched, we must account for the extra batch dim # which causes previous dims to be incremented by 1 if self.gno_batched: self.in_coord_dim_forward_order = [ j + 1 for j in self.in_coord_dim_forward_order ] self.in_coord_dim_reverse_order = [ j + 1 for j in self.in_coord_dim_reverse_order ] if fno_norm == "ada_in": if fno_ada_in_features is not None: self.adain_pos_embed = SinusoidalEmbedding( in_channels=fno_ada_in_dim, num_frequencies=fno_ada_in_features, embedding_type="transformer", ) # if ada_in positional embedding is provided, set the input dimension # of the ada_in norm to the output channels of positional embedding self.ada_in_dim = self.adain_pos_embed.out_channels else: self.ada_in_dim = fno_ada_in_dim else: self.adain_pos_embed = None self.ada_in_dim = None # Create lifting for FNOBlock separately fno_lifting_channels = fno_lifting_channel_ratio * fno_hidden_channels self.lifting = ChannelMLP( in_channels=in_channels + self.in_coord_dim, hidden_channels=fno_lifting_channels, out_channels=fno_hidden_channels, n_layers=3, ) self.fno_hidden_channels = fno_hidden_channels self.fno_blocks = FNOBlocks( n_modes=fno_n_modes, in_channels=fno_hidden_channels, out_channels=fno_hidden_channels, n_layers=fno_n_layers, resolution_scaling_factor=fno_resolution_scaling_factor, fno_block_precision=fno_block_precision, use_channel_mlp=fno_use_channel_mlp, channel_mlp_expansion=fno_channel_mlp_expansion, channel_mlp_dropout=fno_channel_mlp_dropout, non_linearity=fno_non_linearity, stabilizer=fno_stabilizer, norm=fno_norm, ada_in_features=self.ada_in_dim, preactivation=fno_preactivation, fno_skip=fno_skip, channel_mlp_skip=fno_channel_mlp_skip, separable=fno_separable, factorization=fno_factorization, rank=fno_rank, fixed_rank_modes=fno_fixed_rank_modes, implementation=fno_implementation, decomposition_kwargs=fno_decomposition_kwargs, conv_module=fno_conv_module, enforce_hermitian_symmetry=fno_enforce_hermitian_symmetry, ) self.gno_radius = gno_radius if gno_weighting_function is not None: weight_fn = dispatch_weighting_fn( gno_weighting_function, sq_radius=gno_radius**2, scale=gno_weight_function_scale, ) else: weight_fn = None self.gno = GNOBlock( in_channels=fno_hidden_channels, out_channels=fno_hidden_channels, radius=gno_radius, weighting_fn=weight_fn, coord_dim=self.gno_coord_dim, pos_embedding_type=gno_pos_embed_type, pos_embedding_channels=gno_embed_channels, pos_embedding_max_positions=gno_embed_max_positions, channel_mlp_layers=gno_channel_mlp_hidden_layers, channel_mlp_non_linearity=gno_channel_mlp_non_linearity, transform_type=gno_transform_type, use_open3d_neighbor_search=gno_use_open3d, use_torch_scatter_reduce=gno_use_torch_scatter, ) projection_channels = projection_channel_ratio * fno_hidden_channels self.projection = ChannelMLP( in_channels=fno_hidden_channels, out_channels=out_channels, hidden_channels=projection_channels, n_layers=2, n_dim=1, non_linearity=fno_non_linearity, ) # out_p : (n_out, gno_coord_dim) # in_p : (n_1, n_2, ..., n_k, k) # if batched shape is the same because this is just geometry # that remains constant across the entire batch # f : (n_1, n_2, ..., n_k, in_channels) # if batched, (b, n_1, n_2, ..., n_k, in_channels) # ada_in : (fno_ada_in_dim, ) # returns: (fno_hidden_channels, n_1, n_2, ...) def latent_embedding(self, in_p, f, ada_in=None): if self.gno_batched: batch_size = f.shape[0] # repeat in_p along the batch dimension for latent embedding in_p = in_p.repeat([batch_size] + [1] * (in_p.ndim)) in_p = torch.cat((f, in_p), dim=-1) if self.gno_batched: # shape: (b, k, n_1, n_2, ... n_k) in_p = in_p.permute(0, -1, *self.in_coord_dim_forward_order) else: in_p = in_p.permute(-1, *self.in_coord_dim_forward_order).unsqueeze(0) # Update Ada IN embedding if ada_in is not None: if self.adain_pos_embed is not None: ada_in_embed = self.adain_pos_embed(ada_in.unsqueeze(0)).squeeze(0) else: ada_in_embed = ada_in self.fno_blocks.set_ada_in_embeddings(ada_in_embed) # Apply FNO blocks in_p = self.lifting(in_p) for layer_idx in range(self.fno_blocks.n_layers): in_p = self.fno_blocks(in_p, layer_idx) if self.gno_batched: return in_p else: return in_p.squeeze(0)
[docs] def integrate_latent(self, in_p, out_p, latent_embed): """ Compute integration region for each output point """ # (n_1*n_2*..., fno_hidden_channels) # if batched, (b, n1*n2*..., fno_hidden_channels) if self.gno_batched: batch_size = latent_embed.shape[0] latent_embed = latent_embed.permute( 0, *self.in_coord_dim_reverse_order, 1 ).reshape((batch_size, -1, self.fno_hidden_channels)) else: latent_embed = latent_embed.permute( *self.in_coord_dim_reverse_order, 0 ).reshape((-1, self.fno_hidden_channels)) # (n_out, fno_hidden_channels) out = self.gno( y=in_p.reshape(-1, in_p.shape[-1]), x=out_p, f_y=latent_embed, ) # if self.gno is variable and not batched if out.ndim == 2: out = out.unsqueeze(0) out = out.permute(0, 2, 1) # b, c, n_out # Project pointwise to out channels out = self.projection(out) if self.gno_batched: out = out.permute(0, 2, 1) else: out = out.squeeze(0).permute(1, 0) return out
[docs] def forward(self, in_p, out_p, f, ada_in=None, **kwargs): if kwargs: warnings.warn( f"FNOGNO.forward() received unexpected keyword arguments: {list(kwargs.keys())}. " "These arguments will be ignored.", UserWarning, stacklevel=2, ) # Compute latent space embedding latent_embed = self.latent_embedding(in_p=in_p, f=f, ada_in=ada_in) # Integrate latent space out = self.integrate_latent(in_p=in_p, out_p=out_p, latent_embed=latent_embed) return out