import torch
import warnings
from ..layers.fourier_continuation import FCLegendre, FCGram
"""
differentiation.py implements utilities for computing derivatives via finite-differences
and Fourier/spectral differentiation methods
"""
[docs]
class FiniteDiff:
"""A unified class for computing finite differences in 1D, 2D, or 3D.
This class provides comprehensive methods for computing derivatives using finite differences
with support for both periodic and non-periodic boundary conditions.
It implements the following high-order finite difference schemes:
- Interior points: Second-order central differences for optimal accuracy
- Periodic boundaries: Uses torch.roll for seamless periodic wrapping.
- Non-periodic boundaries: Uses third-order one-sided differences at boundary points.
Parameters
----------
dim : int
Dimension of the input field. Must be 1, 2, or 3.
h : float or tuple, optional
Grid spacing(s) for finite difference calculations, by default 1.0.
- For 1D: single float or tuple with one element
- For 2D: tuple (h_x, h_y) or single float for uniform spacing
- For 3D: tuple (h_x, h_y, h_z) or single float for uniform spacing
periodic_in_x : bool, optional
Whether to use periodic boundary conditions in x-direction, by default True.
When True, uses torch.roll for efficient periodic wrapping.
When False, uses high-order one-sided differences at boundaries.
periodic_in_y : bool, optional
Whether to use periodic boundary conditions in y-direction, by default True.
When True, uses torch.roll for efficient periodic wrapping.
When False, uses high-order one-sided differences at boundaries.
Only used for 2D and 3D fields.
periodic_in_z : bool, optional
Whether to use periodic boundary conditions in z-direction, by default True.
When True, uses torch.roll for efficient periodic wrapping.
When False, uses high-order one-sided differences at boundaries.
Only used for 3D fields.
Available Methods
----------------
Derivative Methods:
- dx(u, order=1): Compute derivative with respect to x
- dy(u, order=1): Compute derivative with respect to y (2D/3D only)
- dz(u, order=1): Compute derivative with respect to z (3D only)
Vector Calculus Operators:
- laplacian(u): Compute the Laplacian ∇²f
- gradient(u): Compute the gradient ∇f (returns vector field)
- divergence(u): Compute the divergence ∇·u (for vector fields)
- curl(u): Compute the curl ∇×u (for vector fields, 2D/3D only)
Mathematical Formulas
---------------------
For first-order derivatives:
- Interior: (f_{i+1} - f_{i-1})/(2h) [2nd order central]
- Left boundary: (-11f_0 + 18f_1 - 9f_2 + 2f_3)/(6h) [3rd order forward]
- Right boundary: (-2f_{n-4} + 9f_{n-3} - 18f_{n-2} + 11f_{n-1})/(6h) [3rd order backward]
For second-order derivatives:
- Interior: (f_{i+1} - 2f_i + f_{i-1})/(h²) [2nd order central]
- Left boundary: (2f_0 - 5f_1 + 4f_2 - f_3)/(h²) [3rd order forward]
- Right boundary: (-f_{n-4} + 4f_{n-3} - 5f_{n-2} + 2f_{n-1})/(h²) [3rd order backward]
Examples
--------
>>> # 1D finite differences
>>> x = torch.linspace(0, 2*torch.pi, 100)
>>> u = torch.sin(x)
>>> fd1d = FiniteDiff(dim=1, h=0.1, periodic_in_x=True)
>>> du_dx = fd1d.dx(u) # First derivative
>>> d2u_dx2 = fd1d.dx(u, order=2) # Second derivative
>>>
>>> # 2D finite differences
>>> fd2d = FiniteDiff(dim=2, h=(0.1, 0.1), periodic_in_x=True, periodic_in_y=False)
>>> x = torch.linspace(0, 2*torch.pi, 50)
>>> y = torch.linspace(0, 2*torch.pi, 50)
>>> X, Y = torch.meshgrid(x, y, indexing='ij')
>>> u = torch.sin(X) * torch.cos(Y)
>>> du_dx = fd2d.dx(u)
>>> du_dy = fd2d.dy(u)
>>> grad = fd2d.gradient(u) # Returns [du_dx, du_dy]
>>>
>>> # 3D finite differences
>>> fd3d = FiniteDiff(dim=3, h=(0.1, 0.1, 0.1), periodic_in_x=True, periodic_in_y=True, periodic_in_z=False)
>>> x = torch.linspace(0, 2*torch.pi, 20)
>>> y = torch.linspace(0, 2*torch.pi, 20)
>>> z = torch.linspace(0, 2*torch.pi, 20)
>>> X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
>>> u = torch.sin(X) * torch.cos(Y) * torch.sin(Z) # 3D scalar field
>>> du_dx = fd3d.dx(u)
>>> du_dy = fd3d.dy(u)
>>> du_dz = fd3d.dz(u)
>>> laplacian = fd3d.laplacian(u) # Sum of all second derivatives
>>>
>>> # Vector field operations
>>> vx = torch.sin(X) * torch.cos(Y) * torch.sin(Z)
>>> vy = torch.cos(X) * torch.sin(Y) * torch.cos(Z)
>>> vz = torch.sin(X) * torch.sin(Y) * torch.cos(Z)
>>> v = torch.stack([vx, vy, vz], dim=-4) # 3D vector field
>>> div_v = fd3d.divergence(v) # Scalar field
>>> curl_v = fd3d.curl(v) # Vector field
"""
def __init__(
self, dim, h=1.0, periodic_in_x=True, periodic_in_y=True, periodic_in_z=True
):
"""
Initialize the FiniteDiff class for computing finite differences.
See class docstring for detailed parameter descriptions.
"""
# Check if dim is valid
if dim not in [1, 2, 3]:
raise ValueError("dim must be 1, 2, or 3")
self.dim = dim
# Set up grid spacing
if isinstance(h, (int, float)):
# Create tuple of length dim with repeated h value
self.h = tuple(h for _ in range(dim))
else:
# h is already a tuple/list
if len(h) != dim:
raise ValueError(f"For {dim}D, h must be a float or a tuple of length {dim}")
self.h = tuple(h) # Convert to tuple
# Set up periodic conditions
self.periodic_in_x = periodic_in_x
if dim >= 2:
self.periodic_in_y = periodic_in_y
if dim >= 3:
self.periodic_in_z = periodic_in_z
[docs]
def dx(self, u, order=1):
"""
Compute derivative with respect to x.
Parameters
----------
u : torch.Tensor
Input tensor
order : int, optional
Order of the derivative, by default 1
Returns
-------
torch.Tensor
Derivative with respect to x
"""
if self.dim == 1:
return self._dx_1d(u, order)
elif self.dim == 2:
return self._dx_2d(u, order)
else: # dim == 3
return self._dx_3d(u, order)
[docs]
def dy(self, u, order=1):
"""
Compute derivative with respect to y.
Parameters
----------
u : torch.Tensor
Input tensor
order : int, optional
Order of the derivative, by default 1
Returns
-------
torch.Tensor
Derivative with respect to y
"""
if self.dim < 2:
raise ValueError("dy is only available for 2D and 3D")
elif self.dim == 2:
return self._dy_2d(u, order)
else: # dim == 3
return self._dy_3d(u, order)
[docs]
def dz(self, u, order=1):
"""
Compute derivative with respect to z.
Parameters
----------
u : torch.Tensor
Input tensor
order : int, optional
Order of the derivative, by default 1
Returns
-------
torch.Tensor
Derivative with respect to z
"""
if self.dim < 3:
raise ValueError("dz is only available for 3D")
return self._dz_3d(u, order)
[docs]
def laplacian(self, u):
"""
Compute the Laplacian ∇²f.
Parameters
----------
u : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
The Laplacian of the input tensor
"""
if self.dim == 1:
return self._dx_1d(u, 2)
elif self.dim == 2:
return self._dx_2d(u, 2) + self._dy_2d(u, 2)
else: # dim == 3
return self._dx_3d(u, 2) + self._dy_3d(u, 2) + self._dz_3d(u, 2)
[docs]
def gradient(self, u):
"""
Compute the gradient ∇f for scalar fields.
Parameters
----------
u : torch.Tensor
Input scalar field
Returns
-------
torch.Tensor
The gradient of the scalar field
"""
if self.dim == 1:
return self._dx_1d(u, 1)
elif self.dim == 2:
grad_x = self._dx_2d(u, 1)
grad_y = self._dy_2d(u, 1)
return torch.stack([grad_x, grad_y], dim=-3)
else: # dim == 3
grad_x = self._dx_3d(u, 1)
grad_y = self._dy_3d(u, 1)
grad_z = self._dz_3d(u, 1)
return torch.stack([grad_x, grad_y, grad_z], dim=-4)
[docs]
def divergence(self, u):
"""
Compute the divergence ∇·u for vector fields.
Parameters
----------
u : torch.Tensor
Input vector field
Returns
-------
torch.Tensor
The divergence of the vector field
"""
# Check input dimensions match vector field components
n_components_expected = self.dim
n_components_actual = u.shape[-self.dim - 1]
if n_components_actual != n_components_expected:
raise ValueError(f"Input must be a {self.dim}D vector field with {n_components_expected} components")
if self.dim == 1:
return self._dx_1d(u[..., 0, :], 1)
elif self.dim == 2:
u1, u2 = u[..., 0, :, :], u[..., 1, :, :]
return self._dx_2d(u1, 1) + self._dy_2d(u2, 1)
else: # dim == 3
u1, u2, u3 = u[..., 0, :, :, :], u[..., 1, :, :, :], u[..., 2, :, :, :]
return self._dx_3d(u1, 1) + self._dy_3d(u2, 1) + self._dz_3d(u3, 1)
[docs]
def curl(self, u):
"""
Compute the curl ∇×u for vector fields.
Parameters
----------
u : torch.Tensor
Input vector field
Returns
-------
torch.Tensor
The curl of the vector field
"""
if self.dim == 1:
raise ValueError("Curl is not defined for 1D")
elif self.dim == 2:
if u.shape[-3] != 2:
raise ValueError("Input must be a 2D vector field with 2 components")
u1, u2 = u[..., 0, :, :], u[..., 1, :, :]
return self._dx_2d(u2, 1) - self._dy_2d(u1, 1)
else: # dim == 3
if u.shape[-4] != 3:
raise ValueError("Input must be a 3D vector field with 3 components")
u1, u2, u3 = u[..., 0, :, :, :], u[..., 1, :, :, :], u[..., 2, :, :, :]
curl_x = self._dy_3d(u3, 1) - self._dz_3d(u2, 1)
curl_y = self._dz_3d(u1, 1) - self._dx_3d(u3, 1)
curl_z = self._dx_3d(u2, 1) - self._dy_3d(u1, 1)
return torch.stack([curl_x, curl_y, curl_z], dim=-4)
def _dx_1d(self, u, order):
"""1D derivative with respect to x."""
if order == 1:
return self._dx_1st_1d(u)
elif order == 2:
return self._dx_2nd_1d(u)
else:
raise ValueError("Only 1st and 2nd order derivatives currently supported")
def _dx_1st_1d(self, u):
"""First order derivative with respect to x (1D)."""
if self.periodic_in_x:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i+1} - f_{i-1})/(2h)
dx = (torch.roll(u, -1, dims=-1) - torch.roll(u, 1, dims=-1)) / (2.0 * self.h[0])
else:
# Non-periodic case: handle boundaries separately
dx = torch.zeros_like(u)
# Interior points: Second-order central differences (f_{i+1} - f_{i-1})/(2h)
dx[..., 1:-1] = (u[..., 2:] - u[..., :-2]) / (2.0 * self.h[0])
# Left boundary: 3rd-order forward differences (-11f_{0} + 18f_{1} - 9f_{2} + 2f_{3})/(6h)
dx[..., 0] = (-11 * u[..., 0] + 18 * u[..., 1] - 9 * u[..., 2] + 2 * u[..., 3]) / (6.0 * self.h[0])
# Right boundary: 3rd-order backward differences (-2f_{n-4} + 9f_{n-3} - 18f_{n-2} + 11f_{n-1})/(6h)
dx[..., -1] = (-2 * u[..., -4] + 9 * u[..., -3] - 18 * u[..., -2] + 11 * u[..., -1]) / (6.0 * self.h[0])
return dx
def _dx_2nd_1d(self, u):
"""Second order derivative with respect to x (1D)."""
if self.periodic_in_x:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i+1} - 2f_{i} + f_{i-1})/(h²)
dxx = (torch.roll(u, -1, dims=-1) - 2 * u + torch.roll(u, 1, dims=-1)) / (self.h[0]**2)
else:
# Non-periodic case: handle boundaries separately
dxx = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i+1} - 2f_{i} + f_{i-1})/(h²)
dxx[..., 1:-1] = (u[..., 2:] - 2 * u[..., 1:-1] + u[..., :-2]) / (self.h[0]**2)
# Left boundary: 3rd-order forward differences (2f_{0} - 5f_{1} + 4f_{2} - f_{3})/h²
dxx[..., 0] = (2 * u[..., 0] - 5 * u[..., 1] + 4 * u[..., 2] - u[..., 3]) / (self.h[0]**2)
# Right boundary: 3rd-order backward differences (-f_{n-4} + 4f_{n-3} - 5f_{n-2} + 2f_{n-1})/h²
dxx[..., -1] = (-u[..., -4] + 4 * u[..., -3] - 5 * u[..., -2] + 2 * u[..., -1]) / (self.h[0]**2)
return dxx
def _dx_2d(self, u, order):
"""2D derivative with respect to x."""
if order == 1:
return self._dx_1st_2d(u)
elif order == 2:
return self._dx_2nd_2d(u)
else:
raise ValueError("Only 1st and 2nd order derivatives currently supported")
def _dy_2d(self, u, order):
"""2D derivative with respect to y."""
if order == 1:
return self._dy_1st_2d(u)
elif order == 2:
return self._dy_2nd_2d(u)
else:
raise ValueError("Only 1st and 2nd order derivatives currently supported")
def _dx_1st_2d(self, u):
"""First order derivative with respect to x (2D)."""
if self.periodic_in_x:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i+1,j} - f_{i-1,j})/(2h_{x})
dx = (torch.roll(u, -1, dims=-2) - torch.roll(u, 1, dims=-2)) / (2.0 * self.h[0])
else:
# Non-periodic case: handle boundaries separately
dx = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i+1,j} - f_{i-1,j})/(2h_{x})
dx[..., 1:-1, :] = (u[..., 2:, :] - u[..., :-2, :]) / (2.0 * self.h[0])
# Left boundary: 3rd-order forward differences (-11f_{0} + 18f_{1} - 9f_{2} + 2f_{3})/(6h_{x})
dx[..., 0, :] = (-11 * u[..., 0, :] + 18 * u[..., 1, :] - 9 * u[..., 2, :] + 2 * u[..., 3, :]) / (6.0 * self.h[0])
# Right boundary: 3rd-order backward differences (-2f_{n-4} + 9f_{n-3} - 18f_{n-2} + 11f_{n-1})/(6h_{x})
dx[..., -1, :] = (-2 * u[..., -4, :] + 9 * u[..., -3, :] - 18 * u[..., -2, :] + 11 * u[..., -1, :]) / (6.0 * self.h[0])
return dx
def _dy_1st_2d(self, u):
"""First order derivative with respect to y (2D)."""
if self.periodic_in_y:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i,j+1} - f_{i,j-1})/(2h_{y})
dy = (torch.roll(u, -1, dims=-1) - torch.roll(u, 1, dims=-1)) / (2.0 * self.h[1])
else:
# Non-periodic case: handle boundaries separately
dy = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i,j+1} - f_{i,j-1})/(2h_{y})
dy[..., :, 1:-1] = (u[..., :, 2:] - u[..., :, :-2]) / (2.0 * self.h[1])
# Bottom boundary: 3rd-order forward differences (-11f_{0} + 18f_{1} - 9f_{2} + 2f_{3})/(6h_{y})
dy[..., :, 0] = (-11 * u[..., :, 0] + 18 * u[..., :, 1] - 9 * u[..., :, 2] + 2 * u[..., :, 3]) / (6.0 * self.h[1])
# Top boundary: 3rd-order backward differences (-2f_{n-4} + 9f_{n-3} - 18f_{n-2} + 11f_{n-1})/(6h_{y})
dy[..., :, -1] = (-2 * u[..., :, -4] + 9 * u[..., :, -3] - 18 * u[..., :, -2] + 11 * u[..., :, -1]) / (6.0 * self.h[1])
return dy
def _dx_2nd_2d(self, u):
"""Second order derivative with respect to x (2D)."""
if self.periodic_in_x:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i+1,j} - 2f_{i,j} + f_{i-1,j})/(h_{x}²)
dxx = (torch.roll(u, -1, dims=-2) - 2 * u + torch.roll(u, 1, dims=-2)) / (self.h[0] ** 2)
else:
# Non-periodic case: handle boundaries separately
dxx = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i+1,j} - 2f_{i,j} + f_{i-1,j})/(h_{x}²)
dxx[..., 1:-1, :] = (u[..., 2:, :] - 2 * u[..., 1:-1, :] + u[..., :-2, :]) / (self.h[0] ** 2)
# Left boundary: 3rd-order forward differences (2f_{0} - 5f_{1} + 4f_{2} - f_{3})/h_{x}²
dxx[..., 0, :] = (2 * u[..., 0, :] - 5 * u[..., 1, :] + 4 * u[..., 2, :] - u[..., 3, :]) / (self.h[0] ** 2)
# Right boundary: 3rd-order backward differences (-f_{n-4} + 4f_{n-3} - 5f_{n-2} + 2f_{n-1})/h_{x}²
dxx[..., -1, :] = (-u[..., -4, :] + 4 * u[..., -3, :] - 5 * u[..., -2, :] + 2 * u[..., -1, :]) / (self.h[0] ** 2)
return dxx
def _dy_2nd_2d(self, u):
"""Second order derivative with respect to y (2D)."""
if self.periodic_in_y:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i,j+1} - 2f_{i,j} + f_{i,j-1})/(h_{y}²)
dyy = (torch.roll(u, -1, dims=-1) - 2 * u + torch.roll(u, 1, dims=-1)) / (self.h[1] ** 2)
else:
# Non-periodic case: handle boundaries separately
dyy = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i,j+1} - 2f_{i,j} + f_{i,j-1})/(h_{y}²)
dyy[..., :, 1:-1] = (u[..., :, 2:] - 2 * u[..., :, 1:-1] + u[..., :, :-2]) / (self.h[1] ** 2)
# Bottom boundary: 3rd-order forward differences (2f_{0} - 5f_{1} + 4f_{2} - f_{3})/h_{y}²
dyy[..., :, 0] = (2 * u[..., :, 0] - 5 * u[..., :, 1] + 4 * u[..., :, 2] - u[..., :, 3]) / (self.h[1] ** 2)
# Top boundary: 3rd-order backward differences (-f_{n-4} + 4f_{n-3} - 5f_{n-2} + 2f_{n-1})/h_{y}²
dyy[..., :, -1] = (-u[..., :, -4] + 4 * u[..., :, -3] - 5 * u[..., :, -2] + 2 * u[..., :, -1]) / (self.h[1] ** 2)
return dyy
def _dx_3d(self, u, order):
"""3D derivative with respect to x."""
if order == 1:
return self._dx_1st_3d(u)
elif order == 2:
return self._dx_2nd_3d(u)
else:
raise ValueError("Only 1st and 2nd order derivatives currently supported")
def _dy_3d(self, u, order):
"""3D derivative with respect to y."""
if order == 1:
return self._dy_1st_3d(u)
elif order == 2:
return self._dy_2nd_3d(u)
else:
raise ValueError("Only 1st and 2nd order derivatives currently supported")
def _dz_3d(self, u, order):
"""3D derivative with respect to z."""
if order == 1:
return self._dz_1st_3d(u)
elif order == 2:
return self._dz_2nd_3d(u)
else:
raise ValueError("Only 1st and 2nd order derivatives currently supported")
def _dx_1st_3d(self, u):
"""First order derivative with respect to x (3D)."""
if self.periodic_in_x:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i+1,j,k} - f_{i-1,j,k})/(2h_{x})
dx = (torch.roll(u, -1, dims=-3) - torch.roll(u, 1, dims=-3)) / (2.0 * self.h[0])
else:
# Non-periodic case: handle boundaries separately
dx = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i+1,j,k} - f_{i-1,j,k})/(2h_{x})
dx[..., 1:-1, :, :] = (u[..., 2:, :, :] - u[..., :-2, :, :]) / (2.0 * self.h[0])
# Left boundary: 3rd-order forward differences (-11f_{0} + 18f_{1} - 9f_{2} + 2f_{3})/(6h_{x})
dx[..., 0, :, :] = (-11 * u[..., 0, :, :] + 18 * u[..., 1, :, :] - 9 * u[..., 2, :, :] + 2 * u[..., 3, :, :]) / (6.0 * self.h[0])
# Right boundary: 3rd-order backward differences (-2f_{n-4} + 9f_{n-3} - 18f_{n-2} + 11f_{n-1})/(6h_{x})
dx[..., -1, :, :] = (-2 * u[..., -4, :, :] + 9 * u[..., -3, :, :] - 18 * u[..., -2, :, :] + 11 * u[..., -1, :, :]) / (6.0 * self.h[0])
return dx
def _dy_1st_3d(self, u):
"""First order derivative with respect to y (3D)."""
if self.periodic_in_y:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i,j+1,k} - f_{i,j-1,k})/(2h_{y})
dy = (torch.roll(u, -1, dims=-2) - torch.roll(u, 1, dims=-2)) / (2.0 * self.h[1])
else:
# Non-periodic case: handle boundaries separately
dy = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i,j+1,k} - f_{i,j-1,k})/(2h_{y})
dy[..., :, 1:-1, :] = (u[..., :, 2:, :] - u[..., :, :-2, :]) / (2.0 * self.h[1])
# Bottom boundary: 3rd-order forward differences (-11f_{0} + 18f_{1} - 9f_{2} + 2f_{3})/(6h_{y})
dy[..., :, 0, :] = (-11 * u[..., :, 0, :] + 18 * u[..., :, 1, :] - 9 * u[..., :, 2, :] + 2 * u[..., :, 3, :]) / (6.0 * self.h[1])
# Top boundary: 3rd-order backward differences (-2f_{n-4} + 9f_{n-3} - 18f_{n-2} + 11f_{n-1})/(6h_{y})
dy[..., :, -1, :] = (-2 * u[..., :, -4, :] + 9 * u[..., :, -3, :] - 18 * u[..., :, -2, :] + 11 * u[..., :, -1, :]) / (6.0 * self.h[1])
return dy
def _dz_1st_3d(self, u):
"""First order derivative with respect to z (3D)."""
if self.periodic_in_z:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i,j,k+1} - f_{i,j,k-1})/(2h_{z})
dz = (torch.roll(u, -1, dims=-1) - torch.roll(u, 1, dims=-1)) / (2.0 * self.h[2])
else:
# Non-periodic case: handle boundaries separately
dz = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i,j,k+1} - f_{i,j,k-1})/(2h_{z})
dz[..., :, :, 1:-1] = (u[..., :, :, 2:] - u[..., :, :, :-2]) / (2.0 * self.h[2])
# Front boundary: 3rd-order forward differences (-11f_{0} + 18f_{1} - 9f_{2} + 2f_{3})/(6h_{z})
dz[..., :, :, 0] = (-11 * u[..., :, :, 0] + 18 * u[..., :, :, 1] - 9 * u[..., :, :, 2] + 2 * u[..., :, :, 3]) / (6.0 * self.h[2])
# Back boundary: 3rd-order backward differences (-2f_{n-4} + 9f_{n-3} - 18f_{n-2} + 11f_{n-1})/(6h_{z})
dz[..., :, :, -1] = (-2 * u[..., :, :, -4] + 9 * u[..., :, :, -3] - 18 * u[..., :, :, -2] + 11 * u[..., :, :, -1]) / (6.0 * self.h[2])
return dz
def _dx_2nd_3d(self, u):
"""Second order derivative with respect to x (3D)."""
if self.periodic_in_x:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i+1,j,k} - 2f_{i,j,k} + f_{i-1,j,k})/(h_{x}²)
dxx = (torch.roll(u, -1, dims=-3) - 2 * u + torch.roll(u, 1, dims=-3)) / (self.h[0] ** 2)
else:
# Non-periodic case: handle boundaries separately
dxx = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i+1,j,k} - 2f_{i,j,k} + f_{i-1,j,k})/(h_{x}²)
dxx[..., 1:-1, :, :] = (u[..., 2:, :, :] - 2 * u[..., 1:-1, :, :] + u[..., :-2, :, :]) / (self.h[0] ** 2)
# Left boundary: 3rd-order forward differences (2f_{0} - 5f_{1} + 4f_{2} - f_{3})/h_{x}²
dxx[..., 0, :, :] = (2 * u[..., 0, :, :] - 5 * u[..., 1, :, :] + 4 * u[..., 2, :, :] - u[..., 3, :, :]) / (self.h[0] ** 2)
# Right boundary: 3rd-order backward differences (-f_{n-4} + 4f_{n-3} - 5f_{n-2} + 2f_{n-1})/h_{x}²
dxx[..., -1, :, :] = (-u[..., -4, :, :] + 4 * u[..., -3, :, :] - 5 * u[..., -2, :, :] + 2 * u[..., -1, :, :]) / (self.h[0] ** 2)
return dxx
def _dy_2nd_3d(self, u):
"""Second order derivative with respect to y (3D)."""
if self.periodic_in_y:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i,j+1,k} - 2f_{i,j,k} + f_{i,j-1,k})/(h_{y}²)
dyy = (torch.roll(u, -1, dims=-2) - 2 * u + torch.roll(u, 1, dims=-2)) / (self.h[1] ** 2)
else:
# Non-periodic case: handle boundaries separately
dyy = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i,j+1,k} - 2f_{i,j,k} + f_{i,j-1,k})/(h_{y}²)
dyy[..., :, 1:-1, :] = (u[..., :, 2:, :] - 2 * u[..., :, 1:-1, :] + u[..., :, :-2, :]) / (self.h[1] ** 2)
# Bottom boundary: 3rd-order forward differences (2f_{0} - 5f_{1} + 4f_{2} - f_{3})/h_{y}²
dyy[..., :, 0, :] = (2 * u[..., :, 0, :] - 5 * u[..., :, 1, :] + 4 * u[..., :, 2, :] - u[..., :, 3, :]) / (self.h[1] ** 2)
# Top boundary: 3rd-order backward differences (-f_{n-4} + 4f_{n-3} - 5f_{n-2} + 2f_{n-1})/h_{y}²
dyy[..., :, -1, :] = (-u[..., :, -4, :] + 4 * u[..., :, -3, :] - 5 * u[..., :, -2, :] + 2 * u[..., :, -1, :]) / (self.h[1] ** 2)
return dyy
def _dz_2nd_3d(self, u):
"""Second order derivative with respect to z (3D)."""
if self.periodic_in_z:
# Periodic case: use torch.roll for boundary wrapping
# Central difference: (f_{i,j,k+1} - 2f_{i,j,k} + f_{i,j,k-1})/(h_{z}²)
dzz = (torch.roll(u, -1, dims=-1) - 2 * u +
torch.roll(u, 1, dims=-1)) / (self.h[2] ** 2)
else:
# Non-periodic case: handle boundaries separately
dzz = torch.zeros_like(u)
# Interior points: Second-order central differences
# (f_{i,j,k+1} - 2f_{i,j,k} + f_{i,j,k-1})/(h_{z}²)
dzz[..., :, :, 1:-1] = (u[..., :, :, 2:] - 2 * u[..., :, :, 1:-1] + u[..., :, :, :-2]) / (self.h[2] ** 2)
# Front boundary: 3rd-order forward differences (2f_{0} - 5f_{1} + 4f_{2} - f_{3})/h_{z}²
dzz[..., :, :, 0] = (2 * u[..., :, :, 0] - 5 * u[..., :, :, 1] + 4 * u[..., :, :, 2] - u[..., :, :, 3]) / (self.h[2] ** 2)
# Back boundary: 3rd-order backward differences (-f_{n-4} + 4f_{n-3} - 5f_{n-2} + 2f_{n-1})/h_{z}²
dzz[..., :, :, -1] = (-u[..., :, :, -4] + 4 * u[..., :, :, -3] - 5 * u[..., :, :, -2] + 2 * u[..., :, :, -1]) / (self.h[2] ** 2)
return dzz
# Backward compatibility functions
def central_diff_1d(x, h, periodic_in_x=True):
"""
Backward compatibility function for central_diff_1d.
Creates a FiniteDiff instance with dim=1 and returns dx.
.. deprecated::
This function is deprecated and will be removed in future versions.
Please use the FiniteDiff class instead for better performance and features.
"""
warnings.warn(
"central_diff_1d is deprecated and may not be maintained. "
"Please use FiniteDiff class instead for better performance and features.",
DeprecationWarning,
stacklevel=2,
)
fd1d = FiniteDiff(dim=1, h=h, periodic_in_x=periodic_in_x)
return fd1d.dx(x)
def central_diff_2d(x, h, periodic_in_x=True, periodic_in_y=True):
"""
Backward compatibility function for central_diff_2d.
Creates a FiniteDiff instance with dim=2 and returns dx, dy.
.. deprecated::
This function is deprecated and will be removed in future versions.
Please use the FiniteDiff class instead for better performance and features.
"""
warnings.warn(
"central_diff_2d is deprecated and may not be maintained. "
"Please use FiniteDiff class instead for better performance and features.",
DeprecationWarning,
stacklevel=2,
)
fd2d = FiniteDiff(
dim=2, h=h, periodic_in_x=periodic_in_x, periodic_in_y=periodic_in_y
)
return fd2d.dx(x), fd2d.dy(x)
def central_diff_3d(x, h, periodic_in_x=True, periodic_in_y=True, periodic_in_z=True):
"""
Backward compatibility function for central_diff_3d.
Creates a FiniteDiff instance with dim=3 and returns dx, dy, dz.
.. deprecated::
This function is deprecated and will be removed in future versions.
Please use the FiniteDiff class instead for better performance and features.
"""
warnings.warn(
"central_diff_3d is deprecated and may not be maintained. "
"Please use FiniteDiff class instead for better performance and features.",
DeprecationWarning,
stacklevel=2,
)
fd3d = FiniteDiff(
dim=3,
h=h,
periodic_in_x=periodic_in_x,
periodic_in_y=periodic_in_y,
periodic_in_z=periodic_in_z,
)
return fd3d.dx(x), fd3d.dy(x), fd3d.dz(x)
def get_non_uniform_fd_weights(
points, num_neighbors=5, derivative_indices=[0], radius=None, regularize_lstsq=False
):
"""
Compute finite difference weights for approximating the first order derivative
on an unstructured grid of points
Parameters:
-----------
points : torch tensor of shape (N, d) containing the d coordinates of the N points
num_neighbors: int for the number of nearest neighbors to include in the stencil (including the point itself)
At least 3 and at most N
derivative_indices : indices of the derivatives to compute, e.g. in 2D, [0] for x, [1] for y, [0, 1] for x and y
radius : float, the cutoff distance to use a neighbor as radius
Neighbors with distances beyond this value are not used (their weights are set to zero)
Always keep at least 3 neighbors (including the point itself)
regularize_lstsq : bool, whether to regularize the least squares system
Sometimes torch.linalg.lstsq(A, b).solution creates artifacts so can add regularizer
But regularizer can deteriorate performance when system is well-conditioned
Returns:
--------
indices : torch tensor of shape (N, k) for the indices of k nearest neighbors (including the point itself)
fd_weights : torch tensor of weights of shape (N, len(derivative_indices), k)
fd_weights[i,j,m] contains the weights for the m-th nearest neighbor
in the j-th 1st order derivative for the i-th point
"""
N = points.shape[0]
d = points.shape[1]
k = min(max(num_neighbors, 3), N)
# Get the indices of the k nearest neighbors (including the point itself)
pairwise_distances = torch.cdist(points, points, p=2)
distances, indices = torch.topk(pairwise_distances, k=k, dim=1, largest=False)
# Get mask for neighbors within cutoff radius (and always keep at least 3)
if radius is None:
radius_mask = torch.ones_like(distances, dtype=torch.bool)
else:
radius_mask = distances <= radius
radius_mask[:, :3] = True
# Initialize A to 1 since first row for each point and derivative is 1
A = torch.ones((N, d + 1, k), dtype=points.dtype, device=points.device)
# Compute coordinate differences
for i in range(d):
A[:, i + 1, :] = points[indices, i] - points[:, i].unsqueeze(1)
# Repeat it for each derivative to be evaluated so it becomes of shape (N,
# len(derivative_indices), d+1, k)
A = A.unsqueeze(1).expand(-1, len(derivative_indices), -1, -1)
# Zero out columns for neighbors that are not within the radius
A = A * radius_mask.unsqueeze(1).unsqueeze(2)
# Compute right hand side
b = torch.zeros(
(len(derivative_indices), d + 1, 1), dtype=points.dtype, device=points.device
)
for i in range(len(derivative_indices)):
b[i, derivative_indices[i] + 1] = 1
# Repeat so it becomes (N, len(derivative_indices), d+1, 1)
b = b.unsqueeze(0).expand(N, -1, -1, -1)
# Solve least squares system Aw = b
# sometimes torch.linalg.lstsq(A, b).solution creates artifacts so can add regularizer
# but regularizer can deteriorate performance when system is
# well-conditioned
if regularize_lstsq:
lambda_reg = 1e-6
I_k = torch.eye(k, dtype=A.dtype, device=A.device).unsqueeze(0).unsqueeze(0)
AT = A.transpose(-2, -1)
AT_b = torch.matmul(AT, b)
AT_A = torch.matmul(AT, A) + lambda_reg * I_k
# Use Cholesky decomposition to accelerate torch.linalg.solve(AT_A,
# AT_b).squeeze(-1)
fd_weights = torch.cholesky_solve(AT_b, torch.linalg.cholesky(AT_A)).squeeze(-1)
else:
fd_weights = torch.linalg.lstsq(A, b).solution
return indices, fd_weights.squeeze(-1)
[docs]
class FourierDiff:
"""A unified class for computing Fourier/spectral derivatives in 1D, 2D, 3D.
This class provides comprehensive methods for computing derivatives using Fourier/spectral
methods with support for both periodic and non-periodic functions through Fourier continuation:
- Periodic functions: Direct Fourier differentiation using FFT
- Non-periodic functions: Fourier continuation (FC) is used to extend functions to larger domain
on which the functions are periodic before applying Fourier differentiation with FFT.
The class also provides gradient, divergence, curl, and Laplacian operations.
Parameters
----------
dim : int
Dimension of the input field. Must be 1, 2, or 3.
L : float or tuple, optional
Length of the domain for Fourier differentiation. By default 2*pi for each dimension.
use_fc : str, optional
Whether to use Fourier continuation for non-periodic functions.
Options: False (no FC), 'Legendre', 'Gram'. By default False.
fc_degree : int, optional
Degree of the Fourier continuation polynomial matching. This is the number
of matching points on the left and right boundaries used for the Fourier
continuation procedure. By default 4.
fc_n_additional_pts : int, optional
Number of additional points to add with the Fourier continuation layer.
This extends the domain to handle non-periodic functions. By default 50.
low_pass_filter_ratio : float, optional
If not None, apply a low-pass filter to the Fourier coefficients to reduce
high-frequency noise. Should be between 0 and 1. By default None.
Available Methods
----------------
Derivative Methods:
- dx(u, order=1): Compute derivative with respect to x
- dy(u, order=1): Compute derivative with respect to y (2D/3D only)
- dz(u, order=1): Compute derivative with respect to z (3D only)
- derivative(u, order): Compute derivative with order tuple (e.g., (1,0) for ∂/∂x)
Vector Calculus Operators:
- laplacian(u): Compute the Laplacian ∇²f
- gradient(u): Compute the gradient ∇f (returns vector field)
- divergence(u): Compute the divergence ∇·u (for vector fields)
- curl(u): Compute the curl ∇×u (for vector fields, 2D/3D only)
Mathematical Background
-----------------------
For periodic functions on [0, 2π], the derivative is computed as:
- Forward transform: û_k = FFT(u)
- Derivative in Fourier space: (∂u/∂x)^_k = ik * û_k
- Inverse transform: ∂u/∂x = IFFT(ik * û_k)
For non-periodic functions, Fourier continuation extends the function to
an extended domain (e.g. [0, 2π] → [0, 2π + 2π*additional_pts/n]) on which
the function is periodic.
Examples
--------
>>> # 1D Fourier derivatives
>>> x = torch.linspace(0, 2*torch.pi, 100)
>>> u = torch.sin(x)
>>> fd1d = FourierDiff(dim=1, L=2*torch.pi, use_fc=False)
>>> du_dx = fd1d.dx(u) # First derivative
>>> d2u_dx2 = fd1d.dx(u, order=2) # Second derivative
>>>
>>> # 2D Fourier derivatives
>>> fd2d = FourierDiff(dim=2, L=(2*torch.pi, 2*torch.pi), use_fc=False)
>>> x = torch.linspace(0, 2*torch.pi, 50)
>>> y = torch.linspace(0, 2*torch.pi, 50)
>>> X, Y = torch.meshgrid(x, y, indexing='ij')
>>> u = torch.sin(X) * torch.cos(Y)
>>> du_dx = fd2d.dx(u)
>>> du_dy = fd2d.dy(u)
>>> grad = fd2d.gradient(u) # Returns [du_dx, du_dy]
>>>
>>> # 3D Fourier derivatives
>>> fd3d = FourierDiff(dim=3, L=(2*torch.pi, 2*torch.pi, 2*torch.pi), use_fc=False)
>>> x = torch.linspace(0, 2*torch.pi, 20)
>>> y = torch.linspace(0, 2*torch.pi, 20)
>>> z = torch.linspace(0, 2*torch.pi, 20)
>>> X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
>>> u = torch.sin(X) * torch.cos(Y) * torch.sin(Z) # 3D scalar field
>>> du_dx = fd3d.dx(u)
>>> du_dy = fd3d.dy(u)
>>> du_dz = fd3d.dz(u)
>>> laplacian = fd3d.laplacian(u)
>>>
>>> # Vector field operations
>>> vx = torch.sin(X) * torch.cos(Y) * torch.sin(Z)
>>> vy = torch.cos(X) * torch.sin(Y) * torch.cos(Z)
>>> vz = torch.sin(X) * torch.sin(Y) * torch.cos(Z)
>>> v = torch.stack([vx, vy, vz], dim=-4)
>>> div_v = fd3d.divergence(v)
>>> curl_v = fd3d.curl(v)
"""
def __init__(
self,
dim,
L=None,
use_fc=False,
fc_degree=4,
fc_n_additional_pts=50,
low_pass_filter_ratio=None,
):
"""
Initialize the FourierDiff class for computing Fourier derivatives.
See class docstring for detailed parameter descriptions.
"""
# Check if dim is valid
if dim not in [1, 2, 3]:
raise ValueError("dim must be 1, 2, or 3")
self.dim = dim
# Set default L based on dimension
if L is None:
L = 2 * torch.pi
if not isinstance(L, (tuple, list)):
L = (L,) * dim
if len(L) != dim:
raise ValueError(f"For {dim}D, L must be a single float or tuple with {dim} elements")
self.L = L[0] if dim == 1 else L
self.use_fc = use_fc
self.fc_degree = fc_degree
self.fc_n_additional_pts = fc_n_additional_pts
self.low_pass_filter_ratio = low_pass_filter_ratio
# Initialize FC class if needed
self.FC = None
if self.use_fc in ["Legendre", "Gram"]:
FC_class = FCLegendre if self.use_fc == "Legendre" else FCGram
self.FC = FC_class(d=self.fc_degree, n_additional_pts=self.fc_n_additional_pts)
[docs]
def compute_multiple_derivatives(self, u, derivatives):
"""
Compute multiple derivatives in a single FFT/IFFT call for better performance.
Parameters
----------
u : torch.Tensor
Input tensor.
derivatives : list
List of derivative specifications:
- 1D: list of int (orders)
- 2D: list of tuples (order_x, order_y)
- 3D: list of tuples (order_x, order_y, order_z)
Returns
-------
list of torch.Tensor
List of computed derivatives in the same order as derivatives input
"""
if self.dim == 1:
return self._compute_multiple_derivatives_1d(u, derivatives)
elif self.dim == 2:
return self._compute_multiple_derivatives_2d(u, derivatives)
elif self.dim == 3:
return self._compute_multiple_derivatives_3d(u, derivatives)
[docs]
def derivative(self, u, order):
"""
Compute Fourier derivative of a given tensor.
Parameters
----------
u : torch.Tensor
Input tensor
order : tuple
Derivative orders:
- 1D: (order_x,)
- 2D: (order_x, order_y)
- 3D: (order_x, order_y, order_z)
Returns
-------
torch.Tensor
The derivative of the input tensor
"""
if len(order) != self.dim:
raise ValueError(f"For {self.dim}D, order must be a tuple with {self.dim} elements")
if self.dim == 1:
derivatives = self._compute_multiple_derivatives_1d(u, [order[0]])
elif self.dim == 2:
derivatives = self._compute_multiple_derivatives_2d(u, [order])
elif self.dim == 3:
derivatives = self._compute_multiple_derivatives_3d(u, [order])
return derivatives[0]
[docs]
def partial(self, u, direction="x", order=1):
"""
Compute partial Fourier derivative along a specific direction.
Parameters
----------
u : torch.Tensor
Input tensor
direction : str, optional
Direction along which to compute the derivative, by default 'x'
Options: 'x', 'y' (2D/3D only), 'z' (3D only)
order : int, optional
Order of the derivative, by default 1
Returns
-------
torch.Tensor
The partial derivative of the input tensor
"""
if direction == "x":
return self.dx(u, order=order)
elif direction == "y" and self.dim >= 2:
return self.dy(u, order=order)
elif direction == "z" and self.dim >= 3:
return self.dz(u, order=order)
else:
raise ValueError(
f"Invalid direction '{direction}' for dimension {self.dim}"
)
[docs]
def dx(self, u, order=1):
"""Compute derivative with respect to x."""
if self.dim == 1:
return self._dx_1d(u, order)
elif self.dim == 2:
return self._dx_2d(u, order)
elif self.dim == 3:
return self._dx_3d(u, order)
[docs]
def dy(self, u, order=1):
"""Compute derivative with respect to y (2D/3D only)."""
if self.dim < 2:
raise ValueError("dy method only available for 2D and 3D")
elif self.dim == 2:
return self._dy_2d(u, order)
elif self.dim == 3:
return self._dy_3d(u, order)
[docs]
def dz(self, u, order=1):
"""Compute derivative with respect to z (3D only)."""
if self.dim < 3:
raise ValueError("dz method only available for 3D")
return self._dz_3d(u, order)
[docs]
def laplacian(self, u):
"""Compute the Laplacian ∇²f."""
if self.dim == 1:
return self.dx(u, order=2)
elif self.dim == 2:
return self.dx(u, order=2) + self.dy(u, order=2)
elif self.dim == 3:
return self.dx(u, order=2) + self.dy(u, order=2) + self.dz(u, order=2)
[docs]
def gradient(self, u):
"""Compute the gradient ∇f (returns vector field)."""
if self.dim == 1:
return self.dx(u).unsqueeze(-2)
elif self.dim == 2:
return torch.stack([self.dx(u), self.dy(u)], dim=-3)
elif self.dim == 3:
return torch.stack([self.dx(u), self.dy(u), self.dz(u)], dim=-4)
[docs]
def divergence(self, u):
"""Compute the divergence ∇·u (for vector fields)."""
expected_dims = {1: 1, 2: 2, 3: 3}
if u.shape[-self.dim - 1] != expected_dims[self.dim]:
raise ValueError(
f"For {self.dim}D, input must have {expected_dims[self.dim]} components in the vector dimension"
)
if self.dim == 1:
return self.dx(u.squeeze(-2))
elif self.dim == 2:
return self.dx(u[..., 0, :, :]) + self.dy(u[..., 1, :, :])
elif self.dim == 3:
return self.dx(u[..., 0, :, :, :]) + self.dy(u[..., 1, :, :, :]) + self.dz(u[..., 2, :, :, :])
[docs]
def curl(self, u):
"""Compute the curl ∇×u (for vector fields, 2D/3D only)."""
# Check input dimensions
if self.dim == 1:
raise ValueError("curl not defined for 1D")
elif self.dim == 2 and u.shape[-3] != 2:
raise ValueError("For 2D, input must have 2 components in the vector dimension")
elif self.dim == 3 and u.shape[-4] != 3:
raise ValueError("For 3D, input must have 3 components in the vector dimension")
if self.dim == 2:
# In 2D: ∂v/∂x - ∂u/∂y where u = (u,v) is a 2D vector field
return self.dx(u[..., 1, :, :]) - self.dy(u[..., 0, :, :])
elif self.dim == 3:
# In 3D, ∇×u = (∂w/∂y - ∂v/∂z, ∂u/∂z - ∂w/∂x, ∂v/∂x - ∂u/∂y) where u = (u,v,w) is a 3D vector field
curl_x = self.dy(u[..., 2, :, :, :]) - self.dz(u[..., 1, :, :, :]) # ∂w/∂y - ∂v/∂z
curl_y = self.dz(u[..., 0, :, :, :]) - self.dx(u[..., 2, :, :, :]) # ∂u/∂z - ∂w/∂x
curl_z = self.dx(u[..., 1, :, :, :]) - self.dy(u[..., 0, :, :, :]) # ∂v/∂x - ∂u/∂y
# Stack the components into a 3D vector field
return torch.stack([curl_x, curl_y, curl_z], dim=-4)
def _compute_multiple_derivatives_1d(self, u, orders):
"""1D multiple derivatives computation."""
if u is None:
raise ValueError("Input tensor u is None")
L_x = self.L
nx = u.shape[-1]
u_clone = u.clone()
# Apply Fourier continuation if specified
if self.use_fc and self.FC is not None:
FC = self.FC.to(u_clone.device)
u_clone = FC(u_clone, dim=1)
L_x *= (nx + self.fc_n_additional_pts) / nx
# Update grid parameters after extension
nx = u_clone.shape[-1]
dx = L_x / nx
# FFT
u_h = torch.fft.rfft(u_clone, dim=-1)
# Frequency array
k_x = torch.fft.rfftfreq(nx, d=dx, device=u_h.device) * (2 * torch.pi)
# Apply low-pass filter if specified
if self.low_pass_filter_ratio is not None:
cutoff = int(u_h.shape[-1] * self.low_pass_filter_ratio)
u_h[..., cutoff:] = 0
# Compute derivatives
results = []
for order in orders:
derivative_u_h = ((1j * k_x) ** order) * u_h
results.append(derivative_u_h)
derivatives_ft = torch.stack(results, dim=0)
derivatives_real = torch.fft.irfft(derivatives_ft, dim=-1, n=nx)
# Crop result if Fourier continuation was used
if self.use_fc and self.FC is not None:
FC = self.FC.to(derivatives_real.device)
derivatives_real = FC.restrict(derivatives_real, dim=1)
return [derivatives_real[i] for i in range(len(orders))]
def _dx_1d(self, u, order):
"""1D derivative with respect to x."""
derivatives = self._compute_multiple_derivatives_1d(u, [order])
return derivatives[0]
def _compute_multiple_derivatives_2d(self, u, derivatives):
"""2D multiple derivatives computation."""
if u is None:
raise ValueError("Input tensor u is None")
L_x, L_y = self.L[0], self.L[1]
nx, ny = u.shape[-2], u.shape[-1]
u_clone = u.clone()
# Apply Fourier continuation if specified
if self.use_fc and self.FC is not None:
FC = self.FC.to(u_clone.device)
u_clone = FC(u_clone, dim=2)
L_x *= (nx + self.fc_n_additional_pts) / nx
L_y *= (ny + self.fc_n_additional_pts) / ny
# Update grid parameters after extension
nx, ny = u_clone.shape[-2], u_clone.shape[-1]
dx, dy = L_x / nx, L_y / ny
# 2D FFT
u_h = torch.fft.fft2(u_clone, dim=(-2, -1))
# Frequency arrays
k_x = torch.fft.fftfreq(nx, d=dx, device=u_h.device) * (2 * torch.pi)
k_y = torch.fft.fftfreq(ny, d=dy, device=u_h.device) * (2 * torch.pi)
# Create frequency meshgrid
KX, KY = torch.meshgrid(k_x, k_y, indexing="ij")
# Apply low-pass filter if specified
if self.low_pass_filter_ratio is not None:
cutoff_x = int(nx * self.low_pass_filter_ratio)
cutoff_y = int(ny * self.low_pass_filter_ratio)
u_h[..., cutoff_y:, :] = 0
u_h[..., :, cutoff_x:] = 0
# Compute derivatives
results = []
for order_x, order_y in derivatives:
# Expand meshgrid tensors for proper broadcasting
KX_expanded = KX.expand(u_h.shape)
KY_expanded = KY.expand(u_h.shape)
derivative_u_h = ((1j * KX_expanded) ** order_x) * ((1j * KY_expanded) ** order_y) * u_h
results.append(derivative_u_h)
derivatives_ft = torch.stack(results, dim=0)
derivatives_real = torch.fft.ifft2(derivatives_ft, dim=(-2, -1)).real
# Crop result if Fourier continuation was used
if self.use_fc and self.FC is not None:
FC = self.FC.to(derivatives_real.device)
derivatives_real = FC.restrict(derivatives_real, dim=2)
return [derivatives_real[i] for i in range(len(derivatives))]
def _dx_2d(self, u, order):
"""2D derivative with respect to x."""
derivatives = self._compute_multiple_derivatives_2d(u, [(order, 0)])
return derivatives[0]
def _dy_2d(self, u, order):
"""2D derivative with respect to y."""
derivatives = self._compute_multiple_derivatives_2d(u, [(0, order)])
return derivatives[0]
def _compute_multiple_derivatives_3d(self, u, derivatives):
"""3D multiple derivatives computation."""
if u is None:
raise ValueError("Input tensor u is None")
L_x, L_y, L_z = self.L[0], self.L[1], self.L[2]
nx, ny, nz = u.shape[-3], u.shape[-2], u.shape[-1]
u_clone = u.clone()
# Apply Fourier continuation if specified
if self.use_fc and self.FC is not None:
FC = self.FC.to(u_clone.device)
u_clone = FC(u_clone, dim=3)
L_x *= (nx + self.fc_n_additional_pts) / nx
L_y *= (ny + self.fc_n_additional_pts) / ny
L_z *= (nz + self.fc_n_additional_pts) / nz
# Update grid parameters after extension
nx, ny, nz = u_clone.shape[-3], u_clone.shape[-2], u_clone.shape[-1]
dx, dy, dz = L_x / nx, L_y / ny, L_z / nz
# 3D FFT
u_h = torch.fft.fftn(u_clone, dim=(-3, -2, -1))
# Frequency arrays
k_x = torch.fft.fftfreq(nx, d=dx, device=u_h.device) * (2 * torch.pi)
k_y = torch.fft.fftfreq(ny, d=dy, device=u_h.device) * (2 * torch.pi)
k_z = torch.fft.fftfreq(nz, d=dz, device=u_h.device) * (2 * torch.pi)
# Create frequency meshgrid
KX, KY, KZ = torch.meshgrid(k_x, k_y, k_z, indexing="ij")
# Apply low-pass filter if specified
if self.low_pass_filter_ratio is not None:
cutoff_x = int(nx * self.low_pass_filter_ratio)
cutoff_y = int(ny * self.low_pass_filter_ratio)
cutoff_z = int(nz * self.low_pass_filter_ratio)
u_h[..., cutoff_y:, :, :] = 0
u_h[..., :, cutoff_x:, :] = 0
u_h[..., :, :, cutoff_z:] = 0
# Compute derivatives
results = []
for order_x, order_y, order_z in derivatives:
# Expand meshgrid tensors for proper broadcasting
KX_expanded = KX.expand(u_h.shape)
KY_expanded = KY.expand(u_h.shape)
KZ_expanded = KZ.expand(u_h.shape)
derivative_u_h = ((1j * KX_expanded) ** order_x) * ((1j * KY_expanded) ** order_y) * ((1j * KZ_expanded) ** order_z) * u_h
results.append(derivative_u_h)
derivatives_ft = torch.stack(results, dim=0)
derivatives_real = torch.fft.ifftn(derivatives_ft, dim=(-3, -2, -1)).real
# Crop result if Fourier continuation was used
if self.use_fc and self.FC is not None:
FC = self.FC.to(derivatives_real.device)
derivatives_real = FC.restrict(derivatives_real, dim=3)
return [derivatives_real[i] for i in range(len(derivatives))]
def _dx_3d(self, u, order):
"""3D derivative with respect to x."""
derivatives = self._compute_multiple_derivatives_3d(u, [(order, 0, 0)])
return derivatives[0]
def _dy_3d(self, u, order):
"""3D derivative with respect to y."""
derivatives = self._compute_multiple_derivatives_3d(u, [(0, order, 0)])
return derivatives[0]
def _dz_3d(self, u, order):
"""3D derivative with respect to z."""
derivatives = self._compute_multiple_derivatives_3d(u, [(0, 0, order)])
return derivatives[0]