import torch
from torch import nn
[docs]
def skip_connection(
in_features, out_features, n_dim=2, bias=False, skip_type="soft-gating"
):
"""A wrapper for several types of skip connections.
Returns an nn.Module skip connections, one of {'identity', 'linear', soft-gating'}
Parameters
----------
in_features : int
number of input features
out_features : int
number of output features
n_dim : int, default is 2
Dimensionality of the input (excluding batch-size and channels).
``n_dim=2`` corresponds to having Module2D.
bias : bool, optional
whether to use a bias, by default False
skip_type : {'identity', 'linear', soft-gating'}
kind of skip connection to use, by default "soft-gating"
Returns
-------
nn.Module
module that takes in x and returns skip(x)
"""
if skip_type.lower() == "soft-gating":
return SoftGating(
in_features=in_features,
out_features=out_features,
bias=bias,
n_dim=n_dim,
)
elif skip_type.lower() == "linear":
return Flattened1dConv(in_channels=in_features,
out_channels=out_features,
kernel_size=1,
bias=bias,)
elif skip_type.lower() == "identity":
return nn.Identity()
else:
raise ValueError(
f"Got skip-connection type={skip_type}, expected one of"
f" {'soft-gating', 'linear', 'id'}."
)
[docs]
class SoftGating(nn.Module):
"""Applies soft-gating by weighting the channels of the given input
Given an input x of size `(batch-size, channels, height, width)`,
this returns `x * w `
where w is of shape `(1, channels, 1, 1)`
Parameters
----------
in_features : int
out_features : None
this is provided for API compatibility with nn.Linear only
n_dim : int, default is 2
Dimensionality of the input (excluding batch-size and channels).
``n_dim=2`` corresponds to having Module2D.
bias : bool, default is False
"""
def __init__(self, in_features, out_features=None, n_dim=2, bias=False):
super().__init__()
if out_features is not None and in_features != out_features:
raise ValueError(
f"Got in_features={in_features} and out_features={out_features}, "
"but these two must be the same for soft-gating"
)
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.ones(1, self.in_features, *(1,) * n_dim))
if bias:
self.bias = nn.Parameter(torch.ones(1, self.in_features, *(1,) * n_dim))
else:
self.bias = None
[docs]
def forward(self, x):
"""Applies soft-gating to a batch of activations"""
if self.bias is not None:
return self.weight * x + self.bias
else:
return self.weight * x
class Flattened1dConv(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size, bias=False):
"""Flattened3dConv is a Conv-based skip layer for
input tensors of ndim > 3 (batch, channels, d1, ...) that flattens all dimensions
past the batch and channel dims into one dimension, applies the Conv,
and un-flattens.
Parameters
----------
in_channels : int
in_channels of Conv1d
out_channels : int
out_channels of Conv1d
kernel_size : int
kernel_size of Conv1d
bias : bool, optional
bias of Conv3d, by default False
"""
super().__init__()
self.conv = nn.Conv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
bias=bias)
def forward(self, x):
# x.shape: b, c, x1, ..., xn x_ndim > 1
size = list(x.shape)
# flatten everything past 1st data dim
x = x.view(*size[:2], -1)
x = self.conv(x)
# reshape x into an Nd tensor b, c, x1, x2, ...
x = x.view(size[0], self.conv.out_channels, *size[2:])
return x