import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from ..layers.channel_mlp import ChannelMLP
from ..layers.spectral_convolution import SpectralConv
from ..layers.skip_connections import skip_connection
from ..layers.padding import DomainPadding
from ..layers.coda_layer import CODALayer
from ..layers.resample import resample
from ..layers.embeddings import GridEmbedding2D, GridEmbeddingND
[docs]
class CODANO(nn.Module):
"""Codomain Attention Neural Operators (CoDA-NO)
It uses a specialized attention mechanism in the codomain space for data in
infinite dimensional spaces as described in [1]_. The model treats each input channel as a variable of the physical system
and uses attention mechanism to model the interactions between the variables. The model uses lifting and projection modules
to map the input variables to a higher-dimensional space and then back to the output space. The model also supports positional
encoding and static channel information for additional context of the physical system such as external force or inlet condition.
Parameters
----------
n_layers : int
The number of codomain attention layers. Default: 4
n_modes : list
The number of Fourier modes to use in integral operators in the CoDA-NO block along each dimension.
Example: For a 5-layer 2D CoDA-NO, n_modes=[[16, 16], [16, 16], [16, 16], [16, 16], [16, 16]]
Other parameters
---------------
output_variable_codimension : int, optional
The number of output channels (or output codomain dimension) corresponding to each input variable (or input channel).
Example: For an input with 3 variables (channels) and output_variable_codimension=2, the output will have 6 channels (3 variables × 2 codimension). Default: 1
lifting_channels : int, optional
Number of intermediate channels in the lifting block.
The lifting module projects each input variable (i.e., each input channel) into a
higher-dimensional space determined by hidden_variable_codimension.
If lifting_channels is None, lifting is not performed and the input channels are
directly used as tokens for codomain attention. Default: 64
hidden_variable_codimension : int, optional
The number of hidden channels corresponding to each input variable (or channel). Each input channel is independently lifted
to hidden_variable_codimension channels by the lifting block. Default: 32
projection_channels : int, optional
The number of intermediate channels in the projection block of the CODANO.
If projection_channels=None, projection is not performed and the output of
the last CoDA block is returned directly. Default: 64
use_positional_encoding : bool, optional
Indicates whether to use variable-specific positional encoding. If True, a learnable positional encoding is concatenated
to each variable (each input channel) before the lifting operation.
The positional encoding used here is a function space generalization of the learnable positional encoding
used in BERT [2]. In CODANO, the positional encoding is a function on domain which is learned directly
in the Fourier Space. Default: False
positional_encoding_dim : int, optional
The dimension (number of channels) of the positional encoding learned of each input variable
(i.e., input channel). Default: 8
positional_encoding_modes : list, optional
Number of Fourier modes used in positional encoding along each dimension. The positional embeddings are functions and are directly learned
in Fourier space. This parameter must be specified when use_positional_encoding=True.
Example: For a 2D input, positional_encoding_modes could be [16, 16]. Default: None
static_channel_dim : int, optional
The number of channels for static information, such as boundary conditions in PDEs. These channels are concatenated with
each variable before the lifting operation and used to provide additional information
regarding the physical setup of the system.
When static_channel_dim > 0, additional information must be provided during the forward pass.
For example, static_channel_dim=1 can be used to provide mask of the domain pointing
a hole or obstacle in the domain. Default: 0
variable_ids : list[str], optional
The names of the variables in the dataset.
This parameter is only required when use_positional_encoding=True to initialize learnable
positional embeddings for each unique physical variable in the dataset.
For example:
If the dataset consists of only Navier Stokes equations, the variable_ids=['u_x', 'u_y', 'p'], representing the velocity
components in x and y directions and pressure, respectively. Please note that we consider each input channel as a physical
variable of the PDE.
Please note that the 'velocity' variable is composed of two channels (codimension=2) and we have split the velocity field
into two components, i.e., u_x and u_y. And this is to be done for all variable with codimension > 1.
If the dataset consists of multiple PDEs, such as Navier Stokes and Heat equation, the variable_ids=['u_x', 'u_y', 'p', 'T'],
where 'T' represents the temperature variable for the Heat equation and 'u_x', 'u_y', 'p' are the velocity components and pressure
for the Navier Stokes equations. This is required when we aim to learn a single solver for multiple different PDEs.
This parameter is not required when use_positional_encoding=False. Default: None
per_layer_scaling_factors : list, optional
The output scaling factor for each CoDANO_block along each dimension. The output of each of the CoDANO_block
is resampled according to the scaling factor and then passed to the following CoDANO_blocks.
Example: For a 2D input and n_layers=5, per_layer_scaling_factors=[[1, 1], [0.5, 0.5], [1, 1], [2, 2], [1, 1]],
which downsamples the output of the second layer by a factor of 2 and upsamples the output of the fourth layer by a factor of 2.
The resolution of the output of the CODANO model is determined by the product of the scaling factors of all the layers. Default: None
n_heads : list, optional
The number of attention heads for each layer.
Example: For a 4-layer CoDA-NO, n_heads=[2, 2, 2, 2]. Default: None (single attention head for each codomain attention block)
attention_scaling_factors : list, optional
Scaling factors in the codomain attention mechanism to scale the key and query functions. These scaling factors are used to resample
the key and query function before calculating the attention matrix. It does not have any effect on the value functions
in the codomain attention mechanism, i.e., it does not change the output shape of the block.
Example: For a 5-layer CoDA-NO, attention_scaling_factors=[0.5, 0.5, 0.5, 0.5, 0.5], which downsample the key and query functions,
reducing the resolution by a factor of 2. Default: None (no scaling)
conv_module : nn.Module, optional
The convolution module to use in the CoDANO_block. Default: SpectralConv
nonlinear_attention : bool, optional
Indicates whether to use a non-linear attention mechanism, employing non-linear key, query, and value operators. Default: False
non_linearity : callable, optional
The non-linearity to use in the codomain attention block. Default: F.gelu
attention_token_dim : int, optional
The number of channels in each token function. attention_token_dim must divide hidden_variable_codimension. Default: 1
per_channel_attention : bool, optional
Indicates whether to use a per-channel attention mechanism in Codomain attention layer. Default: False
enable_cls_token : bool, optional
Indicates whether to use a learnable CLASS token during the attention mechanism. We use a function-space generalization of the
learnable [class] token used in vision transformers such as ViT, which is learned directly in Fourier space.
The [class] function is realized on the input grid by performing an inverse Fourier transform of the learned Fourier coefficients.
Then, the [class] token function is added to the set of input token functions before passing to the codomain attention layer. It aggregates
information from all the other tokens through the attention mechanism. The output token corresponding to the [class] token is discarded in the
output of the last CoDA block. Default: False
use_horizontal_skip_connection : bool, optional
Indicates whether to use horizontal skip connections, similar to U-shaped architectures. Default: False
horizontal_skips_map : dict, optional
A mapping that specifies horizontal skip connections between layers. Only required when use_horizontal_skip_connection=True.
Example: For a 5-layer architecture, horizontal_skips_map={4: 0, 3: 1} creates skip connections from layer 0 to layer 4 and layer 1 to layer 3. Default: None
domain_padding : float, optional
The padding factor for each input channel. It zero pads each of the channel. Default: 0.25
layer_kwargs : dict, optional
Additional arguments for the CoDA blocks. Default: {}
References
-----------
.. [1] : Rahman, Md Ashiqur, et al. "Pretraining codomain attention neural operators for solving multiphysics pdes." (2024).
NeurIPS 2024. https://arxiv.org/pdf/2403.12553.
.. [2] : Devlin, Jacob, et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.
"""
def __init__(
self,
output_variable_codimension=1,
lifting_channels: int = 64,
hidden_variable_codimension=32,
projection_channels: int = 64,
use_positional_encoding=False,
positional_encoding_dim=8,
positional_encoding_modes=None,
static_channel_dim=0,
variable_ids=None,
use_horizontal_skip_connection=False,
horizontal_skips_map=None,
n_layers=4,
n_modes=None,
per_layer_scaling_factors=None,
n_heads=None,
attention_scaling_factors=None,
conv_module=SpectralConv,
nonlinear_attention=False,
non_linearity=F.gelu,
attention_token_dim=1,
per_channel_attention=False,
layer_kwargs={},
domain_padding=0.25,
enable_cls_token=False,
):
super().__init__()
self.n_layers = n_layers
assert len(n_modes) == n_layers, "number of modes for all layers are not given"
assert (
len(n_heads) == n_layers or n_heads is None
), "number of Attention head for all layers are not given"
assert (
len(per_layer_scaling_factors) == n_layers
or per_layer_scaling_factors is None
), "scaling for all layers are not given"
assert (
len(attention_scaling_factors) == n_layers
or attention_scaling_factors is None
), "attention scaling for all layers are not given"
if use_positional_encoding:
assert positional_encoding_dim > 0, "positional encoding dim is not given"
assert (
positional_encoding_modes is not None
), "positional encoding modes are not given"
else:
positional_encoding_dim = 0
if attention_scaling_factors is None:
attention_scaling_factors = [1] * n_layers
input_variable_codimension = 1 # each channel is a variable
if lifting_channels is None:
self.lifting = False
else:
lifting_variable_codimension = lifting_channels
self.lifting = True
if projection_channels is None:
self.projection = False
else:
projection_variable_codimension = projection_channels
self.projection = True
extended_variable_codimemsion = (
input_variable_codimension + static_channel_dim + positional_encoding_dim
)
if not self.lifting:
hidden_variable_codimension = extended_variable_codimemsion
assert (
hidden_variable_codimension % attention_token_dim == 0
), "attention token dim should divide hidden variable codimension"
self.n_dim = len(n_modes[0])
if n_heads is None:
n_heads = [1] * n_layers
if per_layer_scaling_factors is None:
per_layer_scaling_factors = [[1] * self.n_dim] * n_layers
if attention_scaling_factors is None:
attention_scaling_factors = [1] * n_layers
self.input_variable_codimension = input_variable_codimension
self.hidden_variable_codimension = hidden_variable_codimension
self.n_modes = n_modes
self.per_layer_scale_factors = per_layer_scaling_factors
self.non_linearity = non_linearity
self.n_heads = n_heads
self.enable_cls_token = enable_cls_token
self.positional_encoding_dim = positional_encoding_dim
self.variable_ids = variable_ids
self.attention_scalings = attention_scaling_factors
self.positional_encoding_modes = positional_encoding_modes
self.static_channel_dim = static_channel_dim
self.layer_kwargs = layer_kwargs
self.use_positional_encoding = use_positional_encoding
self.use_horizontal_skip_connection = use_horizontal_skip_connection
self.horizontal_skips_map = horizontal_skips_map
self.output_variable_codimension = output_variable_codimension
if self.positional_encoding_modes is not None:
self.positional_encoding_modes[-1] = self.positional_encoding_modes[-1] // 2
# calculating scaling
if self.per_layer_scale_factors is not None:
self.end_to_end_scaling = [1] * len(self.per_layer_scale_factors[0])
# multiplying scaling factors
for k in self.per_layer_scale_factors:
self.end_to_end_scaling = [
i * j for (i, j) in zip(self.end_to_end_scaling, k)
]
else:
self.end_to_end_scaling = [1] * self.n_dim
if self.n_heads is None:
self.n_heads = [1] * self.n_layers
# Setting up domain padding for encoder and reconstructor
if domain_padding is not None and domain_padding > 0:
self.domain_padding = DomainPadding(
domain_padding=domain_padding,
resolution_scaling_factor=self.end_to_end_scaling,
)
else:
self.domain_padding = None
self.extended_variable_codimemsion = extended_variable_codimemsion
if self.lifting:
self.lifting = ChannelMLP(
in_channels=extended_variable_codimemsion,
out_channels=self.hidden_variable_codimension,
hidden_channels=lifting_variable_codimension,
n_layers=2,
n_dim=self.n_dim,
)
else:
self.hidden_variable_codimension = self.extended_variable_codimemsion
self.attention_layers = nn.ModuleList([])
for i in range(self.n_layers):
self.attention_layers.append(
CODALayer(
n_modes=self.n_modes[i],
n_heads=self.n_heads[i],
scale=self.attention_scalings[i],
token_codimension=attention_token_dim,
per_channel_attention=per_channel_attention,
nonlinear_attention=nonlinear_attention,
resolution_scaling_factor=self.per_layer_scale_factors[i],
conv_module=conv_module,
non_linearity=self.non_linearity,
**self.layer_kwargs,
)
)
if self.use_horizontal_skip_connection:
# horizontal skip connections
# linear projection of the concated tokens from skip connections
self.skip_map_module = nn.ModuleDict()
for k in self.horizontal_skips_map.keys():
self.skip_map_module[str(k)] = ChannelMLP(
in_channels=2 * self.hidden_variable_codimension,
out_channels=self.hidden_variable_codimension,
hidden_channels=None,
n_layers=1,
non_linearity=nn.Identity(),
n_dim=self.n_dim,
)
if self.projection:
self.projection = ChannelMLP(
in_channels=self.hidden_variable_codimension,
out_channels=output_variable_codimension,
hidden_channels=projection_variable_codimension,
n_layers=2,
n_dim=self.n_dim,
)
else:
self.projection = None
if enable_cls_token:
self.cls_token = nn.Parameter(
torch.randn(
1,
self.hidden_variable_codimension,
*self.n_modes[0],
dtype=torch.cfloat,
)
)
if use_positional_encoding:
self.positional_encoding = nn.ParameterDict()
for i in self.variable_ids:
self.positional_encoding[i] = nn.Parameter(
torch.randn(
1,
positional_encoding_dim,
*self.positional_encoding_modes,
dtype=torch.cfloat,
)
)
def _extend_positional_encoding(self, new_var_ids):
"""
Add variable specific positional encoding for new variables. This function is required
while adapting a pre-trained model to a new dataset/PDE with additional new variables.
Parameters
----------
new_var_ids : list[str]
IDs of the new variables to add positional encoding.
"""
for i in new_var_ids:
self.positional_encoding[i] = nn.Parameter(
torch.randn(
1,
self.positional_encoding_dim,
*self.positional_encoding_modes,
dtype=torch.cfloat,
)
)
self.variable_ids += new_var_ids
def _get_positional_encoding(self, x, input_variable_ids):
"""
Returns the positional encoding for the input variables.
Parameters
----------
x : torch.Tensor
input tensor of shape (batch_size, num_inp_var, H, W, ...)
input_variable_ids : list[str]
The names of the variables corresponding to the channels of input 'x'.
"""
encoding_list = []
for i in input_variable_ids:
encoding_list.append(
torch.fft.irfftn(self.positional_encoding[i], s=x.shape[-self.n_dim :])
)
return torch.stack(encoding_list, dim=1)
def _get_cls_token(self, x):
"""
Returns the learnable cls token for the input variables.
Parameters
----------
x : torch.Tensor
input tensor of shape (batch_size, num_inp_var, H, W, ...)
This is used to determine the shape of the cls token.
"""
cls_token = torch.fft.irfftn(self.cls_token, s=x.shape[-self.n_dim :])
repeat_shape = [1 for _ in x.shape]
repeat_shape[0] = x.shape[0]
cls_token = cls_token.repeat(*repeat_shape)
return cls_token
def _extend_variables(self, x, static_channel, input_variable_ids):
"""
Extend the input variables by concatenating the static channel and positional encoding.
Parameters
----------
x : torch.Tensor
input tensor of shape (batch_size, num_inp_var, H, W, ...)
static_channel : torch.Tensor
static channel tensor of shape (batch_size, static_channel_dim, H, W, ...)
input_variable_ids : list[str]
The names of the variables corresponding to the channels of input 'x'.
"""
x = x.unsqueeze(2)
if static_channel is not None:
repeat_shape = [1 for _ in x.shape]
repeat_shape[1] = x.shape[1]
static_channel = static_channel.unsqueeze(1).repeat(*repeat_shape)
x = torch.cat([x, static_channel], dim=2)
if self.use_positional_encoding:
positional_encoding = self._get_positional_encoding(x, input_variable_ids)
repeat_shape = [1 for _ in x.shape]
repeat_shape[0] = x.shape[0]
x = torch.cat([x, positional_encoding.repeat(*repeat_shape)], dim=2)
return x
[docs]
def forward(self, x: torch.Tensor, static_channel=None, input_variable_ids=None):
"""
Parameters
----------
x : torch.Tensor
input tensor of shape (batch_size, num_inp_var, H, W, ...)
static_channel : torch.Tensor
static channel tensor of shape (batch_size, static_channel_dim, H, W, ...)
These channels provide additional information regarding the physical setup of the system.
Must be provided when `static_channel_dim > 0`.
input_variable_ids : list[str]
The names of the variables corresponding to the channels of input 'x'.
This parameter is required when `use_positional_encoding=True`.
For example, if input x represents and snapshot of the velocity field of a fluid flow, the variable_ids=['u_x', 'u_y'].
The variable_ids must be in the same order as the channels in the input tensor 'x', i.e., variable_ids[0] corresponds to the
first channel of 'x', i.e., x[:, 0, ...].
Returns
-------
torch.Tensor
output tensor of shape (batch_size, output_variable_codimension*num_inp_var, H, W, ...)
"""
batch, num_inp_var, *spatial_shape = (
x.shape
) # num_inp_var is the number of channels in the input
# input validation
if (
self.static_channel_dim > 0
and static_channel is None
and static_channel.shape[1] != self.static_channel_dim
):
raise ValueError(
f"Epected static channel dimension is {self.static_channel_dim}, but got {static_channel.shape[1]}"
)
if self.use_positional_encoding:
assert (
input_variable_ids is not None
), "variable_ids are not provided for the input"
assert x.shape[1] == len(
input_variable_ids
), f"Expected number of variables in input is {len(input_variable_ids)}, but got {x.shape[1]}"
# position encoding and static channels are concatenated with the input
# variables
x = self._extend_variables(x, static_channel, input_variable_ids)
# input variables are lifted to a higher-dimensional space
if self.lifting:
x = x.reshape(
batch * num_inp_var, self.extended_variable_codimemsion, *spatial_shape
)
x = self.lifting(x)
x = x.reshape(
batch, num_inp_var * self.hidden_variable_codimension, *spatial_shape
)
# getting the learnable CLASS token
if self.enable_cls_token:
cls_token = self._get_cls_token(x)
x = torch.cat(
[
cls_token,
x,
],
dim=1,
)
num_inp_var += 1
# zero padding the domain of the input
if self.domain_padding is not None:
x = self.domain_padding.pad(x)
# calculating the output shape
output_shape = [
int(round(i * j))
for (i, j) in zip(x.shape[-self.n_dim :], self.end_to_end_scaling)
]
# forward pass through the Codomain Attention layers
skip_outputs = {}
for layer_idx in range(self.n_layers):
if (
self.horizontal_skips_map is not None
and layer_idx in self.horizontal_skips_map.keys()
):
# `horizontal skip connections`
# tokens from skip connections are concatenated with the
# current token and then linearly projected
# to the `hidden_variable_codimension`
skip_val = skip_outputs[self.horizontal_skips_map[layer_idx]]
resolution_scaling_factors = [
m / n for (m, n) in zip(x.shape, skip_val.shape)
]
resolution_scaling_factors = resolution_scaling_factors[
-1 * self.n_dim :
]
t = resample(
skip_val,
resolution_scaling_factors,
list(range(-self.n_dim, 0)),
output_shape=x.shape[-self.n_dim :],
)
x = x.reshape(
batch * num_inp_var,
self.hidden_variable_codimension,
*x.shape[-self.n_dim :],
)
t = t.reshape(
batch * num_inp_var,
self.hidden_variable_codimension,
*t.shape[-self.n_dim :],
)
x = torch.cat([x, t], dim=1)
x = self.skip_map_module[str(layer_idx)](x)
x = x.reshape(
batch,
num_inp_var * self.hidden_variable_codimension,
*x.shape[-self.n_dim :],
)
if layer_idx == self.n_layers - 1:
cur_output_shape = output_shape
else:
cur_output_shape = None
x = self.attention_layers[layer_idx](x, output_shape=cur_output_shape)
# storing the outputs for skip connections
if (
self.horizontal_skips_map is not None
and layer_idx in self.horizontal_skips_map.values()
):
skip_outputs[layer_idx] = x.clone()
# removing the padding
if self.domain_padding is not None:
x = self.domain_padding.unpad(x)
# projecting the hidden variables to the output variables
if self.projection:
x = x.reshape(
batch * num_inp_var,
self.hidden_variable_codimension,
*x.shape[-self.n_dim :],
)
x = self.projection(x)
x = x.reshape(
batch,
num_inp_var * self.output_variable_codimension,
*x.shape[-self.n_dim :],
)
else:
return x
# discarding the CLASS token
if self.enable_cls_token:
x = x[:, self.output_variable_codimension :, ...]
return x