Source code for neuralop.models.gino

from functools import partial
import torch
import torch.nn.functional as F
import time

from torch import nn


from ..layers.channel_mlp import ChannelMLP
from ..layers.embeddings import SinusoidalEmbedding
from ..layers.fno_block import FNOBlocks
from ..layers.spectral_convolution import SpectralConv
from ..layers.gno_block import GNOBlock

[docs] class GINO(nn.Module): """GINO: Geometry-informed Neural Operator. Learns a mapping between functions presented over arbitrary coordinate meshes. The model carries global integration through spectral convolution layers in an intermediate latent space, as described in [1]_. Parameters ---------- in_channels : int feature dimension of input points out_channels : int feature dimension of output points latent_feature_channels : int, optional number of channels in optional latent feature map to concatenate onto latent embeddings before the latent FNO's forward pass, default None projection_channels : int, optional number of channels in FNO pointwise projection gno_coord_dim : int, optional geometric dimension of input/output queries, by default 3 gno_pos_embed_type : literal `{'transformer', 'nerf'}` | None type of optional sinusoidal positional embedding to use in GNOBlock, by default `'transformer'` gno_embed_channels: int dimension of optional per-channel embedding to use in GNOBlock, by default 32 gno_embed_max_positions: int max positions of optional per-channel embedding to use in GNOBlock, by default 10000. If `gno_pos_embed_type != 'transformer'`, value is unused. gno_radius : float, optional radius in input/output space for GNO neighbor search, by default 0.033 in_gno_channel_mlp_hidden_layers : list, optional widths of hidden layers in input GNO, by default [80, 80, 80] out_gno_channel_mlp_hidden_layers : list, optional widths of hidden layers in output GNO, by default [512, 256] gno_channel_mlp_non_linearity : nn.Module, optional nonlinearity to use in gno ChannelMLP, by default F.gelu in_gno_transform_type : str, optional transform type parameter for input GNO, by default 'linear' see neuralop.layers.IntegralTransform out_gno_transform_type : str, optional transform type parameter for output GNO, by default 'linear' see neuralop.layers.IntegralTransform gno_use_open3d : bool, optional whether to use open3d neighbor search, by default False if False, uses pure-PyTorch fallback neighbor search gno_use_torch_scatter : bool, optional whether to use torch_scatter's neighborhood reduction function or the native PyTorch implementation in IntegralTransform layers. If False, uses the fallback PyTorch version. out_gno_tanh : bool, optional whether to use tanh to stabilize outputs of the output GNO, by default False fno_in_channels : int, optional number of input channels for FNO, by default 26 fno_n_modes : tuple, optional number of modes along each dimension to use in FNO, by default (16, 16, 16) fno_hidden_channels : int, optional hidden channels for use in FNO, by default 64 lifting_channels : int, optional number of channels in FNO's pointwise lifting, by default 256 fno_projection_channels : int, optional number of channels in FNO's pointwise projection, by default 256 fno_n_layers : int, optional number of layers in FNO, by default 4 fno_resolution_scaling_factor : float | None, optional factor by which to scale output of FNO, by default None fno_incremental_n_modes : list[int] | None, defaults to None if passed, sets n_modes separately for each FNO layer. fno_block_precision : str, defaults to 'full' data precision to compute within fno block fno_use_channel_mlp : bool, defaults to False Whether to use a ChannelMLP layer after each FNO block. fno_channel_mlp_dropout : float, defaults to 0 dropout parameter of above ChannelMLP. fno_channel_mlp_expansion : float, defaults to 0.5 expansion parameter of above ChannelMLP. fno_non_linearity : nn.Module, defaults to F.gelu nonlinear activation function between each FNO layer. fno_stabilizer : nn.Module | None, defaults to None By default None, otherwise tanh is used before FFT in the FNO block. fno_norm : nn.Module | None, defaults to None normalization layer to use in FNO. fno_ada_in_features : int | None, defaults to 4 if an adaptive mesh is used, number of channels of its positional embedding. If None, adaptive mesh embedding is not used. fno_ada_in_dim : int, defaults to 1 dimensions of above FNO adaptive mesh. fno_preactivation : bool, defaults to False whether to use Resnet-style preactivation. fno_skip : str, defaults to 'linear' type of skip connection to use. fno_channel_mlp_skip : str, defaults to 'soft-gating' type of skip connection to use in the FNO 'linear': conv layer 'soft-gating': weights the channels of the input 'identity': nn.Identity fno_separable : bool, defaults to False if True, use a depthwise separable spectral convolution. fno_factorization : str {'tucker', 'tt', 'cp'} | None, defaults to None Tensor factorization of the parameters weight to use fno_rank : float, defaults to 1.0 Rank of the tensor factorization of the Fourier weights. fno_joint_factorization : bool, defaults to False Whether all the Fourier layers should be parameterized by a single tensor (vs one per layer). fno_fixed_rank_modes : bool, defaults to False Modes to not factorize. fno_implementation : str {'factorized', 'reconstructed'} | None, defaults to 'factorized' If factorization is not None, forward mode to use:: * `reconstructed` : the full weight tensor is reconstructed from the factorization and used for the forward pass * `factorized` : the input is directly contracted with the factors of the decomposition fno_decomposition_kwargs : dict, defaults to dict() Optionaly additional parameters to pass to the tensor decomposition. fno_conv_module : nn.Module, defaults to SpectralConv Spectral Convolution module to use. References ----------- .. _[1] : Li, Z., Kovachki, N., Choy, C., Li, B., Kossaifi, J., Otta, S., Nabian, M., Stadler, M., Hundt, C., Azizzadenesheli, K., Anandkumar, A. (2023) Geometry-Informed Neural Operator for Large-Scale 3D PDEs. NeurIPS 2023, https://proceedings.neurips.cc/paper_files/paper/2023/hash/70518ea42831f02afc3a2828993935ad-Abstract-Conference.html """ def __init__( self, in_channels, out_channels, latent_feature_channels=None, projection_channels=256, gno_coord_dim=3, gno_pos_embed_type='transformer', gno_embed_channels=32, gno_embed_max_positions=10000, gno_radius=0.033, in_gno_channel_mlp_hidden_layers=[80, 80, 80], out_gno_channel_mlp_hidden_layers=[512, 256], gno_channel_mlp_non_linearity=F.gelu, in_gno_transform_type='linear', out_gno_transform_type='linear', gno_use_open3d=False, gno_use_torch_scatter=True, out_gno_tanh=None, fno_in_channels=3, fno_n_modes=(16, 16, 16), fno_hidden_channels=64, lifting_channels=256, fno_n_layers=4, fno_resolution_scaling_factor=None, fno_incremental_n_modes=None, fno_block_precision='full', fno_use_channel_mlp=False, fno_channel_mlp_dropout=0, fno_channel_mlp_expansion=0.5, fno_non_linearity=F.gelu, fno_stabilizer=None, fno_norm=None, fno_ada_in_features=4, fno_ada_in_dim=1, fno_preactivation=False, fno_skip='linear', fno_channel_mlp_skip='soft-gating', fno_separable=False, fno_factorization=None, fno_rank=1.0, fno_joint_factorization=False, fno_fixed_rank_modes=False, fno_implementation='factorized', fno_decomposition_kwargs=dict(), fno_conv_module=SpectralConv, **kwargs ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.latent_feature_channels = latent_feature_channels self.gno_coord_dim = gno_coord_dim self.fno_hidden_channels = fno_hidden_channels self.lifting_channels = lifting_channels # TODO: make sure this makes sense in all contexts if in_gno_transform_type in ["nonlinear", "nonlinear_kernelonly"]: in_gno_out_channels = self.in_channels else: in_gno_out_channels = fno_in_channels self.fno_in_channels = in_gno_out_channels if latent_feature_channels is not None: self.fno_in_channels += latent_feature_channels if self.gno_coord_dim != 3 and gno_use_open3d: print(f'Warning: GNO expects {self.gno_coord_dim}-d data but Open3d expects 3-d data') gno_use_open3d = False self.in_coord_dim = len(fno_n_modes) self.gno_out_coord_dim = len(fno_n_modes) # gno output and fno will use same dimensions if self.in_coord_dim != self.gno_coord_dim: print(f'Warning: FNO expects {self.in_coord_dim}-d data while input GNO expects {self.gno_coord_dim}-d data') self.in_coord_dim_forward_order = list(range(self.in_coord_dim)) # channels starting at 2 to permute everything after channel and batch dims self.in_coord_dim_reverse_order = [j + 2 for j in self.in_coord_dim_forward_order] self.fno_norm = fno_norm if self.fno_norm == "ada_in": if fno_ada_in_features is not None and gno_pos_embed_type is not None: self.adain_pos_embed = SinusoidalEmbedding(in_channels=fno_ada_in_dim, num_frequencies=fno_ada_in_features, max_positions=10000, embedding_type=gno_pos_embed_type) self.ada_in_dim = self.adain_pos_embed.out_channels else: self.ada_in_dim = fno_ada_in_dim self.adain_pos_embed = None else: self.adain_pos_embed = None self.ada_in_dim = None self.gno_radius = gno_radius self.out_gno_tanh = out_gno_tanh ### input GNO # input to the first GNO ChannelMLP: x pos encoding, # y (integrand) pos encoding, potentially f_y self.gno_in = GNOBlock( in_channels=in_channels, out_channels=in_gno_out_channels, coord_dim=self.gno_coord_dim, pos_embedding_type=gno_pos_embed_type, pos_embedding_channels=gno_embed_channels, pos_embedding_max_positions=gno_embed_max_positions, radius=gno_radius, channel_mlp_layers=in_gno_channel_mlp_hidden_layers, channel_mlp_non_linearity=gno_channel_mlp_non_linearity, transform_type=in_gno_transform_type, use_open3d_neighbor_search=gno_use_open3d, use_torch_scatter_reduce=gno_use_torch_scatter, ) self.lifting = ChannelMLP(in_channels=self.fno_in_channels, hidden_channels=lifting_channels, out_channels=fno_hidden_channels, n_layers=3) self.fno_blocks = FNOBlocks( n_modes=fno_n_modes, hidden_channels=fno_hidden_channels, in_channels=fno_hidden_channels, out_channels=fno_hidden_channels, positional_embedding=None, n_layers=fno_n_layers, resolution_scaling_factor=fno_resolution_scaling_factor, incremental_n_modes=fno_incremental_n_modes, fno_block_precision=fno_block_precision, use_channel_mlp=fno_use_channel_mlp, channel_mlp_expansion=fno_channel_mlp_expansion, channel_mlp_dropout=fno_channel_mlp_dropout, non_linearity=fno_non_linearity, stabilizer=fno_stabilizer, norm=fno_norm, ada_in_features=self.ada_in_dim, preactivation=fno_preactivation, fno_skip=fno_skip, channel_mlp_skip=fno_channel_mlp_skip, separable=fno_separable, factorization=fno_factorization, rank=fno_rank, joint_factorization=fno_joint_factorization, fixed_rank_modes=fno_fixed_rank_modes, implementation=fno_implementation, decomposition_kwargs=fno_decomposition_kwargs, domain_padding=None, domain_padding_mode=None, conv_module=fno_conv_module, **kwargs ) ### output GNO self.gno_out = GNOBlock( in_channels=fno_hidden_channels, # number of channels in f_y out_channels=fno_hidden_channels, coord_dim=self.gno_coord_dim, radius=self.gno_radius, pos_embedding_type=gno_pos_embed_type, pos_embedding_channels=gno_embed_channels, pos_embedding_max_positions=gno_embed_max_positions, channel_mlp_layers=out_gno_channel_mlp_hidden_layers, channel_mlp_non_linearity=gno_channel_mlp_non_linearity, transform_type=out_gno_transform_type, use_open3d_neighbor_search=gno_use_open3d, use_torch_scatter_reduce=gno_use_torch_scatter ) self.projection = ChannelMLP(in_channels=fno_hidden_channels, out_channels=self.out_channels, hidden_channels=projection_channels, n_layers=2, n_dim=1, non_linearity=fno_non_linearity) #returns: (fno_hidden_channels, n_1, n_2, ...) def latent_embedding(self, in_p, ada_in=None): # in_p : (batch, n_1 , ... , n_k, in_channels + k) # ada_in : (fno_ada_in_dim, ) # permute (b, n_1, ..., n_k, c) -> (b,c, n_1,...n_k) in_p = in_p.permute(0, len(in_p.shape)-1, *list(range(1,len(in_p.shape)-1))) #Update Ada IN embedding if ada_in is not None: if ada_in.ndim == 2: ada_in = ada_in.squeeze(0) if self.adain_pos_embed is not None: ada_in_embed = self.adain_pos_embed(ada_in.unsqueeze(0)).squeeze(0) else: ada_in_embed = ada_in if self.fno_norm == "ada_in": self.fno_blocks.set_ada_in_embeddings(ada_in_embed) #Apply FNO blocks in_p = self.lifting(in_p) for idx in range(self.fno_blocks.n_layers): in_p = self.fno_blocks(in_p, idx) return in_p
[docs] def forward(self, x, input_geom, latent_queries, output_queries, latent_features=None, ada_in=None, **kwargs): """The GINO's forward call: Input GNO --> FNOBlocks --> output GNO + projection to output queries Parameters ---------- x : torch.Tensor input function a defined on the input domain `input_geom` shape (batch, n_in, in_channels) input_geom : torch.Tensor input domain coordinate mesh shape (1, n_in, gno_coord_dim) latent_queries : torch.Tensor latent geometry on which to compute FNO latent embeddings a grid on [0,1] x [0,1] x .... shape (1, n_gridpts_1, .... n_gridpts_n, gno_coord_dim) output_queries : torch.Tensor points at which to query the final GNO layer to get output shape (batch, n_out, gno_coord_dim) latent_features : torch.Tensor, optional optional feature map to concatenate onto latent embedding before being passed into the latent FNO, default None if `latent_feature_channels` is set, must be passed ada_in : torch.Tensor, optional adaptive scalar instance parameter, defaults to None """ batch_size = x.shape[0] input_geom = input_geom.squeeze(0) latent_queries = latent_queries.squeeze(0) # Pass through input GNOBlock in_p = self.gno_in(y=input_geom, x=latent_queries.view((-1, latent_queries.shape[-1])), f_y=x) grid_shape = latent_queries.shape[:-1] # disregard positional encoding dim # shape (batch_size, grid1, ...gridn, -1) in_p = in_p.view((batch_size, *grid_shape, -1)) if latent_features is not None: assert latent_features.shape[-1] == self.latent_feature_channels # latent features must have the same shape (except channels) as latent_queries if latent_features.shape[0] != batch_size: latent_features = latent_features.repeat(batch_size, *[1]*(latent_features.ndim-1)) in_p = torch.cat((in_p, latent_features), dim=-1) # take apply fno in latent space latent_embed = self.latent_embedding(in_p=in_p, ada_in=ada_in) # Integrate latent space to output queries #latent_embed shape (b, c, n_1, n_2, ..., n_k) batch_size = latent_embed.shape[0] # permute to (b, n_1, n_2, ...n_k, c) # then reshape to (b, n_1 * n_2 * ...n_k, out_channels) latent_embed = latent_embed.permute(0, *self.in_coord_dim_reverse_order, 1).reshape(batch_size, -1, self.fno_hidden_channels) if self.out_gno_tanh in ['latent_embed', 'both']: latent_embed = torch.tanh(latent_embed) # latent queries is of shape (d_1 x d_2 x... d_n x n), reshape to n_out x n out = self.gno_out(y=latent_queries.reshape((-1, latent_queries.shape[-1])), x=output_queries, f_y=latent_embed,) out = out.permute(0, 2, 1) # Project pointwise to out channels #(b, n_in, out_channels) out = self.projection(out).permute(0, 2, 1) return out