neuralop.layers.coda_blocks
.CODABlocks
- class neuralop.layers.coda_blocks.CODABlocks(n_modes, n_heads=1, token_codimension=1, head_codimension=None, codimension_size=None, per_channel_attention=True, permutation_eq=True, norm='instance_norm', temperature=1.0, nonlinear_attention=False, scale=None, output_scaling_factor=None, incremental_n_modes=None, non_linearity=<built-in function gelu>, use_channel_mlp=True, channel_mlp_expansion=1.0, fno_skip='linear', channel_mlp_skip='linear', preactivation=False, separable=False, factorization='tucker', rank=1.0, joint_factorization=False, conv_module=<class 'neuralop.layers.spectral_convolution.SpectralConv'>, fixed_rank_modes=False, implementation='factorized', decomposition_kwargs=None, **_kwargs)[source]
Co-domain Attention Blocks (CODABlocks) implement the transformer architecture in the operator learning framework, as described in [R5bec054e9579-1].
- Parameters:
- n_modeslist
Number of modes for each dimension used in K, Q, V operator.
- n_headsint
Number of heads for the attention mechanism.
- token_codimensionint
Co-dimension of each variable, i.e. number of output channels associated with each variable.
- head_codimensionint
Co-dimension of each output token for each head.
- codimension_sizeint
Size of the codimension for the whole function. Only used for permutation_eq = False.
- per_channel_attentionbool, optional
Whether to use per-channel attention. Default is True (overwrites token_codimension to 1).
- permutation_eqbool, optional
Whether to use permutation equivariant mixer layer after the attention mechanism.
- normliteral {‘instance_norm’} or None
Normalization module to be used. If ‘instance_norm’, instance normalization is applied to the token outputs of the attention module. Defaults to ‘instance_norm’
- temperaturefloat
Temperature parameter for the attention mechanism.
- nonlinear_attentionbool, optional
Whether to use non-linear activation for K, Q, V operator.
- scaleint
Scale for downsampling Q, K functions before calculating the attention matrix. Higher scale will downsample more.
- output_scaling_factorfloat
Scaling factor for the output.
Methods
compute_attention
(tokens, batch_size)Compute the key-query-value variant of the attention matrix for input token functions.
forward
(x)CoDANO's forward pass.
- Other Parameters:
- incremental_n_modeslist
Incremental number of modes for each dimension (for incremental training).
- use_channel_mlpbool, optional
Whether to use MLP layers to parameterize skip connections. Default is True.
- channel_mlp_expansionfloat, optional
Expansion parameter for self.channel_mlp. Default is 0.5.
- non_linearitycallable
Non-linearity function to be used.
- preactivationbool, optional
Whether to use preactivation. Default is False.
- fno_skipstr, optional
Type of skip connection to be used. Default is ‘linear’.
- channel_mlp_skipstr, optional
Module to use for ChannelMLP skip connections. Default is “linear”.
- separablebool, optional
Whether to use separable convolutions. Default is False.
- factorizationstr, optional
Type of factorization to be used. Default is ‘tucker’.
- rankfloat, optional
Rank of the factorization. Default is 1.0.
- conv_modulecallable
Spectral convolution module to be used.
- joint_factorizationbool, optional
Whether to factorize all spectralConv weights as one tensor. Default is False.
References
- compute_attention(tokens, batch_size)[source]
Compute the key-query-value variant of the attention matrix for input token functions.
- Parameters:
- tokenstorch.Tensor
Input tokens with shape (b * t, d, h, w, …), where: b is the batch size, t is the number of tokens, d is the token codimension, and h, w, … are the domain dimensions. Assumes input tokens have been normalized.
- batch_sizeint
The size of the batch.
- forward(x)[source]
CoDANO’s forward pass.
If
self.permutation_eq == True
, computes the permutation-equivariant forward pass, where the mixer FNO block is applied to each token separately, making the final result equivariant to any permutation of tokens.If
self.permutation_eq == True
, the mixer is applied to the whole function together, and tokens are treated as channels within the same function.
- Parameters:
- xtorch.Tensor
Input tensor with shape (b, t * d, h, w, …), where b is the batch size, t is the number of tokens, and d is the token codimension.