Source code for neuralop.layers.embeddings

from abc import ABC, abstractmethod
from typing import List

import torch
from torch import nn


class Embedding(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @property
    @abstractmethod
    def out_channels(self):
        pass


[docs] class GridEmbedding2D(Embedding): """GridEmbedding2D applies a simple positional embedding as a regular 2D grid. It expects inputs of shape (batch, channels, d_1, d_2) Parameters ---------- in_channels : int number of channels in input. Fixed for output channel interface grid_boundaries : list, optional coordinate boundaries of input grid, by default [[0, 1], [0, 1]] """ def __init__(self, in_channels: int, grid_boundaries=[[0, 1], [0, 1]]): super().__init__() self.in_channels = in_channels self.grid_boundaries = grid_boundaries self._grid = None self._res = None @property def out_channels(self): return self.in_channels + 2
[docs] def grid(self, spatial_dims, device, dtype): """grid generates 2D grid needed for pos encoding and caches the grid associated with MRU resolution Parameters ---------- spatial_dims : torch.size sizes of spatial resolution device : literal 'cpu' or 'cuda:*' where to load data dtype : str dtype to encode data Returns ------- torch.tensor output grids to concatenate """ # handle case of multiple train resolutions if self._grid is None or self._res != spatial_dims: grid_x, grid_y = regular_grid_2d( spatial_dims, grid_boundaries=self.grid_boundaries ) grid_x = grid_x.to(device).to(dtype).unsqueeze(0).unsqueeze(0) grid_y = grid_y.to(device).to(dtype).unsqueeze(0).unsqueeze(0) self._grid = grid_x, grid_y self._res = spatial_dims return self._grid
[docs] def forward(self, data, batched=True): if not batched: if data.ndim == 3: data = data.unsqueeze(0) batch_size = data.shape[0] x, y = self.grid(data.shape[-2:], data.device, data.dtype) out = torch.cat( (data, x.expand(batch_size, -1, -1, -1), y.expand(batch_size, -1, -1, -1)), dim=1, ) # in the unbatched case, the dataloader will stack N # examples with no batch dim to create one if not batched and batch_size == 1: return out.squeeze(0) else: return out
[docs] class GridEmbeddingND(nn.Module): """GridEmbeddingND applies a simple positional embedding as a regular ND grid. It expects inputs of shape (batch, channels, d_1, ..., d_n) Parameters ---------- in_channels : int number of channels in input dim : int dimensions of positional encoding to apply grid_boundaries : list, optional coordinate boundaries of input grid along each dim, by default [[0, 1], [0, 1]] """ def __init__(self, in_channels: int, dim: int=2, grid_boundaries=[[0, 1], [0, 1]]): super().__init__() self.in_channels = in_channels self.dim = dim assert self.dim == len( grid_boundaries ), f"Error: expected grid_boundaries to be an iterable of length {self.dim}, received {grid_boundaries}" self.grid_boundaries = grid_boundaries self._grid = None self._res = None @property def out_channels(self): return self.in_channels + self.dim
[docs] def grid(self, spatial_dims: torch.Size, device: str, dtype: torch.dtype): """grid generates ND grid needed for pos encoding and caches the grid associated with MRU resolution Parameters ---------- spatial_dims : torch.Size sizes of spatial resolution device : literal 'cpu' or 'cuda:*' where to load data dtype : str dtype to encode data Returns ------- torch.tensor output grids to concatenate """ # handle case of multiple train resolutions if self._grid is None or self._res != spatial_dims: grids_by_dim = regular_grid_nd(spatial_dims, grid_boundaries=self.grid_boundaries) # add batch, channel dims grids_by_dim = [x.to(device).to(dtype).unsqueeze(0).unsqueeze(0) for x in grids_by_dim] self._grid = grids_by_dim self._res = spatial_dims return self._grid
[docs] def forward(self, data, batched=True): """ Params -------- data: torch.Tensor assumes shape (batch (optional), channels, x_1, x_2, ...x_n) batched: bool whether data has a batch dim """ # add batch dim if it doesn't exist if not batched: if data.ndim == self.dim + 1: data = data.unsqueeze(0) batch_size = data.shape[0] grids = self.grid(spatial_dims=data.shape[2:], device=data.device, dtype=data.dtype) grids = [x.repeat(batch_size, *[1] * (self.dim + 1)) for x in grids] out = torch.cat((data, *grids), dim=1) return out
[docs] class SinusoidalEmbedding(Embedding): """ Sinusoidal positional embedding for enriching coordinate inputs with spectral information. This class provides sinusoidal positional embeddings in two styles: Transformer-style and NeRF-style. It lifts low-dimensional coordinates into a richer spectral representation by encoding them as periodic functions (sines and cosines) at multiple frequencies. The embedding enhances a model's ability to capture fine-scale variations and high-frequency dynamics by providing a hierarchy of frequency components alongside the original coordinates. Parameters ---------- in_channels : int Number of input channels to embed (dimensionality of input coordinates) num_freqs : int, optional Number of frequency levels L in the embedding. Each level contributes a sine and cosine pair, resulting in 2L output channels per input channel. By default, set to the number of input channels. embedding_type : {'transformer', 'nerf'}, optional Type of embedding to apply, by default 'transformer' Transformer-style [1]_: For each input coordinate p and frequency level k (0 ≤ k < L): - g(p)_{2k} = sin(p / max_positions^{k/L}) - g(p)_{2k+1} = cos(p / max_positions^{k/L}) NeRF-style [2]_: For each input coordinate p and frequency level k (0 ≤ k < L): - g(p)_{2k} = sin(2^k * π * p) - g(p)_{2k+1} = cos(2^k * π * p) max_positions : int, optional Maximum number of positions for transformer-style encoding, by default 10000. Only used when embedding_type='transformer'. Notes ----- - Input shape: (batch, n_in, in_channels) or (n_in, in_channels) - Output shape: (batch, n_in, 2*num_freqs*in_channels) or (n_in, 2*num_freqs*in_channels) - Ensure the highest frequency satisfies the Nyquist criterion: - Transformer: f_max < N/2 where N is the number of sampling points - NeRF: 2^{L-1} < N/2, i.e., L < 1 + log₂(N/2) Examples -------- See `examples/layers/plot_sinusoidal_embeddings.py` for comprehensive visualizations References ---------- .. [1] Vaswani, A. et al. "Attention Is All You Need". NeurIPS 2017, https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf .. [2] Mildenhall, B. et al. "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis". ArXiv 2020, https://arxiv.org/pdf/2003.08934 """ def __init__( self, in_channels: int, num_frequencies: int = None, embedding_type: str = "transformer", max_positions: int = 10000, ): super().__init__() self.in_channels = in_channels self.num_frequencies = num_frequencies # verify embedding type allowed_embeddings = ["nerf", "transformer"] assert ( embedding_type in allowed_embeddings ), f"Error: embedding_type expected one of {allowed_embeddings}, received {embedding_type}" self.embedding_type = embedding_type if self.embedding_type == "transformer": assert ( max_positions is not None ), "Error: max_positions must have an int value for \ transformer embedding." self.max_positions = max_positions @property def out_channels(self): """ required property for linking/composing model layers """ return 2 * self.num_frequencies * self.in_channels
[docs] def forward(self, x): """ Parameters ----------- x: torch.Tensor shape (n_in, self.in_channels) or (batch, n_in, self.in_channels) """ assert x.ndim in [2,3], f"Error: expected inputs of shape (batch, n_in, {self.in_channels})\ or (n_in, channels), got inputs with ndim={x.ndim}, shape={x.shape}" if x.ndim == 2: batched = False x = x.unsqueeze(0) else: batched = True batch_size, n_in, _ = x.shape if self.embedding_type == "nerf": freqs = 2 ** torch.arange(0, self.num_frequencies, device=x.device) * torch.pi elif self.embedding_type == "transformer": freqs = torch.arange(0, self.num_frequencies, device=x.device) / self.num_frequencies * 2 freqs = (1 / self.max_positions) ** freqs # outer product of wavenumbers and position coordinates # shape b, n_in * channels, len(freqs) freqs = torch.einsum("bij, k -> bijk", x, freqs) # shape len(x), 2, len(freqs) freqs = torch.stack((freqs.sin(), freqs.cos()), dim=-1) # transpose the inner per-entry matrix and ravel to interleave sin and cos freqs = freqs.view(batch_size, n_in, -1) if not batched: freqs = freqs.squeeze(0) return freqs
class RotaryEmbedding2D(nn.Module): def __init__(self, dim, min_freq=1 / 64, scale=1.0): """ Applying rotary positional embedding (https://arxiv.org/abs/2104.09864) to the input feature tensor. The crux is the dot product of two rotation matrices R(theta1) and R(theta2) is equal to R(theta2 - theta1). """ super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.min_freq = min_freq self.scale = scale self.register_buffer("inv_freq", inv_freq, persistent=False) self.out_channels = 2 def forward(self, coordinates): """coordinates is tensor of [batch_size, num_points]""" coordinates = coordinates * (self.scale / self.min_freq) freqs = torch.einsum("... i , j -> ... i j", coordinates, self.inv_freq) # [b, n, d//2] return torch.cat((freqs, freqs), dim=-1) # [b, n, d] @staticmethod def apply_1d_rotary_pos_emb(t, freqs): return apply_rotary_pos_emb(t, freqs) @staticmethod def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y): """Split the last dimension of features into two equal halves and apply 1d rotary positional embedding to each half.""" d = t.shape[-1] t_x, t_y = t[..., : d // 2], t[..., d // 2 :] return torch.cat( (apply_rotary_pos_emb(t_x, freqs_x), apply_rotary_pos_emb(t_y, freqs_y)), dim=-1, ) # Utility functions for GridEmbedding def regular_grid_2d(spatial_dims, grid_boundaries=[[0, 1], [0, 1]]): """ Creates a 2 x height x width stack of positional encodings A, where A[:,i,j] = [[x,y]] at coordinate (i,j) on a (height, width) grid. """ height, width = spatial_dims xt = torch.linspace(grid_boundaries[0][0], grid_boundaries[0][1], height + 1)[:-1] yt = torch.linspace(grid_boundaries[1][0], grid_boundaries[1][1], width + 1)[:-1] grid_x, grid_y = torch.meshgrid(xt, yt, indexing="ij") grid_x = grid_x.repeat(1, 1) grid_y = grid_y.repeat(1, 1) return grid_x, grid_y def regular_grid_nd( resolutions: List[int], grid_boundaries: List[List[int]] = [[0, 1]] * 2 ): """regular_grid_nd generates a tensor of coordinate points that describe a bounded regular grid. Creates a dim x res_d1 x ... x res_dn stack of positional encodings A, where A[:,c1,c2,...] = [[d1,d2,...dn]] at coordinate (c1,c2,...cn) on a (res_d1, ...res_dn) grid. Parameters ---------- resolutions : List[int] resolution of the output grid along each dimension grid_boundaries : List[List[int]], optional List of pairs [start, end] of the boundaries of the regular grid. Must correspond 1-to-1 with resolutions default [[0,1], [0,1]] Returns ------- grid: tuple(Tensor) list of tensors describing positional encoding """ assert len(resolutions) == len( grid_boundaries ), "Error: inputs must have same number of dimensions" dim = len(resolutions) meshgrid_inputs = list() for res, (start, stop) in zip(resolutions, grid_boundaries): meshgrid_inputs.append(torch.linspace(start, stop, res + 1)[:-1]) grid = torch.meshgrid(*meshgrid_inputs, indexing="ij") grid = tuple([x.repeat([1] * dim) for x in grid]) return grid # Utility fucntions for Rotary embedding # modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py def rotate_half(x): """ Split x's channels into two equal halves. """ # split the last dimension of x into two equal halves x = x.reshape(*x.shape[:-1], 2, -1) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(t, freqs): """ Apply rotation matrix computed based on freqs to rotate t. t: tensor of shape [batch_size, num_points, dim] freqs: tensor of shape [batch_size, num_points, 1] Formula: see equation (34) in https://arxiv.org/pdf/2104.09864.pdf """ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())