Source code for neuralop.layers.skip_connections

import torch
from torch import nn


[docs] def skip_connection( in_features, out_features, n_dim=2, bias=False, skip_type="soft-gating" ): """A wrapper for several types of skip connections. Returns an nn.Module skip connections, one of {'identity', 'linear', soft-gating'} Parameters ---------- in_features : int number of input features out_features : int number of output features n_dim : int, default is 2 Dimensionality of the input (excluding batch-size and channels). ``n_dim=2`` corresponds to having Module2D. bias : bool, optional whether to use a bias, by default False skip_type : {'identity', 'linear', soft-gating'} kind of skip connection to use, by default "soft-gating" Returns ------- nn.Module module that takes in x and returns skip(x) """ if skip_type.lower() == "soft-gating": return SoftGating( in_features=in_features, out_features=out_features, bias=bias, n_dim=n_dim, ) elif skip_type.lower() == "linear": return Flattened1dConv(in_channels=in_features, out_channels=out_features, kernel_size=1, bias=bias,) elif skip_type.lower() == "identity": return nn.Identity() else: raise ValueError( f"Got skip-connection type={skip_type}, expected one of" f" {'soft-gating', 'linear', 'id'}." )
[docs] class SoftGating(nn.Module): """Applies soft-gating by weighting the channels of the given input Given an input x of size `(batch-size, channels, height, width)`, this returns `x * w ` where w is of shape `(1, channels, 1, 1)` Parameters ---------- in_features : int out_features : None this is provided for API compatibility with nn.Linear only n_dim : int, default is 2 Dimensionality of the input (excluding batch-size and channels). ``n_dim=2`` corresponds to having Module2D. bias : bool, default is False """ def __init__(self, in_features, out_features=None, n_dim=2, bias=False): super().__init__() if out_features is not None and in_features != out_features: raise ValueError( f"Got in_features={in_features} and out_features={out_features}" "but these two must be the same for soft-gating" ) self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.ones(1, self.in_features, *(1,) * n_dim)) if bias: self.bias = nn.Parameter(torch.ones(1, self.in_features, *(1,) * n_dim)) else: self.bias = None
[docs] def forward(self, x): """Applies soft-gating to a batch of activations""" if self.bias is not None: return self.weight * x + self.bias else: return self.weight * x
class Flattened1dConv(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size, bias=False): """Flattened3dConv is a Conv-based skip layer for input tensors of ndim > 3 (batch, channels, d1, ...) that flattens all dimensions past the batch and channel dims into one dimension, applies the Conv, and un-flattens. Parameters ---------- in_channels : int in_channels of Conv1d out_channels : int out_channels of Conv1d kernel_size : int kernel_size of Conv1d bias : bool, optional bias of Conv3d, by default False """ super().__init__() self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias) def forward(self, x): # x.shape: b, c, x1, ..., xn x_ndim > 1 size = list(x.shape) # flatten everything past 1st data dim x = x.view(*size[:2], -1) x = self.conv(x) # reshape x into an Nd tensor b, c, x1, x2, ... x = x.view(size[0], self.conv.out_channels, *size[2:]) return x