from functools import partial
import torch
import torch.nn.functional as F
import time
# Set warning filter to show each warning only once
import warnings
warnings.filterwarnings("once", category=UserWarning)
from .base_model import BaseModel
from ..layers.channel_mlp import ChannelMLP
from ..layers.embeddings import SinusoidalEmbedding
from ..layers.fno_block import FNOBlocks
from ..layers.spectral_convolution import SpectralConv
from ..layers.gno_block import GNOBlock
from ..layers.gno_weighting_functions import dispatch_weighting_fn
[docs]
class GINO(BaseModel):
"""
GINO: Geometry-informed Neural Operator - learns a mapping between
functions presented over arbitrary coordinate meshes. The model carries
global integration through spectral convolution layers in an intermediate
latent space, as described in [1]_. Optionally enables a weighted output
GNO for use in a Mollified Graph Neural Operator scheme, as introduced in [2]_.
Parameters
----------
in_channels : int
Feature dimension of input points. Determined by the problem.
out_channels : int
Feature dimension of output points. Determined by the problem.
fno_n_modes : tuple, optional
Number of modes along each dimension to use in FNO. Default: (16, 16, 16)
Must be larger enough but smaller than max_resolution//2 (Nyquist frequency) on the latent grid
fno_hidden_channels : int, optional
Hidden channels for use in FNO. Default: 64
fno_n_layers : int, optional
Number of layers in FNO. Default: 4
latent_feature_channels : int, optional
Number of channels in optional latent feature map to concatenate onto latent embeddings
before the FNO's forward pass. Default: None
projection_channel_ratio : int, optional
Ratio of pointwise projection channels in the final ChannelMLP to fno_hidden_channels.
The number of projection channels in the final ChannelMLP is computed by
projection_channel_ratio * fno_hidden_channels (i.e. default 256). Default: 4
gno_coord_dim : int, optional
Geometric dimension of input/output queries. Determined by the problem. Default: 3
in_gno_radius : float, optional
Radius in input space for GNO neighbor search. Default: 0.033
Larger radius means more neighboors so more global interactions, but larger computational cost.
out_gno_radius : float, optional
Radius in output space for GNO neighbor search. Default: 0.033
Larger radius means more neighboors so more global interactions, but larger computational cost.
gno_weighting_function : Literal["half_cos", "bump", "quartic", "quadr", "octic"], optional
Choice of weighting function to use in the output GNO for Mollified Graph Neural Operator-based models.
See neuralop.layers.gno_weighting_functions for more details. Default: None
gno_weight_function_scale : float, optional
Factor by which to scale weights from GNO weighting function.
If gno_weighting_function is None, this is not used. Default: 1
in_gno_transform_type : str, optional
Transform type parameter for input GNO. Default: "linear"
See neuralop.layers.gno_block for more details.
See neuralop.layers.gno_block for more details.
out_gno_transform_type : str, optional
Transform type parameter for output GNO.
Options: "linear", "nonlinear", "nonlinear_kernelonly". Default: "linear"
See neuralop.layers.gno_block for more details.
Type of optional sinusoidal positional embedding to use in input GNOBlock. Default: "transformer"
Type of optional sinusoidal positional embedding to use in input GNOBlock.
Options: "transformer", "nerf". Default: "transformer"
Type of optional sinusoidal positional embedding to use in output GNOBlock. Default: "transformer"
fno_in_channels : int, optional
Number of input channels for FNO. Default: 3
fno_lifting_channel_ratio : int, optional
Ratio of lifting channels to fno_hidden_channels.
The number of lifting channels in the lifting block of the FNO is
fno_lifting_channel_ratio * hidden_channels (i.e. default 128). Default: 2
gno_embed_channels : int, optional
Dimension of optional per-channel embedding to use in GNOBlock. Default: 32
gno_embed_max_positions : int, optional
Max positions of optional per-channel embedding to use in GNOBlock.
If gno_pos_embed_type != 'transformer', this is not used. Default: 10000
in_gno_channel_mlp_hidden_layers : list, optional
Widths of hidden layers in input GNO. Default: [80, 80, 80]
out_gno_channel_mlp_hidden_layers : list, optional
Widths of hidden layers in output GNO. Default: [512, 256]
gno_channel_mlp_non_linearity : nn.Module, optional
Nonlinearity to use in GNO ChannelMLP. Default: F.gelu
gno_use_open3d : bool, optional
Whether to use Open3D neighbor search.
If False, uses pure-PyTorch fallback neighbor search. Default: True
gno_use_torch_scatter : bool, optional
Whether to use torch-scatter to perform grouped reductions in the IntegralTransform.
If False, uses native Python reduction in neuralop.layers.segment_csr. Default: True
.. warning::
torch-scatter is an optional dependency that conflicts with the newest versions of PyTorch,
so you must handle the conflict explicitly in your environment. See :ref:`torch_scatter_dependency`
for more information.
out_gno_tanh : bool, optional
Whether to use tanh to stabilize outputs of the output GNO. Default: False
fno_resolution_scaling_factor : float, optional
Factor by which to scale output of FNO. Default: None
fno_block_precision : str, optional
Data precision to compute within FNO block. Options: "full", "half", "mixed". Default: "full"
fno_use_channel_mlp : bool, optional
Whether to use a ChannelMLP layer after each FNO block. Default: True
fno_channel_mlp_dropout : float, optional
Dropout parameter of above ChannelMLP. Default: 0
fno_channel_mlp_expansion : float, optional
Expansion parameter of above ChannelMLP. Default: 0.5
fno_non_linearity : nn.Module, optional
Nonlinear activation function between each FNO layer. Default: F.gelu
fno_stabilizer : nn.Module, optional
By default None, otherwise tanh is used before FFT in the FNO block. Default: None
fno_norm : str, optional
Normalization layer to use in FNO. Options: "ada_in", "group_norm", "instance_norm", None. Default: None
fno_ada_in_features : int, optional
If an adaptive mesh is used, number of channels of its positional embedding.
If None, adaptive mesh embedding is not used. Default: 4
fno_ada_in_dim : int, optional
Dimensions of above FNO adaptive mesh. Default: 1
fno_preactivation : bool, optional
Whether to use ResNet-style preactivation. Default: False
fno_skip : str, optional
Type of skip connection to use. Options: "linear", "identity", "soft-gating", None. Default: "linear"
fno_channel_mlp_skip : str, optional
Type of skip connection to use in the FNO.
Options: "linear", "identity", "soft-gating", None. Default: "soft-gating"
fno_separable : bool, optional
Whether to use a separable spectral convolution. Default: False
fno_factorization : str, optional
Tensor factorization of the parameters weight to use.
Options: "tucker", "tt", "cp", None. Default: None
fno_rank : float, optional
Rank of the tensor factorization of the Fourier weights. Default: 1.0
Set to float <1.0 when using TFNO (i.e. when factorization is not None).
A TFNO with rank 0.1 has roughly 10% of the parameters of a dense FNO.
fno_fixed_rank_modes : bool, optional
Whether to not factorize certain modes. Default: False
fno_implementation : str, optional
If factorization is not None, forward mode to use.
Options: "reconstructed", "factorized". Default: "factorized"
fno_decomposition_kwargs : dict, optional
Additional parameters to pass to the tensor decomposition. Default: {}
fno_conv_module : nn.Module, optional
Spectral convolution module to use. Default: SpectralConv
References
-----------
.. [1] : Li, Z., Kovachki, N., Choy, C., Li, B., Kossaifi, J., Otta, S.,
Nabian, M., Stadler, M., Hundt, C., Azizzadenesheli, K., Anandkumar, A. (2023)
Geometry-Informed Neural Operator for Large-Scale 3D PDEs. NeurIPS 2023,
https://proceedings.neurips.cc/paper_files/paper/2023/hash/70518ea42831f02afc3a2828993935ad-Abstract-Conference.html
.. [2] : Lin, R. et al. Placeholder reference for Mollified Graph Neural Operators.
"""
def __init__(
self,
in_channels,
out_channels,
latent_feature_channels=None,
projection_channel_ratio=4,
gno_coord_dim=3,
in_gno_radius=0.033,
out_gno_radius=0.033,
in_gno_transform_type="linear",
out_gno_transform_type="linear",
gno_weighting_function=None,
gno_weight_function_scale=1,
in_gno_pos_embed_type="transformer",
out_gno_pos_embed_type="transformer",
fno_in_channels=3,
fno_n_modes=(16, 16, 16),
fno_hidden_channels=64,
fno_lifting_channel_ratio=2,
fno_n_layers=4,
# Other GNO Params
gno_embed_channels=32,
gno_embed_max_positions=10000,
in_gno_channel_mlp_hidden_layers=[80, 80, 80],
out_gno_channel_mlp_hidden_layers=[512, 256],
gno_channel_mlp_non_linearity=F.gelu,
gno_use_open3d=True,
gno_use_torch_scatter=True,
out_gno_tanh=None,
# Other FNO Params
fno_resolution_scaling_factor=None,
fno_block_precision="full",
fno_use_channel_mlp=True,
fno_channel_mlp_dropout=0,
fno_channel_mlp_expansion=0.5,
fno_non_linearity=F.gelu,
fno_stabilizer=None,
fno_norm=None,
fno_ada_in_features=4,
fno_ada_in_dim=1,
fno_preactivation=False,
fno_skip="linear",
fno_channel_mlp_skip="soft-gating",
fno_separable=False,
fno_factorization=None,
fno_rank=1.0,
fno_fixed_rank_modes=False,
fno_implementation="factorized",
fno_decomposition_kwargs=dict(),
fno_conv_module=SpectralConv,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.latent_feature_channels = latent_feature_channels
self.gno_coord_dim = gno_coord_dim
self.fno_hidden_channels = fno_hidden_channels
self.lifting_channels = fno_lifting_channel_ratio * fno_hidden_channels
# If the input GNO performs a nonlinear kernel, the GNO's output
# features must be the same dimension as its input.
# otherwise the kernel's MLP will perform a lifting operation to
# lift the inputs to ``fno_in_channels`` channels
if in_gno_transform_type in ["nonlinear", "nonlinear_kernelonly"]:
in_gno_out_channels = self.in_channels
else:
in_gno_out_channels = fno_in_channels
# The actual input channels to the FNO are computed here.
self.fno_in_channels = in_gno_out_channels
if latent_feature_channels is not None:
self.fno_in_channels += latent_feature_channels
if self.gno_coord_dim != 3 and gno_use_open3d:
warnings.warn(
f'GNO expects {self.gno_coord_dim}-d data but Open3d expects 3-d data',
UserWarning,
stacklevel=2,
)
gno_use_open3d = False
self.in_coord_dim = len(fno_n_modes)
self.gno_out_coord_dim = len(fno_n_modes) # gno output and fno will use same dimensions
if self.in_coord_dim != self.gno_coord_dim:
warnings.warn(
f'FNO expects {self.in_coord_dim}-d data while input GNO expects {self.gno_coord_dim}-d data',
UserWarning,
stacklevel=2,
)
self.in_coord_dim_forward_order = list(range(self.in_coord_dim))
# tensor indices starting at 2 to permute everything after channel and batch dims
self.in_coord_dim_reverse_order = [j + 2 for j in self.in_coord_dim_forward_order]
self.fno_norm = fno_norm
if self.fno_norm == "ada_in":
if fno_ada_in_features is not None and out_gno_pos_embed_type is not None:
self.adain_pos_embed = SinusoidalEmbedding(
in_channels=fno_ada_in_dim,
num_frequencies=fno_ada_in_features,
max_positions=10000,
embedding_type=out_gno_pos_embed_type,
)
self.ada_in_dim = self.adain_pos_embed.out_channels
else:
self.ada_in_dim = fno_ada_in_dim
self.adain_pos_embed = None
else:
self.adain_pos_embed = None
self.ada_in_dim = None
self.in_gno_radius = in_gno_radius
self.out_gno_radius = out_gno_radius
self.out_gno_tanh = out_gno_tanh
### input GNO
# input to the first GNO ChannelMLP: `x` pos encoding,
# `y` (integrand) pos encoding, potentially `f_y`
self.gno_in = GNOBlock(
in_channels=in_channels,
out_channels=in_gno_out_channels,
coord_dim=self.gno_coord_dim,
pos_embedding_type=in_gno_pos_embed_type,
pos_embedding_channels=gno_embed_channels,
pos_embedding_max_positions=gno_embed_max_positions,
radius=in_gno_radius,
reduction="mean",
weighting_fn=None,
channel_mlp_layers=in_gno_channel_mlp_hidden_layers,
channel_mlp_non_linearity=gno_channel_mlp_non_linearity,
transform_type=in_gno_transform_type,
use_torch_scatter_reduce=gno_use_torch_scatter,
use_open3d_neighbor_search=gno_use_open3d,
)
### Lifting layer before FNOBlocks
self.lifting = ChannelMLP(
in_channels=self.fno_in_channels,
hidden_channels=self.lifting_channels,
out_channels=fno_hidden_channels,
n_layers=2,
) # CHANGED RECENTLY FOR THIS PAPER
### FNOBlocks in latent space
# input: `in_p` intermediate embeddings,
# possibly concatenated feature channels `latent_features`
self.fno_blocks = FNOBlocks(
n_modes=fno_n_modes,
in_channels=fno_hidden_channels,
out_channels=fno_hidden_channels,
n_layers=fno_n_layers,
resolution_scaling_factor=fno_resolution_scaling_factor,
fno_block_precision=fno_block_precision,
use_channel_mlp=fno_use_channel_mlp,
channel_mlp_expansion=fno_channel_mlp_expansion,
channel_mlp_dropout=fno_channel_mlp_dropout,
non_linearity=fno_non_linearity,
stabilizer=fno_stabilizer,
norm=fno_norm,
ada_in_features=self.ada_in_dim,
preactivation=fno_preactivation,
fno_skip=fno_skip,
channel_mlp_skip=fno_channel_mlp_skip,
separable=fno_separable,
factorization=fno_factorization,
rank=fno_rank,
fixed_rank_modes=fno_fixed_rank_modes,
implementation=fno_implementation,
decomposition_kwargs=fno_decomposition_kwargs,
conv_module=fno_conv_module,
)
### output GNO
if gno_weighting_function is not None: # sq radius**2?
weight_fn = dispatch_weighting_fn(
gno_weighting_function,
sq_radius=out_gno_radius**2,
scale=gno_weight_function_scale,
)
else:
weight_fn = None
self.gno_out = GNOBlock(
in_channels=fno_hidden_channels, # number of channels in f_y
out_channels=fno_hidden_channels,
coord_dim=self.gno_coord_dim,
radius=self.out_gno_radius,
reduction="sum",
weighting_fn=weight_fn,
pos_embedding_type=out_gno_pos_embed_type,
pos_embedding_channels=gno_embed_channels,
pos_embedding_max_positions=gno_embed_max_positions,
channel_mlp_layers=out_gno_channel_mlp_hidden_layers,
channel_mlp_non_linearity=gno_channel_mlp_non_linearity,
transform_type=out_gno_transform_type,
use_torch_scatter_reduce=gno_use_torch_scatter,
use_open3d_neighbor_search=gno_use_open3d,
)
projection_channels = projection_channel_ratio * fno_hidden_channels
self.projection = ChannelMLP(
in_channels=fno_hidden_channels,
out_channels=self.out_channels,
hidden_channels=projection_channels,
n_layers=2,
n_dim=1,
non_linearity=fno_non_linearity,
)
# returns: (fno_hidden_channels, n_1, n_2, ...)
def latent_embedding(self, in_p, ada_in=None):
# in_p : (batch, n_1 , ... , n_k, in_channels + k)
# ada_in : (fno_ada_in_dim, )
# permute (b, n_1, ..., n_k, c) -> (b,c, n_1,...n_k)
in_p = in_p.permute(0, len(in_p.shape) - 1, *list(range(1,len(in_p.shape)-1)))
#Update Ada IN embedding
if ada_in is not None:
if ada_in.ndim == 2:
ada_in = ada_in.squeeze(0)
if self.adain_pos_embed is not None:
ada_in_embed = self.adain_pos_embed(ada_in.unsqueeze(0)).squeeze(0)
else:
ada_in_embed = ada_in
if self.fno_norm == "ada_in":
self.fno_blocks.set_ada_in_embeddings(ada_in_embed)
# Apply FNO blocks
in_p = self.lifting(in_p)
for idx in range(self.fno_blocks.n_layers):
in_p = self.fno_blocks(in_p, idx)
return in_p
[docs]
def forward(
self,
input_geom,
latent_queries,
output_queries,
x=None,
latent_features=None,
ada_in=None,
**kwargs,
):
if kwargs:
warnings.warn(
f"GINO.forward() received unexpected keyword arguments: {list(kwargs.keys())}. "
"These arguments will be ignored.",
UserWarning,
stacklevel=2,
)
"""The GINO's forward call:
Input GNO --> FNOBlocks --> output GNO + projection to output queries.
.. note ::
GINO currently supports batching **only in cases where the geometry of
inputs and outputs is shared across the entire batch**. Inputs can have a batch dim
in ``x`` and ``latent_features``, but it must be shared for both.
Parameters
----------
input_geom : torch.Tensor
input domain coordinate mesh
shape (1, n_in, gno_coord_dim)
latent_queries : torch.Tensor
latent geometry on which to compute FNO latent embeddings
a grid on [0,1] x [0,1] x ....
shape (1, n_gridpts_1, .... n_gridpts_n, gno_coord_dim)
output_queries : torch.Tensor | dict[torch.Tensor]
points at which to query the final GNO layer to get output.
shape (1, n_out, gno_coord_dim) per tensor.
* if a tensor, the model will output a tensor.
* if a dict of tensors, the model will return a dict of outputs, so
that ``output[key]`` corresponds to the model queried at
``output_queries[key]``.
x : torch.Tensor, optional
input function a defined on the input domain `input_geom`
shape (batch, n_in, in_channels). Default None
latent_features : torch.Tensor, optional
optional feature map to concatenate onto latent embedding
before being passed into the latent FNO, default None
if `latent_feature_channels` is set, must be passed
ada_in : torch.Tensor, optional
adaptive scalar instance parameter, defaults to None
Returns
-------
out : torch.Tensor | dict[torch.Tensor]
Function over the output query coordinates
* tensor if if ``output_queries`` is a tensor
* dict if if ``output_queries`` is a dict
"""
# Ensure input functions on the input geom and latent geom
# have compatible batch sizes
if x is None:
batch_size = 1
else:
batch_size = x.shape[0]
if latent_features is not None:
assert (
self.latent_feature_channels is not None
), "if passing latent features, latent_feature_channels must be set."
assert latent_features.shape[-1] == self.latent_feature_channels
# batch, n_gridpts_1, .... n_gridpts_n, gno_coord_dim
assert (
latent_features.ndim == self.gno_coord_dim + 2
), f"Latent features must be of shape (batch, n_gridpts_1, ...n_gridpts_n, gno_coord_dim), got {latent_features.shape}"
# latent features must have the same shape (except channels) as latent_queries
if latent_features.shape[0] != batch_size:
if latent_features.shape[0] == 1:
latent_features = latent_features.repeat(batch_size, *[1]*(latent_features.ndim-1))
input_geom = input_geom.squeeze(0)
latent_queries = latent_queries.squeeze(0)
# Pass through input GNOBlock
in_p = self.gno_in(
y=input_geom, x=latent_queries.view((-1, latent_queries.shape[-1])), f_y=x
)
grid_shape = latent_queries.shape[:-1] # disregard positional encoding dim
# shape (batch_size, grid1, ...gridn, -1)
in_p = in_p.view((batch_size, *grid_shape, -1))
if latent_features is not None:
in_p = torch.cat((in_p, latent_features), dim=-1)
# take apply fno in latent space
latent_embed = self.latent_embedding(in_p=in_p, ada_in=ada_in)
# Integrate latent space to output queries
# latent_embed shape (b, c, n_1, n_2, ..., n_k)
batch_size = latent_embed.shape[0]
# permute to (b, n_1, n_2, ...n_k, c)
# then reshape to (b, n_1 * n_2 * ...n_k, out_channels)
latent_embed = latent_embed.permute(0, *self.in_coord_dim_reverse_order, 1).reshape(batch_size, -1, self.fno_hidden_channels)
if self.out_gno_tanh in ["latent_embed", "both"]:
latent_embed = torch.tanh(latent_embed)
# integrate over the latent space
# if output queries is a dict, query the output gno separately
# with each tensor of query points
if isinstance(output_queries, dict):
out = {}
for key, out_p in output_queries.items():
out_p = out_p.squeeze(0)
sub_output = self.gno_out(
y=latent_queries.reshape((-1, latent_queries.shape[-1])),
x=out_p,
f_y=latent_embed,
)
sub_output = sub_output.permute(0, 2, 1)
# Project pointwise to out channels
# (b, n_in, out_channels)
sub_output = self.projection(sub_output).permute(0, 2, 1)
out[key] = sub_output
else:
output_queries = output_queries.squeeze(0)
# latent queries is of shape (d_1 x d_2 x... d_n x n), reshape to n_out x n
out = self.gno_out(
y=latent_queries.reshape((-1, latent_queries.shape[-1])),
x=output_queries,
f_y=latent_embed,
)
out = out.permute(0, 2, 1)
# Project pointwise to out channels
# (b, n_in, out_channels)
out = self.projection(out).permute(0, 2, 1)
return out