from typing import List, Optional, Tuple, Union
from ..utils import validate_scaling_factor
import torch
from torch import nn
import tensorly as tl
from tensorly.plugins import use_opt_einsum
from tltorch.factorized_tensors.core import FactorizedTensor
from .einsum_utils import einsum_complexhalf
from .base_spectral_conv import BaseSpectralConv
from .resample import resample
tl.set_backend("pytorch")
use_opt_einsum("optimal")
einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
def _contract_dense(x, weight, separable=False):
order = tl.ndim(x)
# batch-size, in_channels, x, y...
x_syms = list(einsum_symbols[:order])
# in_channels, out_channels, x, y...
weight_syms = list(x_syms[1:]) # no batch-size
# batch-size, out_channels, x, y...
if separable:
out_syms = [x_syms[0]] + list(weight_syms)
else:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
eq = f'{"".join(x_syms)},{"".join(weight_syms)}->{"".join(out_syms)}'
if not torch.is_tensor(weight):
weight = weight.to_tensor()
if x.dtype == torch.complex32:
# if x is half precision, run a specialized einsum
return einsum_complexhalf(eq, x, weight)
else:
return tl.einsum(eq, x, weight)
def _contract_dense_separable(x, weight, separable):
if not torch.is_tensor(weight):
weight = weight.to_tensor()
return x * weight
def _contract_cp(x, cp_weight, separable=False):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
rank_sym = einsum_symbols[order]
out_sym = einsum_symbols[order + 1]
out_syms = list(x_syms)
if separable:
factor_syms = [einsum_symbols[1] + rank_sym] # in only
else:
out_syms[1] = out_sym
factor_syms = [einsum_symbols[1] + rank_sym, out_sym + rank_sym] # in, out
factor_syms += [xs + rank_sym for xs in x_syms[2:]] # x, y, ...
eq = f'{x_syms},{rank_sym},{",".join(factor_syms)}->{"".join(out_syms)}'
if x.dtype == torch.complex32:
return einsum_complexhalf(eq, x, cp_weight.weights, *cp_weight.factors)
else:
return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors)
def _contract_tucker(x, tucker_weight, separable=False):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
out_sym = einsum_symbols[order]
out_syms = list(x_syms)
if separable:
core_syms = einsum_symbols[order + 1 : 2 * order]
# factor_syms = [einsum_symbols[1]+core_syms[0]] #in only
# x, y, ...
factor_syms = [xs + rs for (xs, rs) in zip(x_syms[1:], core_syms)]
else:
core_syms = einsum_symbols[order + 1 : 2 * order + 1]
out_syms[1] = out_sym
factor_syms = [
einsum_symbols[1] + core_syms[0],
out_sym + core_syms[1],
] # out, in
# x, y, ...
factor_syms += [xs + rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])]
eq = f'{x_syms},{core_syms},{",".join(factor_syms)}->{"".join(out_syms)}'
if x.dtype == torch.complex32:
return einsum_complexhalf(eq, x, tucker_weight.core, *tucker_weight.factors)
else:
return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors)
def _contract_tt(x, tt_weight, separable=False):
order = tl.ndim(x)
x_syms = list(einsum_symbols[:order])
weight_syms = list(x_syms[1:]) # no batch-size
if not separable:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
else:
out_syms = list(x_syms)
rank_syms = list(einsum_symbols[order + 1 :])
tt_syms = []
for i, s in enumerate(weight_syms):
tt_syms.append([rank_syms[i], s, rank_syms[i + 1]])
eq = (
"".join(x_syms)
+ ","
+ ",".join("".join(f) for f in tt_syms)
+ "->"
+ "".join(out_syms)
)
if x.dtype == torch.complex32:
return einsum_complexhalf(eq, x, *tt_weight.factors)
else:
return tl.einsum(eq, x, *tt_weight.factors)
def get_contract_fun(weight, implementation="reconstructed", separable=False):
"""Generic ND implementation of Fourier Spectral Conv contraction
Parameters
----------
weight : tensorly-torch's FactorizedTensor
implementation : {'reconstructed', 'factorized'}, default is 'reconstructed'
whether to reconstruct the weight and do a forward pass (reconstructed)
or contract directly the factors of the factorized weight with the input (factorized)
separable: bool
if True, performs contraction with individual tensor factors.
if False,
Returns
-------
function : (x, weight) -> x * weight in Fourier space
"""
if implementation == "reconstructed":
if separable:
return _contract_dense_separable
else:
return _contract_dense
elif implementation == "factorized":
if torch.is_tensor(weight):
return _contract_dense
elif isinstance(weight, FactorizedTensor):
if weight.name.lower().endswith("dense"):
return _contract_dense
elif weight.name.lower().endswith("tucker"):
return _contract_tucker
elif weight.name.lower().endswith("tt"):
return _contract_tt
elif weight.name.lower().endswith("cp"):
return _contract_cp
else:
raise ValueError(f"Got unexpected factorized weight type {weight.name}")
else:
raise ValueError(
f"Got unexpected weight type of class {weight.__class__.__name__}"
)
else:
raise ValueError(
f'Got implementation={implementation}, expected "reconstructed" or "factorized"'
)
Number = Union[int, float]
[docs]
class SpectralConv(BaseSpectralConv):
"""SpectralConv implements the Spectral Convolution component of a Fourier layer
described.
It is implemented as described in [1]_ and [2]_.
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
n_modes : int or int tuple
Number of modes to use for contraction in Fourier domain during training.
.. warning::
We take care of the redundancy in the Fourier modes, therefore, for an input
of size I_1, ..., I_N, please provide modes M_K that are I_1 < M_K <= I_N
We will automatically keep the right amount of modes: specifically, for the
last mode only, if you specify M_N modes we will use M_N // 2 + 1 modes
as the real FFT is redundant along that last dimension. See the theory guide for mode truncation details.
.. note::
Provided modes should be even integers. odd numbers will be rounded to the closest even number.
This can be updated dynamically during training.
complex_data : bool, optional
Whether data takes on complex values in the spatial domain, by default False.
If True, uses different logic for FFT contraction and uses full FFT instead of real-valued.
max_n_modes : int tuple or None, optional
If not None, maximum number of modes to keep in Fourier Layer along each dim
(n_modes cannot be increased beyond that). If None, all n_modes are used.
By default None.
bias : bool, optional
Whether to add a learnable bias to the output, by default True.
separable : bool, optional
Whether to use separable implementation of contraction.
If True, contracts factors of factorized tensor weight individually.
By default False.
resolution_scaling_factor : float, list of float, or None, optional
Scaling factor(s) for resolution scaling. If provided, the output resolution
will be scaled by this factor along each spatial dimension.
By default None.
fno_block_precision : str, optional
Precision mode for FNO block operations. Options: 'full', 'half', 'mixed'.
By default 'full'.
rank : float, optional
Rank of the tensor factorization of the Fourier weights, by default 1.0.
Ignored if ``factorization is None``.
factorization : str or None, optional
Tensor factorization type. Options: {'tucker', 'cp', 'tt'}.
If None, a single dense weight is learned for the FNO.
Otherwise, that weight, used for the contraction in the Fourier domain
is learned in factorized form. In that case, `factorization` is the
tensor factorization of the parameters weight used.
By default None.
implementation : {'factorized', 'reconstructed'}, optional
If factorization is not None, forward mode to use:
* `reconstructed` : the full weight tensor is reconstructed from the
factorization and used for the forward pass
* `factorized` : the input is directly contracted with the factors of
the decomposition
Ignored if ``factorization is None``.
By default 'reconstructed'.
enforce_hermitian_symmetry : bool, optional
Whether to enforce Hermitian symmetry conditions when performing inverse FFT
for real-valued data. When True, explicitly enforces that the 0th frequency
and Nyquist frequency are real-valued before calling irfft.
When False, relies on cuFFT's irfftn to handle symmetry automatically,
which may fail on certain GPUs or input sizes, causing line artifacts.
Setting to True splits the inverse FFT into ifftn along (n-1) dimensions
followed by irfft on the last dimension, with a small computational overhead.
By default True.
fixed_rank_modes : bool, optional
Modes to not factorize, by default False.
Ignored if ``factorization is None``.
decomposition_kwargs : dict or None, optional
Optional additional parameters to pass to the tensor decomposition.
Ignored if ``factorization is None``.
By default None.
init_std : float or 'auto', optional
Standard deviation to use for weight initialization, by default 'auto'.
If 'auto', uses (2 / (in_channels + out_channels)) ** 0.5.
fft_norm : str, optional
FFT normalization parameter, by default 'forward'.
device : torch.device or None, optional
Device to place the layer on, by default None.
References
----------
.. [1] Li, Z. et al. "Fourier Neural Operator for Parametric Partial Differential
Equations" (2021). ICLR 2021, https://arxiv.org/pdf/2010.08895.
.. [2] Kossaifi, J., Kovachki, N., Azizzadenesheli, K., Anandkumar, A. "Multi-Grid
Tensorized Fourier Neural Operator for High-Resolution PDEs" (2024).
TMLR 2024, https://openreview.net/pdf?id=AWiDlO63bH.
"""
def __init__(
self,
in_channels,
out_channels,
n_modes,
complex_data=False,
max_n_modes=None,
bias=True,
separable=False,
resolution_scaling_factor: Optional[Union[Number, List[Number]]] = None,
fno_block_precision="full",
rank=1.0,
factorization=None,
implementation="reconstructed",
enforce_hermitian_symmetry=True,
fixed_rank_modes=False,
decomposition_kwargs: Optional[dict] = None,
init_std="auto",
fft_norm="forward",
device=None,
):
super().__init__(device=device)
self.in_channels = in_channels
self.out_channels = out_channels
self.complex_data = complex_data
# n_modes is the total number of modes kept along each dimension
self.n_modes = n_modes
self.order = len(self.n_modes)
if max_n_modes is None:
max_n_modes = self.n_modes
elif isinstance(max_n_modes, int):
max_n_modes = [max_n_modes]
self.max_n_modes = max_n_modes
self.fno_block_precision = fno_block_precision
self.rank = rank
self.factorization = factorization
self.implementation = implementation
self.enforce_hermitian_symmetry = enforce_hermitian_symmetry
self.resolution_scaling_factor: Union[
None, List[List[float]]
] = validate_scaling_factor(resolution_scaling_factor, self.order)
if init_std == "auto":
init_std = (2 / (in_channels + out_channels)) ** 0.5
if isinstance(fixed_rank_modes, bool):
if fixed_rank_modes:
# If bool, keep the number of layers fixed
fixed_rank_modes = [0]
else:
fixed_rank_modes = None
self.fft_norm = fft_norm
if factorization is None:
factorization = "Dense" # No factorization
if separable:
if in_channels != out_channels:
raise ValueError(
"To use separable Fourier Conv, in_channels must be equal "
f"to out_channels, but got in_channels={in_channels} and "
f"out_channels={out_channels}",
)
weight_shape = (in_channels, *max_n_modes)
else:
weight_shape = (in_channels, out_channels, *max_n_modes)
self.separable = separable
tensor_kwargs = decomposition_kwargs if decomposition_kwargs is not None else {}
# Create/init spectral weight tensor
self.weight = FactorizedTensor.new(
weight_shape,
rank=self.rank,
factorization=factorization,
fixed_rank_modes=fixed_rank_modes,
**tensor_kwargs,
dtype=torch.cfloat,
)
self.weight.normal_(0, init_std)
self._contract = get_contract_fun(
self.weight, implementation=implementation, separable=separable
)
if bias:
self.bias = nn.Parameter(
init_std * torch.randn(*(tuple([self.out_channels]) + (1,) * self.order))
)
else:
self.bias = None
@property
def n_modes(self):
return self._n_modes
@n_modes.setter
def n_modes(self, n_modes):
if isinstance(n_modes, int): # Should happen for 1D FNO only
n_modes = [n_modes]
else:
n_modes = list(n_modes)
# the real FFT is skew-symmetric, so the last mode has a redundacy if our data is real in space
# As a design choice we do the operation here to avoid users dealing with the +1
# if we use the full FFT we cannot cut off informtion from the last mode
if not self.complex_data:
n_modes[-1] = n_modes[-1] // 2 + 1
self._n_modes = n_modes
[docs]
def forward(self, x: torch.Tensor, output_shape: Optional[Tuple[int]] = None):
"""Generic forward pass for the Factorized Spectral Conv
Parameters
----------
x : torch.Tensor
input activation of size (batch_size, channels, d1, ..., dN)
Returns
-------
tensorized_spectral_conv(x)
"""
batchsize, channels, *mode_sizes = x.shape
fft_size = list(mode_sizes)
if not self.complex_data:
fft_size[-1] = fft_size[-1] // 2 + 1 # Redundant last coefficient in real spatial data
fft_dims = list(range(-self.order, 0))
if self.fno_block_precision == "half":
x = x.half()
if self.complex_data:
x = torch.fft.fftn(x, norm=self.fft_norm, dim=fft_dims)
dims_to_fft_shift = fft_dims
else:
x = torch.fft.rfftn(x, norm=self.fft_norm, dim=fft_dims)
# When x is real in spatial domain, the last half of the last dim is redundant.
# See :ref:`fft_shift_explanation` for discussion of the FFT shift.
dims_to_fft_shift = fft_dims[:-1]
if self.order > 1:
x = torch.fft.fftshift(x, dim=dims_to_fft_shift)
if self.fno_block_precision == "mixed":
# if 'mixed', the above fft runs in full precision, but the
# following operations run at half precision
x = x.chalf()
if self.fno_block_precision in ["half", "mixed"]:
out_dtype = torch.chalf
else:
out_dtype = torch.cfloat
out_fft = torch.zeros(
[batchsize, self.out_channels, *fft_size], device=x.device, dtype=out_dtype
)
# if current modes are less than max, start indexing modes closer to the center of the weight tensor
starts = [
(max_modes - min(size, n_mode))
for (size, n_mode, max_modes) in zip(fft_size, self.n_modes, self.max_n_modes)
]
# if contraction is separable, weights have shape (channels, modes_x, ...)
# otherwise they have shape (in_channels, out_channels, modes_x, ...)
if self.separable:
slices_w = [slice(None)] # channels
else:
slices_w = [slice(None), slice(None)] # in_channels, out_channels
if self.complex_data:
slices_w += [
slice(start // 2, -start // 2) if start else slice(start, None)
for start in starts
]
else:
# The last mode already has redundant half removed in real FFT
slices_w += [
slice(start // 2, -start // 2) if start else slice(start, None)
for start in starts[:-1]
]
slices_w += [slice(None, -starts[-1]) if starts[-1] else slice(None)]
slices_w = tuple(slices_w)
weight = self.weight[slices_w]
### Pick the first n_modes modes of FFT signal along each dim
# if separable conv, weight tensor only has one channel dim
if self.separable:
weight_start_idx = 1
# otherwise drop first two dims (in_channels, out_channels)
else:
weight_start_idx = 2
slices_x = [slice(None), slice(None)] # Batch_size, channels
for all_modes, kept_modes in zip(fft_size, list(weight.shape[weight_start_idx:])):
# After fft-shift, the 0th frequency is located at n // 2 in each direction
# We select n_modes modes around the 0th frequency (kept at index n//2) by grabbing indices
# n//2 - n_modes//2 to n//2 + n_modes//2 if n_modes is even
# n//2 - n_modes//2 to n//2 + n_modes//2 + 1 if n_modes is odd
center = all_modes // 2
negative_freqs = kept_modes // 2
positive_freqs = kept_modes // 2 + kept_modes % 2
# this slice represents the desired indices along each dim
slices_x += [slice(center - negative_freqs, center + positive_freqs)]
if weight.shape[-1] < fft_size[-1]:
slices_x[-1] = slice(None, weight.shape[-1])
else:
slices_x[-1] = slice(None)
slices_x = tuple(slices_x)
out_fft[slices_x] = self._contract(
x[slices_x], weight, separable=self.separable
)
if self.resolution_scaling_factor is not None and output_shape is None:
mode_sizes = tuple([round(s * r) for (s, r) in zip(mode_sizes, self.resolution_scaling_factor)])
if output_shape is not None:
mode_sizes = output_shape
if self.order > 1:
out_fft = torch.fft.ifftshift(out_fft, dim=fft_dims[:-1])
# Inverse FFT
if self.complex_data:
# For complex data, we can use ifftn.
x = torch.fft.ifftn(out_fft, s=mode_sizes, dim=fft_dims, norm=self.fft_norm)
else:
# For real data, we need to enforce Hermitian symmetry conditions for irfft.
# On certain GPUs and for certain input sizes, this is not handled within irfftn in cuFFT,
# and as a result causes line artifacts.
# To fix this, we split the ifftn into a ifftn in (n-1) dimensions and a irfft in the last dimension,
# although it incurs a small additional computational cost.
if self.enforce_hermitian_symmetry:
out_fft = torch.fft.ifftn(out_fft, s=mode_sizes[:-1], dim=fft_dims[:-1], norm=self.fft_norm)
# Enforce Hermitian symmetry conditions for irfft
# 0th frequency must be real
out_fft[..., 0].imag.zero_()
# Nyquist frequency must be real if the spatial size is even
if mode_sizes[-1] % 2 == 0:
out_fft[..., -1].imag.zero_()
# Now that the Hermitian symmetry conditions are enforced, we can use irfft on the last dimension.
x = torch.fft.irfft(out_fft, n=mode_sizes[-1], dim=fft_dims[-1], norm=self.fft_norm)
else:
# If Hemrmitian symmetry is not a concern, we can use irfftn on all dimensions.
x = torch.fft.irfftn(out_fft, s=mode_sizes, dim=fft_dims, norm=self.fft_norm)
if self.bias is not None:
x = x + self.bias
return x