Source code for neuralop.layers.spectral_projection

import torch
import numpy as np
import torch.nn.functional as F


[docs] def spectral_projection_divergence_free(u, domain_size, constraint_modes): """Apply spectral projection layer to make a velocity field divergence-free. Parameters ---------- u : torch.Tensor Input velocity field [batch, 2, height, width] where the last two dimensions represent the 2D velocity components (u_x, u_y) domain_size : float or tuple of float Physical domain size. If float, assumes square domain with same size for height and width. If tuple, should be (height_size, width_size) for non-square domains. constraint_modes : tuple Number of modes to use for constraint resolution (height_modes, width_modes). If larger than input dimensions, they are truncated to the input dimensions. Returns ------- torch.Tensor Divergence-free projected velocity field with [batch, 2, height, width] shape. The output maintains the same shape as the input while satisfying ∇·u = 0. Notes ----- **Mathematical formulation:** Helmholtz-Hodge projection in spectral domain: û_proj = û - (k·û)/|k|² * k. Enforces ∇·u = 0 by removing the irrotational component. **Periodicity:** Velocity fields are assumed periodic. For non-periodic fields, extend via Fourier continuation (neuralop.layers.fourier_continuation), project, then truncate. References ---------- The method is based on the spectral projection approach described in: [1] Towards Enforcing Hard Physics Constraints in Operator Learning Frameworks V. Duruisseaux, M. Liu-Schiaffini, J. Berner, and A. Anandkumar ICML 2024 AI for Science Workshop https://openreview.net/pdf?id=Zvxm14Rd1F [2] Enforcing physical constraints in CNNs through differentiable PDE layer. C. M. Jiang, K. Kashinath, P. Prabhat, and P. Marcus ICLR 2020 Workshop on Integration of Deep Neural Models and Differential Equations, 2020. https://openreview.net/pdf?id=q2noHUqMkK """ device = u.device dtype = u.dtype batch_size, channels, height, width = u.shape # Get domain size if isinstance(domain_size, (int, float)): # Square domain: use same size for both dimensions domain_height = domain_width = float(domain_size) elif isinstance(domain_size, (tuple, list)) and len(domain_size) == 2: # Non-square domain: separate sizes for height and width domain_height, domain_width = float(domain_size[0]), float(domain_size[1]) # Ensure constraint modes do not exceed input dimensions constraint_modes = (min(constraint_modes[0], height), min(constraint_modes[1], width)) # 2D FFT over height and width u_ft = torch.fft.fftn(u, dim=(2, 3)) # Extract lower modes for both dimensions where the constraint is applied if (height != constraint_modes[0]) or (width != constraint_modes[1]): u_ft = torch.fft.fftshift(u_ft, dim=(2, 3)) u_ft = u_ft[:, :, (height - constraint_modes[0])//2 : (height - constraint_modes[0])//2 + constraint_modes[0], (width - constraint_modes[1])//2 : (width - constraint_modes[1])//2 + constraint_modes[1]] u_ft = torch.fft.ifftshift(u_ft, dim=(2, 3)) # Set up wavenumber grids for spectral operations ky = 2*np.pi * torch.fft.fftfreq(constraint_modes[0], d=domain_height/constraint_modes[0]).to(dtype).to(device) kx = 2*np.pi * torch.fft.fftfreq(constraint_modes[1], d=domain_width/constraint_modes[1]).to(dtype).to(device) KX, KY = torch.meshgrid(kx, ky, indexing='ij' if constraint_modes[0] == constraint_modes[1] else 'xy') # Apply Helmholtz-Hodge projection: û_proj = û - (k·û)/|k|² * k k_dot_u = KX * u_ft[:, 0, :, :] + KY * u_ft[:, 1, :, :] k_squared = KX**2 + KY**2 + 1e-8 # Add small epsilon to avoid division by zero projected_u_ft = u_ft - (k_dot_u / k_squared).unsqueeze(1) * torch.stack([KX, KY], dim=0).unsqueeze(0) # Handle zero mode explicitly projected_u_ft[:, :, 0, 0] = 0.0 # Pad zeros back to full resolution if needed if height != constraint_modes[0] or width != constraint_modes[1]: projected_u_ft = torch.fft.fftshift(projected_u_ft) pad_h = height - constraint_modes[0] pad_w = width - constraint_modes[1] projected_u_ft = F.pad(projected_u_ft, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) projected_u_ft = torch.fft.ifftshift(projected_u_ft) # Transform back to physical space return torch.fft.ifftn(projected_u_ft, dim=(2, 3)).real