import torch
import numpy as np
import torch.nn.functional as F
[docs]
def spectral_projection_divergence_free(u, domain_size, constraint_modes):
"""Apply spectral projection layer to make a velocity field divergence-free.
Parameters
----------
u : torch.Tensor
Input velocity field [batch, 2, height, width] where the last two
dimensions represent the 2D velocity components (u_x, u_y)
domain_size : float or tuple of float
Physical domain size. If float, assumes square domain with same size for height and width.
If tuple, should be (height_size, width_size) for non-square domains.
constraint_modes : tuple
Number of modes to use for constraint resolution (height_modes, width_modes).
If larger than input dimensions, they are truncated to the input dimensions.
Returns
-------
torch.Tensor
Divergence-free projected velocity field with [batch, 2, height, width] shape.
The output maintains the same shape as the input while satisfying ∇·u = 0.
Notes
-----
**Mathematical formulation:** Helmholtz-Hodge projection in spectral domain:
û_proj = û - (k·û)/|k|² * k. Enforces ∇·u = 0 by removing the irrotational component.
**Periodicity:** Velocity fields are assumed periodic. For non-periodic fields, extend
via Fourier continuation (neuralop.layers.fourier_continuation), project, then truncate.
References
----------
The method is based on the spectral projection approach described in:
[1] Towards Enforcing Hard Physics Constraints in Operator Learning Frameworks
V. Duruisseaux, M. Liu-Schiaffini, J. Berner, and A. Anandkumar
ICML 2024 AI for Science Workshop
https://openreview.net/pdf?id=Zvxm14Rd1F
[2] Enforcing physical constraints in CNNs through differentiable PDE layer.
C. M. Jiang, K. Kashinath, P. Prabhat, and P. Marcus
ICLR 2020 Workshop on Integration of Deep Neural Models and Differential Equations, 2020.
https://openreview.net/pdf?id=q2noHUqMkK
"""
device = u.device
dtype = u.dtype
batch_size, channels, height, width = u.shape
# Get domain size
if isinstance(domain_size, (int, float)):
# Square domain: use same size for both dimensions
domain_height = domain_width = float(domain_size)
elif isinstance(domain_size, (tuple, list)) and len(domain_size) == 2:
# Non-square domain: separate sizes for height and width
domain_height, domain_width = float(domain_size[0]), float(domain_size[1])
# Ensure constraint modes do not exceed input dimensions
constraint_modes = (min(constraint_modes[0], height), min(constraint_modes[1], width))
# 2D FFT over height and width
u_ft = torch.fft.fftn(u, dim=(2, 3))
# Extract lower modes for both dimensions where the constraint is applied
if (height != constraint_modes[0]) or (width != constraint_modes[1]):
u_ft = torch.fft.fftshift(u_ft, dim=(2, 3))
u_ft = u_ft[:, :,
(height - constraint_modes[0])//2 : (height - constraint_modes[0])//2 + constraint_modes[0],
(width - constraint_modes[1])//2 : (width - constraint_modes[1])//2 + constraint_modes[1]]
u_ft = torch.fft.ifftshift(u_ft, dim=(2, 3))
# Set up wavenumber grids for spectral operations
ky = 2*np.pi * torch.fft.fftfreq(constraint_modes[0], d=domain_height/constraint_modes[0]).to(dtype).to(device)
kx = 2*np.pi * torch.fft.fftfreq(constraint_modes[1], d=domain_width/constraint_modes[1]).to(dtype).to(device)
KX, KY = torch.meshgrid(kx, ky, indexing='ij' if constraint_modes[0] == constraint_modes[1] else 'xy')
# Apply Helmholtz-Hodge projection: û_proj = û - (k·û)/|k|² * k
k_dot_u = KX * u_ft[:, 0, :, :] + KY * u_ft[:, 1, :, :]
k_squared = KX**2 + KY**2 + 1e-8 # Add small epsilon to avoid division by zero
projected_u_ft = u_ft - (k_dot_u / k_squared).unsqueeze(1) * torch.stack([KX, KY], dim=0).unsqueeze(0)
# Handle zero mode explicitly
projected_u_ft[:, :, 0, 0] = 0.0
# Pad zeros back to full resolution if needed
if height != constraint_modes[0] or width != constraint_modes[1]:
projected_u_ft = torch.fft.fftshift(projected_u_ft)
pad_h = height - constraint_modes[0]
pad_w = width - constraint_modes[1]
projected_u_ft = F.pad(projected_u_ft, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
projected_u_ft = torch.fft.ifftshift(projected_u_ft)
# Transform back to physical space
return torch.fft.ifftn(projected_u_ft, dim=(2, 3)).real