Source code for neuralop.models.fnogno

import torch
import torch.nn.functional as F

# Set warning filter to show each warning only once
import warnings

warnings.filterwarnings("once", category=UserWarning)


from .base_model import BaseModel

from ..layers.channel_mlp import ChannelMLP
from ..layers.embeddings import SinusoidalEmbedding
from ..layers.fno_block import FNOBlocks
from ..layers.spectral_convolution import SpectralConv
from ..layers.gno_block import GNOBlock
from ..layers.gno_weighting_functions import dispatch_weighting_fn


[docs] class FNOGNO(BaseModel, name="FNOGNO"): """FNOGNO: Fourier/Geometry Neural Operator - maps from a regular N-d grid to an arbitrary query point cloud. Parameters ---------- in_channels : int Number of input channels. Determined by the problem. out_channels : int Number of output channels. Determined by the problem. fno_n_modes : tuple, optional Number of modes to keep along each spectral dimension of FNO block. Must be larger enough but smaller than max_resolution//2 (Nyquist frequency). Default: (16, 16, 16) fno_hidden_channels : int, optional Number of hidden channels of FNO block. Default: 64 fno_n_layers : int, optional Number of FNO layers in the block. Default: 4 projection_channel_ratio : int, optional Ratio of pointwise projection channels in the final ChannelMLP to fno_hidden_channels. The number of projection channels in the final ChannelMLP is computed by projection_channel_ratio * fno_hidden_channels (i.e. default 256). Default: 4 gno_coord_dim : int, optional Dimension of coordinate space where GNO is computed. Determined by the problem. Default: 3 gno_pos_embed_type : Literal["transformer", "nerf"], optional Type of optional sinusoidal positional embedding to use in GNOBlock. Default: "transformer" gno_radius : float, optional Radius parameter to construct graph. Default: 0.033 Larger radius means more neighboors so more global interactions, but larger computational cost. gno_transform_type : str, optional Type of kernel integral transform to apply in GNO. Kernel k(x,y): parameterized as ChannelMLP MLP integrated over a neighborhood of x. Options: - "linear_kernelonly": Integrand is k(x, y) - "linear": Integrand is k(x, y) * f(y) - "nonlinear_kernelonly": Integrand is k(x, y, f(y)) - "nonlinear": Integrand is k(x, y, f(y)) * f(y) Default: "linear" gno_weighting_function : Literal["half_cos", "bump", "quartic", "quadr", "octic"], optional Choice of weighting function to use in the output GNO for Mollified Graph Neural Operator-based models. See neuralop.layers.gno_weighting_functions for more details. Default: None gno_weight_function_scale : float, optional Factor by which to scale weights from GNO weighting function. If gno_weighting_function is None, this is not used. Default: 1 gno_embed_channels : int, optional Dimension of optional per-channel embedding to use in GNOBlock. Default: 32 gno_embed_max_positions : int, optional Max positions of optional per-channel embedding to use in GNOBlock. If gno_pos_embed_type != 'transformer', value is unused. Default: 10000 gno_channel_mlp_hidden_layers : list, optional Dimension of hidden ChannelMLP layers of GNO. Default: [512, 256] gno_channel_mlp_non_linearity : nn.Module, optional Nonlinear activation function between layers. Default: F.gelu gno_use_open3d : bool, optional Whether to use Open3D functionality. If False, uses simple fallback neighbor search. Default: True gno_use_torch_scatter : bool, optional Whether to use torch-scatter to perform grouped reductions in the IntegralTransform. If False, uses native Python reduction in neuralop.layers.segment_csr. .. warning:: torch-scatter is an optional dependency that conflicts with the newest versions of PyTorch, so you must handle the conflict explicitly in your environment. See :ref:`torch_scatter_dependency` for more information. Default: True gno_batched : bool, optional Whether to use IntegralTransform/GNO layer in "batched" mode. If False, sets batched=False. Default: False fno_lifting_channel_ratio : int, optional Ratio of lifting channels to FNO hidden channels. Default: 4 fno_resolution_scaling_factor : float, optional Factor by which to rescale output predictions in the original domain. Default: None fno_block_precision : str, optional Data precision to compute within FNO block. Options: "full", "half", "mixed". Default: "full" fno_use_channel_mlp : bool, optional Whether to use a ChannelMLP layer after each FNO block. Default: True fno_channel_mlp_dropout : float, optional Dropout parameter of above ChannelMLP. Default: 0 fno_channel_mlp_expansion : float, optional Expansion parameter of above ChannelMLP. Default: 0.5 fno_non_linearity : nn.Module, optional Nonlinear activation function between each FNO layer. Default: F.gelu fno_stabilizer : nn.Module, optional By default None, otherwise tanh is used before FFT in the FNO block. Default: None fno_norm : str, optional Normalization layer to use in FNO. Options: "ada_in", "group_norm", "instance_norm", None. Default: None fno_ada_in_features : int, optional If an adaptive mesh is used, number of channels of its positional embedding. Default: None fno_ada_in_dim : int, optional Dimensions of above FNO adaptive mesh. Default: 1 fno_preactivation : bool, optional Whether to use ResNet-style preactivation. Default: False fno_skip : str, optional Type of skip connection to use. Options: "linear", "identity", "soft-gating", None. Default: "linear" fno_channel_mlp_skip : str, optional Type of skip connection to use in the FNO. Options: - "linear": Conv layer - "soft-gating": Weights the channels of the input - "identity": nn.Identity - None: No skip connection Default: "soft-gating" fno_separable : bool, optional Whether to use a depthwise separable spectral convolution. Default: False fno_factorization : str, optional Tensor factorization of the parameters weight to use. Options: "tucker", "tt", "cp", None. Default: None fno_rank : float, optional Rank of the tensor factorization of the Fourier weights. Default: 1.0 fno_fixed_rank_modes : bool, optional Whether to not factorize certain modes. Default: False fno_implementation : str, optional If factorization is not None, forward mode to use. Options: - "reconstructed": The full weight tensor is reconstructed from the factorization and used for the forward pass - "factorized": The input is directly contracted with the factors of the decomposition Default: "factorized" fno_decomposition_kwargs : dict, optional Additional parameters to pass to the tensor decomposition. Default: {} fno_conv_module : nn.Module, optional Spectral convolution module to use. Default: SpectralConv """ def __init__( self, in_channels, out_channels, projection_channel_ratio=4, gno_coord_dim=3, gno_pos_embed_type="transformer", gno_transform_type="linear", fno_n_modes=(16, 16, 16), fno_hidden_channels=64, fno_lifting_channel_ratio=4, fno_n_layers=4, # Other GNO params gno_embed_channels=32, gno_embed_max_positions=10000, gno_radius=0.033, gno_weighting_function=None, gno_weight_function_scale=1.0, gno_channel_mlp_hidden_layers=[512, 256], gno_channel_mlp_non_linearity=F.gelu, gno_use_open3d=True, gno_use_torch_scatter=True, gno_batched=False, # Other FNO params fno_resolution_scaling_factor=None, fno_block_precision="full", fno_use_channel_mlp=True, fno_channel_mlp_dropout=0, fno_channel_mlp_expansion=0.5, fno_non_linearity=F.gelu, fno_stabilizer=None, fno_norm=None, fno_ada_in_features=None, fno_ada_in_dim=1, fno_preactivation=False, fno_skip="linear", fno_channel_mlp_skip="soft-gating", fno_separable=False, fno_factorization=None, fno_rank=1.0, fno_fixed_rank_modes=False, fno_implementation="factorized", fno_decomposition_kwargs=dict(), fno_conv_module=SpectralConv, ): super().__init__() self.gno_coord_dim = gno_coord_dim if self.gno_coord_dim != 3 and gno_use_open3d: print( f"Warning: GNO expects {self.gno_coord_dim}-d data but Open3d expects 3-d data" ) self.in_coord_dim = len(fno_n_modes) if self.in_coord_dim != self.gno_coord_dim: print( f"Warning: FNO expects {self.in_coord_dim}-d data while GNO expects {self.gno_coord_dim}-d data" ) # these lists contain the interior dimensions of the input # in order to reshape without explicitly providing dims self.in_coord_dim_forward_order = list(range(self.in_coord_dim)) self.in_coord_dim_reverse_order = [ j + 1 for j in self.in_coord_dim_forward_order ] self.gno_batched = gno_batched # used in forward call to GNO # if batched, we must account for the extra batch dim # which causes previous dims to be incremented by 1 if self.gno_batched: self.in_coord_dim_forward_order = [ j + 1 for j in self.in_coord_dim_forward_order ] self.in_coord_dim_reverse_order = [ j + 1 for j in self.in_coord_dim_reverse_order ] if fno_norm == "ada_in": if fno_ada_in_features is not None: self.adain_pos_embed = SinusoidalEmbedding( in_channels=fno_ada_in_dim, num_frequencies=fno_ada_in_features, embedding_type="transformer", ) # if ada_in positional embedding is provided, set the input dimension # of the ada_in norm to the output channels of positional embedding self.ada_in_dim = self.adain_pos_embed.out_channels else: self.ada_in_dim = fno_ada_in_dim else: self.adain_pos_embed = None self.ada_in_dim = None # Create lifting for FNOBlock separately fno_lifting_channels = fno_lifting_channel_ratio * fno_hidden_channels self.lifting = ChannelMLP( in_channels=in_channels + self.in_coord_dim, hidden_channels=fno_lifting_channels, out_channels=fno_hidden_channels, n_layers=3, ) self.fno_hidden_channels = fno_hidden_channels self.fno_blocks = FNOBlocks( n_modes=fno_n_modes, in_channels=fno_hidden_channels, out_channels=fno_hidden_channels, n_layers=fno_n_layers, resolution_scaling_factor=fno_resolution_scaling_factor, fno_block_precision=fno_block_precision, use_channel_mlp=fno_use_channel_mlp, channel_mlp_expansion=fno_channel_mlp_expansion, channel_mlp_dropout=fno_channel_mlp_dropout, non_linearity=fno_non_linearity, stabilizer=fno_stabilizer, norm=fno_norm, ada_in_features=self.ada_in_dim, preactivation=fno_preactivation, fno_skip=fno_skip, channel_mlp_skip=fno_channel_mlp_skip, separable=fno_separable, factorization=fno_factorization, rank=fno_rank, fixed_rank_modes=fno_fixed_rank_modes, implementation=fno_implementation, decomposition_kwargs=fno_decomposition_kwargs, conv_module=fno_conv_module, ) self.gno_radius = gno_radius if gno_weighting_function is not None: weight_fn = dispatch_weighting_fn( gno_weighting_function, sq_radius=gno_radius**2, scale=gno_weight_function_scale, ) else: weight_fn = None self.gno = GNOBlock( in_channels=fno_hidden_channels, out_channels=fno_hidden_channels, radius=gno_radius, weighting_fn=weight_fn, coord_dim=self.gno_coord_dim, pos_embedding_type=gno_pos_embed_type, pos_embedding_channels=gno_embed_channels, pos_embedding_max_positions=gno_embed_max_positions, channel_mlp_layers=gno_channel_mlp_hidden_layers, channel_mlp_non_linearity=gno_channel_mlp_non_linearity, transform_type=gno_transform_type, use_open3d_neighbor_search=gno_use_open3d, use_torch_scatter_reduce=gno_use_torch_scatter, ) projection_channels = projection_channel_ratio * fno_hidden_channels self.projection = ChannelMLP( in_channels=fno_hidden_channels, out_channels=out_channels, hidden_channels=projection_channels, n_layers=2, n_dim=1, non_linearity=fno_non_linearity, ) # out_p : (n_out, gno_coord_dim) # in_p : (n_1, n_2, ..., n_k, k) # if batched shape is the same because this is just geometry # that remains constant across the entire batch # f : (n_1, n_2, ..., n_k, in_channels) # if batched, (b, n_1, n_2, ..., n_k, in_channels) # ada_in : (fno_ada_in_dim, ) # returns: (fno_hidden_channels, n_1, n_2, ...) def latent_embedding(self, in_p, f, ada_in=None): if self.gno_batched: batch_size = f.shape[0] # repeat in_p along the batch dimension for latent embedding in_p = in_p.repeat([batch_size] + [1] * (in_p.ndim)) in_p = torch.cat((f, in_p), dim=-1) if self.gno_batched: # shape: (b, k, n_1, n_2, ... n_k) in_p = in_p.permute(0, -1, *self.in_coord_dim_forward_order) else: in_p = in_p.permute(-1, *self.in_coord_dim_forward_order).unsqueeze(0) # Update Ada IN embedding if ada_in is not None: if self.adain_pos_embed is not None: ada_in_embed = self.adain_pos_embed(ada_in.unsqueeze(0)).squeeze(0) else: ada_in_embed = ada_in self.fno_blocks.set_ada_in_embeddings(ada_in_embed) # Apply FNO blocks in_p = self.lifting(in_p) for layer_idx in range(self.fno_blocks.n_layers): in_p = self.fno_blocks(in_p, layer_idx) if self.gno_batched: return in_p else: return in_p.squeeze(0)
[docs] def integrate_latent(self, in_p, out_p, latent_embed): """ Compute integration region for each output point """ # (n_1*n_2*..., fno_hidden_channels) # if batched, (b, n1*n2*..., fno_hidden_channels) if self.gno_batched: batch_size = latent_embed.shape[0] latent_embed = latent_embed.permute( 0, *self.in_coord_dim_reverse_order, 1 ).reshape((batch_size, -1, self.fno_hidden_channels)) else: latent_embed = latent_embed.permute( *self.in_coord_dim_reverse_order, 0 ).reshape((-1, self.fno_hidden_channels)) # (n_out, fno_hidden_channels) out = self.gno( y=in_p.reshape(-1, in_p.shape[-1]), x=out_p, f_y=latent_embed, ) # if self.gno is variable and not batched if out.ndim == 2: out = out.unsqueeze(0) out = out.permute(0, 2, 1) # b, c, n_out # Project pointwise to out channels out = self.projection(out) if self.gno_batched: out = out.permute(0, 2, 1) else: out = out.squeeze(0).permute(1, 0) return out
[docs] def forward(self, in_p, out_p, f, ada_in=None, **kwargs): if kwargs: warnings.warn( f"FNOGNO.forward() received unexpected keyword arguments: {list(kwargs.keys())}. " "These arguments will be ignored.", UserWarning, stacklevel=2, ) # Compute latent space embedding latent_embed = self.latent_embedding(in_p=in_p, f=f, ada_in=ada_in) # Integrate latent space out = self.integrate_latent(in_p=in_p, out_p=out_p, latent_embed=latent_embed) return out