Training an FNO with incremental meta-learning

A demo of the Incremental FNO meta-learning algorithm on our small Darcy-Flow dataset.

This tutorial demonstrates incremental meta-learning for neural operators, which allows the model to gradually increase its complexity during training. This approach can lead to:

  • Better convergence properties

  • More stable training dynamics

  • Improved generalization

  • Reduced computational requirements during early training

The incremental approach starts with a small number of Fourier modes and gradually increases the model capacity as training progresses.

Import dependencies

We import the necessary modules for incremental FNO training

import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import FNO
from neuralop.data.datasets import load_darcy_flow_small
from neuralop.utils import count_model_params
from neuralop.training import AdamW
from neuralop.training.incremental import IncrementalFNOTrainer
from neuralop.data.transforms.data_processors import IncrementalDataProcessor
from neuralop import LpLoss, H1Loss

Loading the Darcy-Flow dataset

We load the Darcy-Flow dataset with multiple resolutions for incremental training.

train_loader, test_loaders, output_encoder = load_darcy_flow_small(
    n_train=100,
    batch_size=16,
    test_resolutions=[16, 32],
    n_tests=[100, 50],
    test_batch_sizes=[32, 32],
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Loading test db for resolution 16 with 100 samples
Loading test db for resolution 32 with 50 samples

Configuring incremental training

We set up the incremental FNO model with a small starting number of modes. The model will gradually increase its capacity during training. We choose to update the modes using the incremental gradient explained algorithm

incremental = True
if incremental:
    starting_modes = (2, 2)  # Start with very few modes
else:
    starting_modes = (8, 8)  # Standard number of modes

Creating the incremental FNO model

We create an FNO model with a maximum number of modes that can be reached during incremental training. The model starts with fewer modes and grows.

model = FNO(
    max_n_modes=(8, 8),  # Maximum modes the model can reach
    n_modes=starting_modes,  # Starting number of modes
    hidden_channels=32,
    in_channels=1,
    out_channels=1,
)
model = model.to(device)
n_params = count_model_params(model)

Setting up the optimizer and scheduler

We use AdamW optimizer with weight decay for regularization

optimizer = AdamW(model.parameters(), lr=8e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

Configuring incremental data processing

If one wants to use Incremental Resolution, one should use the IncrementalDataProcessor. When passed to the trainer, the trainer will automatically update the resolution.

Key parameters for incremental resolution:

  • incremental_resolution: bool, default is False. If True, increase the resolution of the input incrementally

  • incremental_res_gap: parameter for resolution updates

  • subsampling_rates: a list of resolutions to use

  • dataset_indices: a list of indices of the dataset to slice to regularize the input resolution

  • dataset_resolution: the resolution of the input

  • epoch_gap: the number of epochs to wait before increasing the resolution

  • verbose: if True, print the resolution and the number of modes

data_transform = IncrementalDataProcessor(
    in_normalizer=None,
    out_normalizer=None,
    device=device,
    subsampling_rates=[2, 1],  # Resolution scaling factors
    dataset_resolution=16,  # Base resolution
    dataset_indices=[2, 3],  # Dataset indices for regularization
    epoch_gap=10,  # Epochs between resolution updates
    verbose=True,  # Print progress information
)

data_transform = data_transform.to(device)
Original Incre Res: change index to 0
Original Incre Res: change sub to 2
Original Incre Res: change res to 8

Setting up loss functions

We use H1 loss for training and L2 loss for evaluation

l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)
train_loss = h1loss
eval_losses = {"h1": h1loss, "l2": l2loss}

Displaying training configuration

We display the model parameters, optimizer, scheduler, and loss functions to verify our incremental training setup

print("\n### N PARAMS ###\n", n_params)
print("\n### OPTIMIZER ###\n", optimizer)
print("\n### SCHEDULER ###\n", scheduler)
print("\n### LOSSES ###")
print("\n### INCREMENTAL RESOLUTION + GRADIENT EXPLAINED ###")
print(f"\n * Train: {train_loss}")
print(f"\n * Test: {eval_losses}")
sys.stdout.flush()
### N PARAMS ###
 537441

### OPTIMIZER ###
 AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    initial_lr: 0.008
    lr: 0.008
    weight_decay: 0.0001
)

### SCHEDULER ###
 <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7fb5e7d1de00>

### LOSSES ###

### INCREMENTAL RESOLUTION + GRADIENT EXPLAINED ###

 * Train: <neuralop.losses.data_losses.H1Loss object at 0x7fb5da514a50>

 * Test: {'h1': <neuralop.losses.data_losses.H1Loss object at 0x7fb5da514a50>, 'l2': <neuralop.losses.data_losses.LpLoss object at 0x7fb5e7d1dcd0>}

Configuring the IncrementalFNOTrainer

We set up the IncrementalFNOTrainer with various incremental learning options. Other options include setting incremental_loss_gap = True. If one wants to use incremental resolution, set it to True. In this example we only update the modes and not the resolution. When using incremental resolution, keep in mind that the number of modes initially set should be strictly less than the resolution.

Key parameters for incremental training:

  • incremental_grad: bool, default is False. If True, use the base incremental algorithm based on gradient variance

  • incremental_grad_eps: threshold for gradient variance

  • incremental_buffer: number of buffer modes to calculate gradient variance

  • incremental_max_iter: initial number of iterations

  • incremental_grad_max_iter: maximum iterations to accumulate gradients

  • incremental_loss_gap: bool, default is False. If True, use the incremental algorithm based on loss gap

  • incremental_loss_eps: threshold for loss gap

# Create the IncrementalFNOTrainer with our configuration
trainer = IncrementalFNOTrainer(
    model=model,
    n_epochs=20,
    data_processor=data_transform,
    device=device,
    verbose=True,
    incremental_loss_gap=False,  # Use gradient-based incremental learning
    incremental_grad=True,  # Enable gradient-based mode updates
    incremental_grad_eps=0.9999,  # Gradient variance threshold
    incremental_loss_eps=0.001,  # Loss gap threshold
    incremental_buffer=5,  # Buffer modes for gradient calculation
    incremental_max_iter=1,  # Initial iterations
    incremental_grad_max_iter=2,  # Maximum gradient accumulation iterations
)

Training the incremental FNO model

We train the model using incremental meta-learning. The trainer will: 1. Start with a small number of Fourier modes 2. Gradually increase the model capacity based on gradient variance 3. Monitor the incremental learning progress 4. Evaluate on test data throughout training

trainer.train(
    train_loader,
    test_loaders,
    optimizer,
    scheduler,
    regularizer=False,
    training_loss=train_loss,
    eval_losses=eval_losses,
)
Training on 100 samples
Testing on [50, 50] samples         on resolutions [16, 32].
/opt/hostedtoolcache/Python/3.13.7/x64/lib/python3.13/site-packages/torch/utils/data/dataloader.py:668: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.13.7/x64/lib/python3.13/site-packages/torch/nn/modules/module.py:1786: UserWarning: FNO.forward() received unexpected keyword arguments: ['y']. These arguments will be ignored.
  return forward_call(*args, **kwargs)
Raw outputs of shape torch.Size([16, 1, 8, 8])
/home/runner/work/neuraloperator/neuraloperator/neuralop/training/trainer.py:536: UserWarning: H1Loss.__call__() received unexpected keyword arguments: ['x']. These arguments will be ignored.
  loss += training_loss(out, **sample)
[0] time=0.23, avg_loss=0.9379, train_err=13.3989
/home/runner/work/neuraloperator/neuraloperator/neuralop/training/trainer.py:581: UserWarning: LpLoss.__call__() received unexpected keyword arguments: ['x']. These arguments will be ignored.
  val_loss = loss(out, **sample)
Eval: 16_h1=0.8951, 16_l2=0.5764, 32_h1=0.9397, 32_l2=0.5756
[1] time=0.21, avg_loss=0.8235, train_err=11.7637
Eval: 16_h1=0.8356, 16_l2=0.4704, 32_h1=1.0422, 32_l2=0.4983
[2] time=0.22, avg_loss=0.6751, train_err=9.6441
/home/runner/work/neuraloperator/neuraloperator/neuralop/training/incremental.py:244: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  torch.Tensor(strength_vector),
Eval: 16_h1=0.8407, 16_l2=0.4864, 32_h1=1.0523, 32_l2=0.5230
[3] time=0.21, avg_loss=0.6172, train_err=8.8165
Eval: 16_h1=0.7064, 16_l2=0.3904, 32_h1=0.9171, 32_l2=0.4216
[4] time=0.22, avg_loss=0.5630, train_err=8.0428
Eval: 16_h1=0.7321, 16_l2=0.4232, 32_h1=0.9270, 32_l2=0.4512
[5] time=0.22, avg_loss=0.5213, train_err=7.4477
Eval: 16_h1=0.8455, 16_l2=0.4952, 32_h1=1.0422, 32_l2=0.5151
[6] time=0.22, avg_loss=0.5604, train_err=8.0060
Eval: 16_h1=0.8722, 16_l2=0.4394, 32_h1=1.2960, 32_l2=0.4788
[7] time=0.22, avg_loss=0.5087, train_err=7.2676
Eval: 16_h1=0.6518, 16_l2=0.3533, 32_h1=0.9138, 32_l2=0.3902
[8] time=0.22, avg_loss=0.4225, train_err=6.0350
Eval: 16_h1=0.7022, 16_l2=0.3866, 32_h1=0.9870, 32_l2=0.4312
[9] time=0.22, avg_loss=0.4486, train_err=6.4087
Eval: 16_h1=0.6325, 16_l2=0.3323, 32_h1=0.9428, 32_l2=0.3703
Incre Res Update: change index to 1
Incre Res Update: change sub to 1
Incre Res Update: change res to 16
[10] time=0.30, avg_loss=0.5505, train_err=7.8647
Eval: 16_h1=0.5113, 16_l2=0.2967, 32_h1=0.6253, 32_l2=0.2858
[11] time=0.29, avg_loss=0.4654, train_err=6.6486
Eval: 16_h1=0.5035, 16_l2=0.3235, 32_h1=0.6616, 32_l2=0.3382
[12] time=0.29, avg_loss=0.4690, train_err=6.7004
Eval: 16_h1=0.4196, 16_l2=0.2600, 32_h1=0.5457, 32_l2=0.2662
[13] time=0.29, avg_loss=0.3872, train_err=5.5309
Eval: 16_h1=0.3993, 16_l2=0.2503, 32_h1=0.5213, 32_l2=0.2664
[14] time=0.29, avg_loss=0.3615, train_err=5.1648
Eval: 16_h1=0.4214, 16_l2=0.2614, 32_h1=0.5556, 32_l2=0.2738
[15] time=0.30, avg_loss=0.3779, train_err=5.3988
Eval: 16_h1=0.3647, 16_l2=0.2222, 32_h1=0.4885, 32_l2=0.2301
[16] time=0.29, avg_loss=0.3327, train_err=4.7523
Eval: 16_h1=0.3624, 16_l2=0.2319, 32_h1=0.4977, 32_l2=0.2450
[17] time=0.29, avg_loss=0.3194, train_err=4.5629
Eval: 16_h1=0.3759, 16_l2=0.2348, 32_h1=0.4892, 32_l2=0.2361
[18] time=0.30, avg_loss=0.3482, train_err=4.9740
Eval: 16_h1=0.3483, 16_l2=0.2251, 32_h1=0.4770, 32_l2=0.2444
[19] time=0.29, avg_loss=0.3083, train_err=4.4047
Eval: 16_h1=0.3184, 16_l2=0.2044, 32_h1=0.4294, 32_l2=0.2149

{'train_err': 4.404683658054897, 'avg_loss': 0.3083278560638428, 'avg_lasso_loss': None, 'epoch_train_time': 0.29014379700004156, '16_h1': tensor(0.3184), '16_l2': tensor(0.2044), '32_h1': tensor(0.4294), '32_l2': tensor(0.2149)}

Visualizing incremental FNO predictions

We visualize the model’s predictions after incremental training. Note that we trained on a very small resolution for a very small number of epochs. In practice, we would train at larger resolution on many more samples.

However, for practicality, we created a minimal example that: i) fits in just a few MB of memory ii) can be trained quickly on CPU

In practice we would train a Neural Operator on one or multiple GPUs

test_samples = test_loaders[32].dataset

fig = plt.figure(figsize=(7, 7))
for index in range(3):
    data = test_samples[index]
    # Input x
    x = data["x"].to(device)
    # Ground-truth
    y = data["y"].to(device)
    # Model prediction: incremental FNO output
    out = model(x.unsqueeze(0))

    # Plot input x
    ax = fig.add_subplot(3, 3, index * 3 + 1)
    x = x.cpu().squeeze().detach().numpy()
    y = y.cpu().squeeze().detach().numpy()
    ax.imshow(x, cmap="gray")
    if index == 0:
        ax.set_title("Input x")
    plt.xticks([], [])
    plt.yticks([], [])

    # Plot ground-truth y
    ax = fig.add_subplot(3, 3, index * 3 + 2)
    ax.imshow(y.squeeze())
    if index == 0:
        ax.set_title("Ground-truth y")
    plt.xticks([], [])
    plt.yticks([], [])

    # Plot model prediction
    ax = fig.add_subplot(3, 3, index * 3 + 3)
    ax.imshow(out.cpu().squeeze().detach().numpy())
    if index == 0:
        ax.set_title("Incremental FNO prediction")
    plt.xticks([], [])
    plt.yticks([], [])

fig.suptitle("Incremental FNO predictions on Darcy-Flow data", y=0.98)
plt.tight_layout()
fig.show()
Incremental FNO predictions on Darcy-Flow data, Input x, Ground-truth y, Incremental FNO prediction

Total running time of the script: (0 minutes 11.301 seconds)

Gallery generated by Sphinx-Gallery