Source code for neuralop.layers.channel_mlp

import torch
from torch import nn
import torch.nn.functional as F


[docs] class ChannelMLP(nn.Module): """Multi-layer perceptron applied channel-wise across spatial dimensions. ChannelMLP applies a series of 1D convolutions and nonlinearities to the channel dimension of input tensors, making it invariant to spatial resolution. This is particularly useful in neural operators where the spatial dimensions may vary but the channel processing should remain consistent. The implementation uses 1D convolutions with kernel size 1, which effectively performs linear transformations on the channel dimension while preserving spatial structure. This approach is more efficient than reshaping to 2D and using fully connected layers. Parameters ---------- in_channels : int Number of input channels out_channels : int, optional Number of output channels. If None, defaults to in_channels. hidden_channels : int, optional Number of hidden channels in intermediate layers. If None, defaults to in_channels. n_layers : int, optional Number of linear layers in the MLP, by default 2 n_dim : int, optional Spatial dimension of input (unused but kept for compatibility), by default 2 non_linearity : callable, optional Nonlinear activation function to apply between layers, by default F.gelu dropout : float, optional Dropout probability applied after each layer (except the last). If 0, no dropout is applied, by default 0.0 """ def __init__( self, in_channels, out_channels=None, hidden_channels=None, n_layers=2, n_dim=2, non_linearity=F.gelu, dropout=0.0, ): super().__init__() self.n_layers = n_layers self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.hidden_channels = ( in_channels if hidden_channels is None else hidden_channels ) self.non_linearity = non_linearity self.dropout = ( nn.ModuleList([nn.Dropout(dropout) for _ in range(n_layers)]) if dropout > 0.0 else None ) # Build the MLP layers using 1D convolutions with kernel size 1 # This effectively performs linear transformations on the channel dimension # while preserving spatial structure and being more efficient than FC layers self.fcs = nn.ModuleList() for i in range(n_layers): if i == 0 and i == (n_layers - 1): # Single layer: input -> output self.fcs.append(nn.Conv1d(self.in_channels, self.out_channels, 1)) elif i == 0: # First layer: input -> hidden self.fcs.append(nn.Conv1d(self.in_channels, self.hidden_channels, 1)) elif i == (n_layers - 1): # Last layer: hidden -> output self.fcs.append(nn.Conv1d(self.hidden_channels, self.out_channels, 1)) else: # Internal layers: hidden -> hidden self.fcs.append( nn.Conv1d(self.hidden_channels, self.hidden_channels, 1) )
[docs] def forward(self, x): """ Forward pass through the channel MLP. Parameters ---------- x : torch.Tensor Input tensor of shape (batch, in_channels, *spatial_dims) Returns ------- torch.Tensor Output tensor of shape (batch, out_channels, *spatial_dims) """ reshaped = False size = list(x.shape) # Handle high-dimensional inputs (4D+) by flattening spatial dimensions # This allows the 1D convolutions to process all spatial positions uniformly if x.ndim > 3: # Flatten spatial dimensions: (batch, channels, x1, x2, ...) -> (batch, channels, -1) # Use reshape() instead of view() to handle non-contiguous tensors x = x.reshape((*size[:2], -1)) reshaped = True # Apply MLP layers with nonlinearity and dropout for i, fc in enumerate(self.fcs): x = fc(x) # Linear transformation (1D conv with kernel size 1) if i < self.n_layers - 1: # Apply nonlinearity to all layers except the last x = self.non_linearity(x) if self.dropout is not None: x = self.dropout[i](x) # Restore original spatial dimensions if input was reshaped if reshaped: x = x.reshape((size[0], self.out_channels, *size[2:])) return x
class LinearChannelMLP(torch.nn.Module): """ Multi-layer perceptron using fully connected layers for channel processing. This is an alternative implementation of ChannelMLP that uses standard Linear layers instead of 1D convolutions. Parameters ---------- layers : list of int List defining the architecture: [in_channels, hidden1, hidden2, ..., out_channels] Must have at least 2 elements (input and output channels) non_linearity : callable, optional Nonlinear activation function to apply between layers, by default F.gelu dropout : float, optional Dropout probability applied after each layer (except the last). If 0, no dropout is applied, by default 0.0 """ def __init__(self, layers, non_linearity=F.gelu, dropout=0.0): super().__init__() self.n_layers = len(layers) - 1 assert ( self.n_layers >= 1 ), "Error: trying to instantiate \ a LinearChannelMLP with only one linear layer." self.fcs = nn.ModuleList() self.non_linearity = non_linearity self.dropout = ( nn.ModuleList([nn.Dropout(dropout) for _ in range(self.n_layers)]) if dropout > 0.0 else None ) # Build linear layers based on the provided architecture for j in range(self.n_layers): self.fcs.append(nn.Linear(layers[j], layers[j + 1])) def forward(self, x): """ Forward pass through the linear channel MLP. Parameters ---------- x : torch.Tensor Input tensor of shape (batch, in_channels) or (batch*spatial, in_channels) Note: Input must be pre-reshaped to 2D format Returns ------- torch.Tensor Output tensor of shape (batch, out_channels) or (batch*spatial, out_channels) Note: Output needs to be reshaped back to spatial dimensions if needed """ # Apply linear layers with nonlinearity and dropout for i, fc in enumerate(self.fcs): x = fc(x) # Linear transformation if i < self.n_layers - 1: # Apply nonlinearity to all layers except the last x = self.non_linearity(x) if self.dropout is not None: x = self.dropout[i](x) return x