from abc import ABC, abstractmethod
from typing import List
import torch
from torch import nn
class Embedding(nn.Module, ABC):
def __init__(self):
super().__init__()
@property
@abstractmethod
def out_channels(self):
pass
[docs]
class GridEmbedding2D(Embedding):
"""GridEmbedding2D applies a simple positional embedding as a regular 2D grid.
It expects inputs of shape (batch, channels, d_1, d_2)
Parameters
----------
in_channels : int
number of channels in input. Fixed for output channel interface
grid_boundaries : list, optional
coordinate boundaries of input grid, by default [[0, 1], [0, 1]]
"""
def __init__(self, in_channels: int, grid_boundaries=[[0, 1], [0, 1]]):
super().__init__()
self.in_channels = in_channels
self.grid_boundaries = grid_boundaries
self._grid = None
self._res = None
@property
def out_channels(self):
return self.in_channels + 2
[docs]
def grid(self, spatial_dims, device, dtype):
"""grid generates 2D grid needed for pos encoding
and caches the grid associated with MRU resolution
Parameters
----------
spatial_dims : torch.size
sizes of spatial resolution
device : literal 'cpu' or 'cuda:*'
where to load data
dtype : str
dtype to encode data
Returns
-------
torch.tensor
output grids to concatenate
"""
# handle case of multiple train resolutions
if self._grid is None or self._res != spatial_dims:
grid_x, grid_y = regular_grid_2d(
spatial_dims, grid_boundaries=self.grid_boundaries
)
grid_x = grid_x.to(device).to(dtype).unsqueeze(0).unsqueeze(0)
grid_y = grid_y.to(device).to(dtype).unsqueeze(0).unsqueeze(0)
self._grid = grid_x, grid_y
self._res = spatial_dims
return self._grid
[docs]
def forward(self, data, batched=True):
if not batched:
if data.ndim == 3:
data = data.unsqueeze(0)
batch_size = data.shape[0]
x, y = self.grid(data.shape[-2:], data.device, data.dtype)
out = torch.cat(
(data, x.expand(batch_size, -1, -1, -1), y.expand(batch_size, -1, -1, -1)),
dim=1,
)
# in the unbatched case, the dataloader will stack N
# examples with no batch dim to create one
if not batched and batch_size == 1:
return out.squeeze(0)
else:
return out
[docs]
class GridEmbeddingND(nn.Module):
"""GridEmbeddingND applies a simple positional embedding as a regular ND grid.
It expects inputs of shape (batch, channels, d_1, ..., d_n)
Parameters
----------
in_channels : int
number of channels in input
dim : int
dimensions of positional encoding to apply
grid_boundaries : list, optional
coordinate boundaries of input grid along each dim, by default [[0, 1], [0, 1]]
"""
def __init__(self, in_channels: int, dim: int=2, grid_boundaries=[[0, 1], [0, 1]]):
super().__init__()
self.in_channels = in_channels
self.dim = dim
assert self.dim == len(
grid_boundaries
), f"Error: expected grid_boundaries to be an iterable of length {self.dim}, received {grid_boundaries}"
self.grid_boundaries = grid_boundaries
self._grid = None
self._res = None
@property
def out_channels(self):
return self.in_channels + self.dim
[docs]
def grid(self, spatial_dims: torch.Size, device: str, dtype: torch.dtype):
"""grid generates ND grid needed for pos encoding
and caches the grid associated with MRU resolution
Parameters
----------
spatial_dims : torch.Size
sizes of spatial resolution
device : literal 'cpu' or 'cuda:*'
where to load data
dtype : str
dtype to encode data
Returns
-------
torch.tensor
output grids to concatenate
"""
# handle case of multiple train resolutions
if self._grid is None or self._res != spatial_dims:
grids_by_dim = regular_grid_nd(spatial_dims, grid_boundaries=self.grid_boundaries)
# add batch, channel dims
grids_by_dim = [x.to(device).to(dtype).unsqueeze(0).unsqueeze(0) for x in grids_by_dim]
self._grid = grids_by_dim
self._res = spatial_dims
return self._grid
[docs]
def forward(self, data, batched=True):
"""
Params
--------
data: torch.Tensor
assumes shape (batch (optional), channels, x_1, x_2, ...x_n)
batched: bool
whether data has a batch dim
"""
# add batch dim if it doesn't exist
if not batched:
if data.ndim == self.dim + 1:
data = data.unsqueeze(0)
batch_size = data.shape[0]
grids = self.grid(spatial_dims=data.shape[2:], device=data.device, dtype=data.dtype)
grids = [x.repeat(batch_size, *[1] * (self.dim + 1)) for x in grids]
out = torch.cat((data, *grids), dim=1)
return out
[docs]
class SinusoidalEmbedding(Embedding):
"""
Sinusoidal positional embedding for enriching coordinate inputs with spectral information.
This class provides sinusoidal positional embeddings in two styles: Transformer-style
and NeRF-style. It lifts low-dimensional coordinates into a richer spectral representation
by encoding them as periodic functions (sines and cosines) at multiple frequencies.
The embedding enhances a model's ability to capture fine-scale variations and high-frequency
dynamics by providing a hierarchy of frequency components alongside the original coordinates.
Parameters
----------
in_channels : int
Number of input channels to embed (dimensionality of input coordinates)
num_freqs : int, optional
Number of frequency levels L in the embedding. Each level contributes
a sine and cosine pair, resulting in 2L output channels per input channel.
By default, set to the number of input channels.
embedding_type : {'transformer', 'nerf'}, optional
Type of embedding to apply, by default 'transformer'
Transformer-style [1]_:
For each input coordinate p and frequency level k (0 ≤ k < L):
- g(p)_{2k} = sin(p / max_positions^{k/L})
- g(p)_{2k+1} = cos(p / max_positions^{k/L})
NeRF-style [2]_:
For each input coordinate p and frequency level k (0 ≤ k < L):
- g(p)_{2k} = sin(2^k * π * p)
- g(p)_{2k+1} = cos(2^k * π * p)
max_positions : int, optional
Maximum number of positions for transformer-style encoding, by default 10000.
Only used when embedding_type='transformer'.
Notes
-----
- Input shape: (batch, n_in, in_channels) or (n_in, in_channels)
- Output shape: (batch, n_in, 2*num_freqs*in_channels) or (n_in, 2*num_freqs*in_channels)
- Ensure the highest frequency satisfies the Nyquist criterion:
- Transformer: f_max < N/2 where N is the number of sampling points
- NeRF: 2^{L-1} < N/2, i.e., L < 1 + log₂(N/2)
Examples
--------
See `examples/layers/plot_sinusoidal_embeddings.py` for comprehensive visualizations
References
----------
.. [1] Vaswani, A. et al. "Attention Is All You Need".
NeurIPS 2017, https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
.. [2] Mildenhall, B. et al. "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis".
ArXiv 2020, https://arxiv.org/pdf/2003.08934
"""
def __init__(
self,
in_channels: int,
num_frequencies: int = None,
embedding_type: str = "transformer",
max_positions: int = 10000,
):
super().__init__()
self.in_channels = in_channels
self.num_frequencies = num_frequencies
# verify embedding type
allowed_embeddings = ["nerf", "transformer"]
assert (
embedding_type in allowed_embeddings
), f"Error: embedding_type expected one of {allowed_embeddings}, received {embedding_type}"
self.embedding_type = embedding_type
if self.embedding_type == "transformer":
assert (
max_positions is not None
), "Error: max_positions must have an int value for \
transformer embedding."
self.max_positions = max_positions
@property
def out_channels(self):
"""
required property for linking/composing model layers
"""
return 2 * self.num_frequencies * self.in_channels
[docs]
def forward(self, x):
"""
Parameters
-----------
x: torch.Tensor
shape (n_in, self.in_channels) or (batch, n_in, self.in_channels)
"""
assert x.ndim in [2,3], f"Error: expected inputs of shape (batch, n_in, {self.in_channels})\
or (n_in, channels), got inputs with ndim={x.ndim}, shape={x.shape}"
if x.ndim == 2:
batched = False
x = x.unsqueeze(0)
else:
batched = True
batch_size, n_in, _ = x.shape
if self.embedding_type == "nerf":
freqs = 2 ** torch.arange(0, self.num_frequencies, device=x.device) * torch.pi
elif self.embedding_type == "transformer":
freqs = torch.arange(0, self.num_frequencies, device=x.device) / self.num_frequencies * 2
freqs = (1 / self.max_positions) ** freqs
# outer product of wavenumbers and position coordinates
# shape b, n_in * channels, len(freqs)
freqs = torch.einsum("bij, k -> bijk", x, freqs)
# shape len(x), 2, len(freqs)
freqs = torch.stack((freqs.sin(), freqs.cos()), dim=-1)
# transpose the inner per-entry matrix and ravel to interleave sin and cos
freqs = freqs.view(batch_size, n_in, -1)
if not batched:
freqs = freqs.squeeze(0)
return freqs
class RotaryEmbedding2D(nn.Module):
def __init__(self, dim, min_freq=1 / 64, scale=1.0):
"""
Applying rotary positional embedding (https://arxiv.org/abs/2104.09864) to the input feature tensor.
The crux is the dot product of two rotation matrices R(theta1) and R(theta2) is equal to R(theta2 - theta1).
"""
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.min_freq = min_freq
self.scale = scale
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.out_channels = 2
def forward(self, coordinates):
"""coordinates is tensor of [batch_size, num_points]"""
coordinates = coordinates * (self.scale / self.min_freq)
freqs = torch.einsum("... i , j -> ... i j", coordinates, self.inv_freq) # [b, n, d//2]
return torch.cat((freqs, freqs), dim=-1) # [b, n, d]
@staticmethod
def apply_1d_rotary_pos_emb(t, freqs):
return apply_rotary_pos_emb(t, freqs)
@staticmethod
def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y):
"""Split the last dimension of features into two equal halves
and apply 1d rotary positional embedding to each half."""
d = t.shape[-1]
t_x, t_y = t[..., : d // 2], t[..., d // 2 :]
return torch.cat(
(apply_rotary_pos_emb(t_x, freqs_x), apply_rotary_pos_emb(t_y, freqs_y)),
dim=-1,
)
# Utility functions for GridEmbedding
def regular_grid_2d(spatial_dims, grid_boundaries=[[0, 1], [0, 1]]):
"""
Creates a 2 x height x width stack of positional encodings A, where
A[:,i,j] = [[x,y]] at coordinate (i,j) on a (height, width) grid.
"""
height, width = spatial_dims
xt = torch.linspace(grid_boundaries[0][0], grid_boundaries[0][1], height + 1)[:-1]
yt = torch.linspace(grid_boundaries[1][0], grid_boundaries[1][1], width + 1)[:-1]
grid_x, grid_y = torch.meshgrid(xt, yt, indexing="ij")
grid_x = grid_x.repeat(1, 1)
grid_y = grid_y.repeat(1, 1)
return grid_x, grid_y
def regular_grid_nd(
resolutions: List[int], grid_boundaries: List[List[int]] = [[0, 1]] * 2
):
"""regular_grid_nd generates a tensor of coordinate points that
describe a bounded regular grid.
Creates a dim x res_d1 x ... x res_dn stack of positional encodings A, where
A[:,c1,c2,...] = [[d1,d2,...dn]] at coordinate (c1,c2,...cn) on a (res_d1, ...res_dn) grid.
Parameters
----------
resolutions : List[int]
resolution of the output grid along each dimension
grid_boundaries : List[List[int]], optional
List of pairs [start, end] of the boundaries of the
regular grid. Must correspond 1-to-1 with resolutions default [[0,1], [0,1]]
Returns
-------
grid: tuple(Tensor)
list of tensors describing positional encoding
"""
assert len(resolutions) == len(
grid_boundaries
), "Error: inputs must have same number of dimensions"
dim = len(resolutions)
meshgrid_inputs = list()
for res, (start, stop) in zip(resolutions, grid_boundaries):
meshgrid_inputs.append(torch.linspace(start, stop, res + 1)[:-1])
grid = torch.meshgrid(*meshgrid_inputs, indexing="ij")
grid = tuple([x.repeat([1] * dim) for x in grid])
return grid
# Utility fucntions for Rotary embedding
# modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
def rotate_half(x):
"""
Split x's channels into two equal halves.
"""
# split the last dimension of x into two equal halves
x = x.reshape(*x.shape[:-1], 2, -1)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
"""
Apply rotation matrix computed based on freqs to rotate t.
t: tensor of shape [batch_size, num_points, dim]
freqs: tensor of shape [batch_size, num_points, 1]
Formula: see equation (34) in https://arxiv.org/pdf/2104.09864.pdf
"""
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())