Training a SFNO on the spherical Shallow Water equations

Using the small Spherical Shallow Water Equations example we ship with the package to train a Spherical Fourier-Neural Operator (SFNO).

This tutorial demonstrates how to train neural operators on spherical domains, which is crucial for many geophysical applications like weather prediction, ocean modeling, and climate science. The SFNO extends the FNO architecture to handle data on the sphere using spherical harmonics instead of regular Fourier modes.

The Shallow Water Equations describe the motion of a thin layer of fluid and are fundamental in atmospheric and oceanic dynamics.

Import dependencies

We import the necessary modules for training a Spherical Fourier Neural Operator

import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import SFNO
from neuralop import Trainer
from neuralop.training import AdamW
from neuralop.data.datasets import load_spherical_swe
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Loading the Spherical Shallow Water Equations dataset

We load the spherical shallow water equations dataset with multiple resolutions. The dataset contains velocity and height fields on the sphere, which are the fundamental variables in shallow water dynamics.

train_loader, test_loaders = load_spherical_swe(
    n_train=200,
    batch_size=32,
    train_resolution=(32, 64),
    test_resolutions=[(32, 64), (64, 128)],
    n_tests=[40, 40],
    test_batch_sizes=[40, 40],
)
Loading train dataloader at resolution (32, 64) with 200 samples and batch-size=32
Loading test dataloader at resolution (32, 64) with 40 samples and batch-size=40
Loading test dataloader at resolution (64, 128) with 40 samples and batch-size=40

Creating the Spherical FNO model

model = SFNO(
    n_modes=(16, 32),
    in_channels=3,
    out_channels=3,
    hidden_channels=64,
    domain_padding=[0.05, 0.05],
    n_layers=2,
)
model = model.to(device)

# Count and display the number of parameters
n_params = count_model_params(model)
print(f"\nOur model has {n_params} parameters.")
sys.stdout.flush()
Our model has 296707 parameters.

Creating the optimizer and scheduler

We use AdamW optimizer with a lower learning rate for spherical data

optimizer = AdamW(model.parameters(), lr=5e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

Setting up loss functions

For spherical data, we use L2 loss with sum reduction to handle the varying grid sizes across different latitudes on the sphere

l2loss = LpLoss(d=2, p=2, reduction="sum")

train_loss = l2loss
eval_losses = {"l2": l2loss}
print("\n### MODEL ###\n", model)
print("\n### OPTIMIZER ###\n", optimizer)
print("\n### SCHEDULER ###\n", scheduler)
print("\n### LOSSES ###")
print(f"\n * Train: {train_loss}")
print(f"\n * Test: {eval_losses}")
sys.stdout.flush()
### MODEL ###
 SFNO(
  (positional_embedding): GridEmbeddingND()
  (domain_padding): DomainPadding()
  (fno_blocks): FNOBlocks(
    (convs): ModuleList(
      (0-1): 2 x SphericalConv(
        (weight): ComplexDenseTensor(shape=torch.Size([64, 64, 16]), rank=None)
        (sht_handle): SHT(
          (_SHT_cache): ModuleDict()
          (_iSHT_cache): ModuleDict()
        )
      )
    )
    (fno_skips): ModuleList(
      (0-1): 2 x Flattened1dConv(
        (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
      )
    )
    (channel_mlp): ModuleList(
      (0-1): 2 x ChannelMLP(
        (fcs): ModuleList(
          (0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
          (1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
        )
      )
    )
    (channel_mlp_skips): ModuleList(
      (0-1): 2 x SoftGating()
    )
  )
  (lifting): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(5, 128, kernel_size=(1,), stride=(1,))
      (1): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
    )
  )
  (projection): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (1): Conv1d(128, 3, kernel_size=(1,), stride=(1,))
    )
  )
)

### OPTIMIZER ###
 AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    initial_lr: 0.005
    lr: 0.005
    weight_decay: 0.0001
)

### SCHEDULER ###
 <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7fb5f4a5b4d0>

### LOSSES ###

 * Train: <neuralop.losses.data_losses.LpLoss object at 0x7fb5f4a5b750>

 * Test: {'l2': <neuralop.losses.data_losses.LpLoss object at 0x7fb5f4a5b750>}

Creating the trainer

We create a Trainer object that handles the training loop for spherical data

trainer = Trainer(
    model=model,
    n_epochs=30,
    device=device,
    wandb_log=False,  # Disable Weights & Biases logging
    eval_interval=5,  # Evaluate every 5 epochs
    use_distributed=False,  # Single GPU/CPU training
    verbose=True,  # Print training progress
)

Training the SFNO model

We train the model on the spherical shallow water equations dataset. The trainer will handle the forward pass through the SFNO, compute the L2 loss, backpropagate, and evaluate on test data.

trainer.train(
    train_loader=train_loader,
    test_loaders=test_loaders,
    optimizer=optimizer,
    scheduler=scheduler,
    regularizer=False,
    training_loss=train_loss,
    eval_losses=eval_losses,
)
Training on 200 samples
Testing on [40, 40] samples         on resolutions [(32, 64), (64, 128)].
/opt/hostedtoolcache/Python/3.13.7/x64/lib/python3.13/site-packages/torch/nn/modules/module.py:1786: UserWarning: FNO.forward() received unexpected keyword arguments: ['y']. These arguments will be ignored.
  return forward_call(*args, **kwargs)
Raw outputs of shape torch.Size([32, 3, 32, 64])
/home/runner/work/neuraloperator/neuraloperator/neuralop/training/trainer.py:536: UserWarning: LpLoss.__call__() received unexpected keyword arguments: ['x']. These arguments will be ignored.
  loss += training_loss(out, **sample)
[0] time=3.66, avg_loss=2.5228, train_err=72.0788
Eval: (32, 64)_l2=1.8109, (64, 128)_l2=2.2560
[5] time=3.55, avg_loss=0.8854, train_err=25.2974
Eval: (32, 64)_l2=0.8866, (64, 128)_l2=2.5050
[10] time=3.59, avg_loss=0.6637, train_err=18.9629
Eval: (32, 64)_l2=0.6947, (64, 128)_l2=2.4491
[15] time=3.60, avg_loss=0.5366, train_err=15.3302
Eval: (32, 64)_l2=0.5969, (64, 128)_l2=2.3813
[20] time=3.82, avg_loss=0.4739, train_err=13.5407
Eval: (32, 64)_l2=0.5328, (64, 128)_l2=2.4176
[25] time=3.87, avg_loss=0.4327, train_err=12.3628
Eval: (32, 64)_l2=0.5153, (64, 128)_l2=2.4114

{'train_err': 12.268861498151507, 'avg_loss': 0.4294101524353027, 'avg_lasso_loss': None, 'epoch_train_time': 3.93271066300008}

Visualizing SFNO predictions on spherical data

We visualize the model’s predictions on spherical shallow water equations data. Note that we trained on a very small resolution for a very small number of epochs. In practice, we would train at larger resolution on many more samples.

However, for practicality, we created a minimal example that: i) fits in just a few MB of memory ii) can be trained quickly on CPU

In practice we would train a Neural Operator on one or multiple GPUs

fig = plt.figure(figsize=(14, 7))
for index, resolution in enumerate([(32, 64), (64, 128)]):
    test_samples = test_loaders[resolution].dataset
    data = test_samples[0]
    # Input x
    x = data["x"]
    # Ground-truth
    y = data["y"][0, ...].numpy()
    # Model prediction: SFNO output
    x_in = x.unsqueeze(0).to(device)
    out = model(x_in).squeeze()[0, ...].detach().cpu().numpy()
    x = x[0, ...].detach().numpy()

    # Plot input fields
    ax = fig.add_subplot(2, 3, index * 3 + 1)
    ax.imshow(x)
    ax.set_title(f"Input x {resolution}")
    plt.xticks([], [])
    plt.yticks([], [])

    # Compute the min and max to use consistent color mapping
    vmin = y.min()
    vmax = y.max()

    # Plot ground-truth fields
    ax = fig.add_subplot(2, 3, index * 3 + 2)
    im_gt = ax.imshow(y, vmin=vmin, vmax=vmax)
    ax.set_title("Ground-truth y")
    plt.xticks([], [])
    plt.yticks([], [])

    # Plot model prediction
    ax = fig.add_subplot(2, 3, index * 3 + 3)
    im_pred = ax.imshow(out, vmin=vmin, vmax=vmax)
    ax.set_title("SFNO prediction")
    plt.xticks([], [])
    plt.yticks([], [])

fig.suptitle("SFNO predictions on spherical shallow water equations", y=0.98, fontsize=24)
plt.tight_layout()
fig.show()
SFNO predictions on spherical shallow water equations, Input x (32, 64), Ground-truth y, SFNO prediction, Input x (64, 128), Ground-truth y, SFNO prediction

Total running time of the script: (2 minutes 3.089 seconds)

Gallery generated by Sphinx-Gallery