neuralop.layers.coda_blocks.CODABlocks

class neuralop.layers.coda_blocks.CODABlocks(n_modes, n_heads=1, token_codimension=1, head_codimension=None, codimension_size=None, per_channel_attention=True, permutation_eq=True, norm='instance_norm', temperature=1.0, nonlinear_attention=False, scale=None, output_scaling_factor=None, incremental_n_modes=None, non_linearity=<built-in function gelu>, use_channel_mlp=True, channel_mlp_expansion=1.0, fno_skip='linear', channel_mlp_skip='linear', preactivation=False, separable=False, factorization='tucker', rank=1.0, joint_factorization=False, conv_module=<class 'neuralop.layers.spectral_convolution.SpectralConv'>, fixed_rank_modes=False, implementation='factorized', decomposition_kwargs=None, **_kwargs)[source]

Co-domain Attention Blocks (CODABlocks) implement the transformer architecture in the operator learning framework, as described in [R5bec054e9579-1].

Parameters:
n_modeslist

Number of modes for each dimension used in K, Q, V operator.

n_headsint

Number of heads for the attention mechanism.

token_codimensionint

Co-dimension of each variable, i.e. number of output channels associated with each variable.

head_codimensionint

Co-dimension of each output token for each head.

codimension_sizeint

Size of the codimension for the whole function. Only used for permutation_eq = False.

per_channel_attentionbool, optional

Whether to use per-channel attention. Default is True (overwrites token_codimension to 1).

permutation_eqbool, optional

Whether to use permutation equivariant mixer layer after the attention mechanism.

normliteral {‘instance_norm’} or None

Normalization module to be used. If ‘instance_norm’, instance normalization is applied to the token outputs of the attention module. Defaults to ‘instance_norm’

temperaturefloat

Temperature parameter for the attention mechanism.

nonlinear_attentionbool, optional

Whether to use non-linear activation for K, Q, V operator.

scaleint

Scale for downsampling Q, K functions before calculating the attention matrix. Higher scale will downsample more.

output_scaling_factorfloat

Scaling factor for the output.

Methods

compute_attention(tokens, batch_size)

Compute the key-query-value variant of the attention matrix for input token functions.

forward(x)

CoDANO's forward pass.

Other Parameters:
incremental_n_modeslist

Incremental number of modes for each dimension (for incremental training).

use_channel_mlpbool, optional

Whether to use MLP layers to parameterize skip connections. Default is True.

channel_mlp_expansionfloat, optional

Expansion parameter for self.channel_mlp. Default is 0.5.

non_linearitycallable

Non-linearity function to be used.

preactivationbool, optional

Whether to use preactivation. Default is False.

fno_skipstr, optional

Type of skip connection to be used. Default is ‘linear’.

channel_mlp_skipstr, optional

Module to use for ChannelMLP skip connections. Default is “linear”.

separablebool, optional

Whether to use separable convolutions. Default is False.

factorizationstr, optional

Type of factorization to be used. Default is ‘tucker’.

rankfloat, optional

Rank of the factorization. Default is 1.0.

conv_modulecallable

Spectral convolution module to be used.

joint_factorizationbool, optional

Whether to factorize all spectralConv weights as one tensor. Default is False.

References

compute_attention(tokens, batch_size)[source]

Compute the key-query-value variant of the attention matrix for input token functions.

Parameters:
tokenstorch.Tensor

Input tokens with shape (b * t, d, h, w, …), where: b is the batch size, t is the number of tokens, d is the token codimension, and h, w, … are the domain dimensions. Assumes input tokens have been normalized.

batch_sizeint

The size of the batch.

forward(x)[source]

CoDANO’s forward pass.

  • If self.permutation_eq == True, computes the permutation-equivariant forward pass, where the mixer FNO block is applied to each token separately, making the final result equivariant to any permutation of tokens.

  • If self.permutation_eq == True, the mixer is applied to the whole function together, and tokens are treated as channels within the same function.

Parameters:
xtorch.Tensor

Input tensor with shape (b, t * d, h, w, …), where b is the batch size, t is the number of tokens, and d is the token codimension.