import torch.nn as nn
import torch.nn.functional as F
import torch
from ..layers.channel_mlp import ChannelMLP
from ..layers.spectral_convolution import SpectralConv
from ..layers.skip_connections import skip_connection
from ..layers.padding import DomainPadding
from ..layers.fno_block import FNOBlocks
from ..layers.resample import resample
from ..layers.embeddings import GridEmbedding2D, GridEmbeddingND
[docs]
class UNO(nn.Module):
"""U-Shaped Neural Operator, as described in [1]_.
Parameters
----------
in_channels : int, optional
Number of input channels, by default 3
out_channels : int, optional
Number of output channels, by default 1
hidden_channels : int
initial width of the UNO (i.e. number of channels)
lifting_channels : int, optional
number of hidden channels of the lifting block of the FNO, by default 256
projection_channels : int, optional
number of hidden channels of the projection block of the FNO, by default 256
positional_embedding : str literal | GridEmbedding2D | GridEmbeddingND | None
if "grid", appends a grid positional embedding with default settings to
the last channels of raw input. Assumes the inputs are discretized
over a grid with entry [0,0,...] at the origin and side lengths of 1.
If an initialized GridEmbedding, uses this module directly
See `neuralop.embeddings.GridEmbeddingND` for details
if None, does nothing
n_layers : int, optional
Number of Fourier Layers, by default 4
uno_out_channels: list
Number of output channel of each Fourier Layers.
Eaxmple: For a Five layer UNO uno_out_channels can be [32,64,64,64,32]
uno_n_modes: list
Number of Fourier Modes to use in integral operation of each Fourier Layers (along each dimension).
Example: For a five layer UNO with 2D input the uno_n_modes can be: [[5,5],[5,5],[5,5],[5,5],[5,5]]
uno_scalings: list
Scaling Factors for each Fourier Layers
Example: For a five layer UNO with 2D input, the uno_scalings can be : [[1.0,1.0],[0.5,0.5],[1,1],[1,1],[2,2]]
horizontal_skips_map: Dict, optional
a map {...., b: a, ....} denoting horizontal skip connection from a-th layer to
b-th layer. If None default skip connection is applied.
Example: For a 5 layer UNO architecture, the skip connections can be
horizontal_skips_map ={4:0,3:1}
incremental_n_modes : None or int tuple, default is None
* If not None, this allows to incrementally increase the number of modes in Fourier domain
during training. Has to verify n <= N for (n, m) in zip(incremental_n_modes, n_modes).
* If None, all the n_modes are used.
This can be updated dynamically during training.
channel_mlp_dropout: float, optional
dropout parameter for channelMLP after each FNO Block
channel_mlp_expansions: float, optional
expansion parameter for channelMLP after each FNO block
non_linearity : nn.Module, optional
Non-Linearity module to use, by default F.gelu
norm : F.module, optional
Normalization layer to use, by default None
preactivation : bool, default is False
if True, use resnet-style preactivation
skip : {'linear', 'identity', 'soft-gating'}, optional
Type of skip connection to use, by default 'soft-gating'
separable : bool, default is False
if True, use a depthwise separable spectral convolution
factorization : str or None, {'tucker', 'cp', 'tt'}
Tensor factorization of the parameters weight to use, by default None.
* If None, a dense tensor parametrizes the Spectral convolutions
* Otherwise, the specified tensor factorization is used.
joint_factorization : bool, optional
Whether all the Fourier Layers should be parametrized by a single tensor (vs one per layer), by default False
rank : float or rank, optional
Rank of the tensor factorization of the Fourier weights, by default 1.0
fixed_rank_modes : bool, optional
Modes to not factorize, by default False
implementation : {'factorized', 'reconstructed'}, optional, default is 'factorized'
If factorization is not None, forward mode to use::
* `reconstructed` : the full weight tensor is reconstructed from the factorization and used for the forward pass
* `factorized` : the input is directly contracted with the factors of the decomposition
decomposition_kwargs : dict, optional, default is {}
Optionaly additional parameters to pass to the tensor decomposition
domain_padding : None or float, optional
If not None, percentage of padding to use, by default None
domain_padding_mode : {'symmetric', 'one-sided'}, optional
How to perform domain padding, by default 'one-sided'
fft_norm : str, optional
by default 'forward'
References
-----------
.. [1] :
Rahman, M.A., Ross, Z., Azizzadenesheli, K. "U-NO: U-shaped
Neural Operators" (2022). TMLR 2022, https://arxiv.org/pdf/2204.11127.
"""
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
lifting_channels=256,
projection_channels=256,
positional_embedding="grid",
n_layers=4,
uno_out_channels=None,
uno_n_modes=None,
uno_scalings=None,
horizontal_skips_map=None,
incremental_n_modes=None,
channel_mlp_dropout=0,
channel_mlp_expansion=0.5,
non_linearity=F.gelu,
norm=None,
preactivation=False,
fno_skip="linear",
horizontal_skip="linear",
channel_mlp_skip="soft-gating",
separable=False,
factorization=None,
rank=1.0,
fixed_rank_modes=False,
integral_operator=SpectralConv,
operator_block=FNOBlocks,
implementation="factorized",
decomposition_kwargs=dict(),
domain_padding=None,
domain_padding_mode="one-sided",
verbose=False,
**kwargs
):
super().__init__()
self.n_layers = n_layers
assert uno_out_channels is not None, "uno_out_channels can not be None"
assert uno_n_modes is not None, "uno_n_modes can not be None"
assert uno_scalings is not None, "uno_scalings can not be None"
assert (
len(uno_out_channels) == n_layers
), "Output channels for all layers are not given"
assert (
len(uno_n_modes) == n_layers
), "number of modes for all layers are not given"
assert (
len(uno_scalings) == n_layers
), "Scaling factor for all layers are not given"
self.n_dim = len(uno_n_modes[0])
self.uno_out_channels = uno_out_channels
self.uno_n_modes = uno_n_modes
self.uno_scalings = uno_scalings
self.hidden_channels = hidden_channels
self.lifting_channels = lifting_channels
self.projection_channels = projection_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.horizontal_skips_map = horizontal_skips_map
self.non_linearity = non_linearity
self.rank = rank
self.factorization = factorization
self.fixed_rank_modes = fixed_rank_modes
self.decomposition_kwargs = decomposition_kwargs
self.fno_skip = (fno_skip,)
self.channel_mlp_skip = (channel_mlp_skip,)
self.implementation = implementation
self.separable = separable
self.preactivation = preactivation
self._incremental_n_modes = incremental_n_modes
self.operator_block = operator_block
self.integral_operator = integral_operator
# create positional embedding at the beginning of the model
if positional_embedding == "grid":
spatial_grid_boundaries = [[0., 1.]] * self.n_dim
self.positional_embedding = GridEmbeddingND(in_channels=self.in_channels,
dim=self.n_dim,
grid_boundaries=spatial_grid_boundaries)
elif isinstance(positional_embedding, GridEmbedding2D):
if self.n_dim == 2:
self.positional_embedding = positional_embedding
else:
raise ValueError(f'Error: expected {self.n_dim}-d positional embeddings, got {positional_embedding}')
elif isinstance(positional_embedding, GridEmbeddingND):
self.positional_embedding = positional_embedding
elif positional_embedding == None:
self.positional_embedding = None
else:
raise ValueError(f"Error: tried to instantiate FNO positional embedding with {positional_embedding},\
expected one of \'grid\', GridEmbeddingND")
if self.positional_embedding is not None:
in_channels += self.n_dim
# constructing default skip maps
if self.horizontal_skips_map is None:
self.horizontal_skips_map = {}
for i in range(
0,
n_layers // 2,
):
# example, if n_layers = 5, then 4:0, 3:1
self.horizontal_skips_map[n_layers - i - 1] = i
# self.uno_scalings may be a 1d list specifying uniform scaling factor at each layer
# or a 2d list, where each row specifies scaling factors along each dimention.
# To get the final (end to end) scaling factors we need to multiply
# the scaling factors (a list) of all layer.
self.end_to_end_scaling_factor = [1] * len(self.uno_scalings[0])
# multiplying scaling factors
for k in self.uno_scalings:
self.end_to_end_scaling_factor = [
i * j for (i, j) in zip(self.end_to_end_scaling_factor, k)
]
# list with a single element is replaced by the scaler.
if len(self.end_to_end_scaling_factor) == 1:
self.end_to_end_scaling_factor = self.end_to_end_scaling_factor[0]
if isinstance(self.end_to_end_scaling_factor, (float, int)):
self.end_to_end_scaling_factor = [
self.end_to_end_scaling_factor
] * self.n_dim
if verbose:
print("calculated out factor", self.end_to_end_scaling_factor)
if domain_padding is not None and (
(isinstance(domain_padding, list) and sum(domain_padding) > 0)
or (isinstance(domain_padding, (float, int)) and domain_padding > 0)
):
self.domain_padding = DomainPadding(
domain_padding=domain_padding,
padding_mode=domain_padding_mode,
resolution_scaling_factor=self.end_to_end_scaling_factor,
)
else:
self.domain_padding = None
self.domain_padding_mode = domain_padding_mode
self.lifting = ChannelMLP(
in_channels=in_channels,
out_channels=self.hidden_channels,
hidden_channels=self.lifting_channels,
n_layers=2,
n_dim=self.n_dim,
)
self.fno_blocks = nn.ModuleList([])
self.horizontal_skips = torch.nn.ModuleDict({})
prev_out = self.hidden_channels
for i in range(self.n_layers):
if i in self.horizontal_skips_map.keys():
prev_out = (
prev_out + self.uno_out_channels[self.horizontal_skips_map[i]]
)
print(f"{fno_skip=}")
print(f"{channel_mlp_skip=}")
self.fno_blocks.append(
self.operator_block(
in_channels=prev_out,
out_channels=self.uno_out_channels[i],
n_modes=self.uno_n_modes[i],
channel_mlp_dropout=channel_mlp_dropout,
channel_mlp_expansion=channel_mlp_expansion,
resolution_scaling_factor=[self.uno_scalings[i]],
non_linearity=non_linearity,
norm=norm,
preactivation=preactivation,
fno_skip=fno_skip,
channel_mlp_skip=channel_mlp_skip,
incremental_n_modes=incremental_n_modes,
rank=rank,
SpectralConv=self.integral_operator,
fixed_rank_modes=fixed_rank_modes,
implementation=implementation,
separable=separable,
factorization=factorization,
decomposition_kwargs=decomposition_kwargs,
)
)
if i in self.horizontal_skips_map.values():
self.horizontal_skips[str(i)] = skip_connection(
self.uno_out_channels[i],
self.uno_out_channels[i],
skip_type=horizontal_skip,
n_dim=self.n_dim,
)
prev_out = self.uno_out_channels[i]
self.projection = ChannelMLP(
in_channels=prev_out,
out_channels=out_channels,
hidden_channels=self.projection_channels,
n_layers=2,
n_dim=self.n_dim,
non_linearity=non_linearity,
)
[docs]
def forward(self, x, **kwargs):
if self.positional_embedding is not None:
x = self.positional_embedding(x)
x = self.lifting(x)
if self.domain_padding is not None:
x = self.domain_padding.pad(x)
output_shape = [
int(round(i * j))
for (i, j) in zip(x.shape[-self.n_dim :], self.end_to_end_scaling_factor)
]
skip_outputs = {}
cur_output = None
for layer_idx in range(self.n_layers):
if layer_idx in self.horizontal_skips_map.keys():
skip_val = skip_outputs[self.horizontal_skips_map[layer_idx]]
resolution_scaling_factors = [
m / n for (m, n) in zip(x.shape, skip_val.shape)
]
resolution_scaling_factors = resolution_scaling_factors[-1 * self.n_dim :]
t = resample(
skip_val, resolution_scaling_factors, list(range(-self.n_dim, 0))
)
x = torch.cat([x, t], dim=1)
if layer_idx == self.n_layers - 1:
cur_output = output_shape
x = self.fno_blocks[layer_idx](x, output_shape=cur_output)
if layer_idx in self.horizontal_skips_map.values():
skip_outputs[layer_idx] = self.horizontal_skips[str(layer_idx)](x)
if self.domain_padding is not None:
x = self.domain_padding.unpad(x)
x = self.projection(x)
return x