Training an FNO with incremental meta-learning

In this example, we demonstrate how to use the small Darcy-Flow example we ship with the package to demonstrate the Incremental FNO meta-learning algorithm

import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import FNO
from neuralop.data.datasets import load_darcy_flow_small
from neuralop.utils import count_model_params
from neuralop.training import AdamW
from neuralop.training.incremental import IncrementalFNOTrainer
from neuralop.data.transforms.data_processors import IncrementalDataProcessor
from neuralop import LpLoss, H1Loss

Loading the Darcy flow dataset

train_loader, test_loaders, output_encoder = load_darcy_flow_small(
    n_train=100,
    batch_size=16,
    test_resolutions=[16, 32],
    n_tests=[100, 50],
    test_batch_sizes=[32, 32],
)
/home/runner/work/neuraloperator/neuraloperator/neuralop/data/datasets/pt_dataset.py:93: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(
Loading test db for resolution 16 with 100 samples
/home/runner/work/neuraloperator/neuraloperator/neuralop/data/datasets/pt_dataset.py:172: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(Path(root_dir).joinpath(f"{dataset_name}_test_{res}.pt").as_posix())
Loading test db for resolution 32 with 50 samples

Choose device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Set up the incremental FNO model We start with 2 modes in each dimension We choose to update the modes by the incremental gradient explained algorithm

incremental = True
if incremental:
    starting_modes = (2, 2)
else:
    starting_modes = (16, 16)

set up model

model = FNO(
    max_n_modes=(16, 16),
    n_modes=starting_modes,
    hidden_channels=32,
    in_channels=1,
    out_channels=1,
)
model = model.to(device)
n_params = count_model_params(model)

Set up the optimizer and scheduler

optimizer = AdamW(model.parameters(), lr=8e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)


# If one wants to use Incremental Resolution, one should use the IncrementalDataProcessor - When passed to the trainer, the trainer will automatically update the resolution
# Incremental_resolution : bool, default is False
#    if True, increase the resolution of the input incrementally
#    uses the incremental_res_gap parameter
#    uses the subsampling_rates parameter - a list of resolutions to use
#    uses the dataset_indices parameter - a list of indices of the dataset to slice to regularize the input resolution
#    uses the dataset_resolution parameter - the resolution of the input
#    uses the epoch_gap parameter - the number of epochs to wait before increasing the resolution
#    uses the verbose parameter - if True, print the resolution and the number of modes
data_transform = IncrementalDataProcessor(
    in_normalizer=None,
    out_normalizer=None,
    device=device,
    subsampling_rates=[2, 1],
    dataset_resolution=16,
    dataset_indices=[2, 3],
    epoch_gap=10,
    verbose=True,
)

data_transform = data_transform.to(device)
Original Incre Res: change index to 0
Original Incre Res: change sub to 2
Original Incre Res: change res to 8

Set up the losses

l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)
train_loss = h1loss
eval_losses = {"h1": h1loss, "l2": l2loss}
print("\n### N PARAMS ###\n", n_params)
print("\n### OPTIMIZER ###\n", optimizer)
print("\n### SCHEDULER ###\n", scheduler)
print("\n### LOSSES ###")
print("\n### INCREMENTAL RESOLUTION + GRADIENT EXPLAINED ###")
print(f"\n * Train: {train_loss}")
print(f"\n * Test: {eval_losses}")
sys.stdout.flush()
### N PARAMS ###
 2110305

### OPTIMIZER ###
 AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    initial_lr: 0.008
    lr: 0.008
    weight_decay: 0.0001
)

### SCHEDULER ###
 <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7f55e4bfba00>

### LOSSES ###

### INCREMENTAL RESOLUTION + GRADIENT EXPLAINED ###

 * Train: <neuralop.losses.data_losses.H1Loss object at 0x7f55e4bfb3a0>

 * Test: {'h1': <neuralop.losses.data_losses.H1Loss object at 0x7f55e4bfb3a0>, 'l2': <neuralop.losses.data_losses.LpLoss object at 0x7f55e4bfb9a0>}

Set up the IncrementalTrainer other options include setting incremental_loss_gap = True If one wants to use incremental resolution set it to True In this example we only update the modes and not the resolution When using the incremental resolution one should keep in mind that the numnber of modes initially set should be strictly less than the resolution Again these are the various paramaters for the various incremental settings incremental_grad : bool, default is False

if True, use the base incremental algorithm which is based on gradient variance uses the incremental_grad_eps parameter - set the threshold for gradient variance uses the incremental_buffer paramater - sets the number of buffer modes to calculate the gradient variance uses the incremental_max_iter parameter - sets the initial number of iterations uses the incremental_grad_max_iter parameter - sets the maximum number of iterations to accumulate the gradients

incremental_loss_gapbool, default is False

if True, use the incremental algorithm based on loss gap uses the incremental_loss_eps parameter

# Finally pass all of these to the Trainer
trainer = IncrementalFNOTrainer(
    model=model,
    n_epochs=20,
    data_processor=data_transform,
    device=device,
    verbose=True,
    incremental_loss_gap=False,
    incremental_grad=True,
    incremental_grad_eps=0.9999,
    incremental_loss_eps = 0.001,
    incremental_buffer=5,
    incremental_max_iter=1,
    incremental_grad_max_iter=2,
)

Train the model

trainer.train(
    train_loader,
    test_loaders,
    optimizer,
    scheduler,
    regularizer=False,
    training_loss=train_loss,
    eval_losses=eval_losses,
)
Training on 100 samples
Testing on [50, 50] samples         on resolutions [16, 32].
Raw outputs of shape torch.Size([16, 1, 8, 8])
[0] time=0.24, avg_loss=0.8115, train_err=11.5929
Eval: 16_h1=0.7332, 16_l2=0.5739, 32_h1=0.7863, 32_l2=0.5720
[1] time=0.23, avg_loss=0.6661, train_err=9.5159
Eval: 16_h1=0.8889, 16_l2=0.7005, 32_h1=1.1195, 32_l2=0.7407
[2] time=0.23, avg_loss=0.6497, train_err=9.2815
Eval: 16_h1=0.6372, 16_l2=0.4883, 32_h1=0.6967, 32_l2=0.5000
[3] time=0.23, avg_loss=0.5559, train_err=7.9411
Eval: 16_h1=0.6112, 16_l2=0.4348, 32_h1=0.7432, 32_l2=0.4530
[4] time=0.23, avg_loss=0.4852, train_err=6.9312
Eval: 16_h1=0.5762, 16_l2=0.4037, 32_h1=0.7138, 32_l2=0.4262
[5] time=0.24, avg_loss=0.4393, train_err=6.2764
Eval: 16_h1=0.5515, 16_l2=0.3826, 32_h1=0.7143, 32_l2=0.4146
[6] time=0.24, avg_loss=0.4039, train_err=5.7703
Eval: 16_h1=0.5421, 16_l2=0.3832, 32_h1=0.7289, 32_l2=0.4221
[7] time=0.23, avg_loss=0.3626, train_err=5.1807
Eval: 16_h1=0.5418, 16_l2=0.3902, 32_h1=0.7402, 32_l2=0.4312
[8] time=0.23, avg_loss=0.3563, train_err=5.0894
Eval: 16_h1=0.5598, 16_l2=0.3874, 32_h1=0.7716, 32_l2=0.4260
[9] time=0.23, avg_loss=0.3525, train_err=5.0354
Eval: 16_h1=0.4780, 16_l2=0.3197, 32_h1=0.6826, 32_l2=0.3563
Incre Res Update: change index to 1
Incre Res Update: change sub to 1
Incre Res Update: change res to 16
[10] time=0.29, avg_loss=0.4253, train_err=6.0757
Eval: 16_h1=0.4069, 16_l2=0.2959, 32_h1=0.4904, 32_l2=0.2928
[11] time=0.28, avg_loss=0.3745, train_err=5.3500
Eval: 16_h1=0.3820, 16_l2=0.2869, 32_h1=0.4769, 32_l2=0.3026
[12] time=0.29, avg_loss=0.3405, train_err=4.8636
Eval: 16_h1=0.3404, 16_l2=0.2598, 32_h1=0.4410, 32_l2=0.2731
[13] time=0.29, avg_loss=0.3090, train_err=4.4136
Eval: 16_h1=0.3231, 16_l2=0.2452, 32_h1=0.4245, 32_l2=0.2586
[14] time=0.29, avg_loss=0.2896, train_err=4.1368
Eval: 16_h1=0.3130, 16_l2=0.2380, 32_h1=0.4161, 32_l2=0.2522
[15] time=0.29, avg_loss=0.2789, train_err=3.9843
Eval: 16_h1=0.3072, 16_l2=0.2324, 32_h1=0.4151, 32_l2=0.2455
[16] time=0.29, avg_loss=0.2690, train_err=3.8434
Eval: 16_h1=0.3042, 16_l2=0.2305, 32_h1=0.4100, 32_l2=0.2425
[17] time=0.29, avg_loss=0.2637, train_err=3.7674
Eval: 16_h1=0.2954, 16_l2=0.2229, 32_h1=0.4023, 32_l2=0.2354
[18] time=0.29, avg_loss=0.2557, train_err=3.6533
Eval: 16_h1=0.2756, 16_l2=0.2105, 32_h1=0.3780, 32_l2=0.2269
[19] time=0.29, avg_loss=0.2395, train_err=3.4208
Eval: 16_h1=0.2735, 16_l2=0.2106, 32_h1=0.3738, 32_l2=0.2303

{'train_err': 3.420767681939261, 'avg_loss': 0.2394537377357483, 'avg_lasso_loss': None, 'epoch_train_time': 0.2902170739999974, '16_h1': tensor(0.2735), '16_l2': tensor(0.2106), '32_h1': tensor(0.3738), '32_l2': tensor(0.2303)}

Plot the prediction, and compare with the ground-truth Note that we trained on a very small resolution for a very small number of epochs In practice, we would train at larger resolution, on many more samples.

However, for practicity, we created a minimal example that i) fits in just a few Mb of memory ii) can be trained quickly on CPU

In practice we would train a Neural Operator on one or multiple GPUs

test_samples = test_loaders[32].dataset

fig = plt.figure(figsize=(7, 7))
for index in range(3):
    data = test_samples[index]
    # Input x
    x = data["x"].to(device)
    # Ground-truth
    y = data["y"].to(device)
    # Model prediction
    out = model(x.unsqueeze(0))
    ax = fig.add_subplot(3, 3, index * 3 + 1)
    x = x.cpu().squeeze().detach().numpy()
    y = y.cpu().squeeze().detach().numpy()
    ax.imshow(x, cmap="gray")
    if index == 0:
        ax.set_title("Input x")
    plt.xticks([], [])
    plt.yticks([], [])

    ax = fig.add_subplot(3, 3, index * 3 + 2)
    ax.imshow(y.squeeze())
    if index == 0:
        ax.set_title("Ground-truth y")
    plt.xticks([], [])
    plt.yticks([], [])

    ax = fig.add_subplot(3, 3, index * 3 + 3)
    ax.imshow(out.cpu().squeeze().detach().numpy())
    if index == 0:
        ax.set_title("Model prediction")
    plt.xticks([], [])
    plt.yticks([], [])

fig.suptitle("Inputs, ground-truth output and prediction.", y=0.98)
plt.tight_layout()
fig.show()
Inputs, ground-truth output and prediction., Input x, Ground-truth y, Model prediction

Total running time of the script: (0 minutes 7.205 seconds)

Gallery generated by Sphinx-Gallery