Training an FNO on Darcy-Flow

We train a Fourier Neural Operator on our small Darcy-Flow example .

Note that this dataset is much smaller than one we would use in practice. The small Darcy-flow is an example built to be trained on a CPU in a few seconds, whereas normally we would train on one or multiple GPUs.

import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import FNO
from neuralop import Trainer
from neuralop.training import AdamW
from neuralop.data.datasets import load_darcy_flow_small
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss

device = 'cpu'

Let’s load the small Darcy-flow dataset.

train_loader, test_loaders, data_processor = load_darcy_flow_small(
        n_train=1000, batch_size=32,
        test_resolutions=[16, 32], n_tests=[100, 50],
        test_batch_sizes=[32, 32],
)
data_processor = data_processor.to(device)
Loading test db for resolution 16 with 100 samples
Loading test db for resolution 32 with 50 samples

We create a simple FNO model

model = FNO(n_modes=(16, 16),
             in_channels=1,
             out_channels=1,
             hidden_channels=32,
             projection_channel_ratio=2)
model = model.to(device)

n_params = count_model_params(model)
print(f'\nOur model has {n_params} parameters.')
sys.stdout.flush()
Our model has 1192801 parameters.

Training setup

Create the optimizer

optimizer = AdamW(model.parameters(),
                                lr=8e-3,
                                weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

Then create the losses

l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)

train_loss = h1loss
eval_losses={'h1': h1loss, 'l2': l2loss}

Training the model

print('\n### MODEL ###\n', model)
print('\n### OPTIMIZER ###\n', optimizer)
print('\n### SCHEDULER ###\n', scheduler)
print('\n### LOSSES ###')
print(f'\n * Train: {train_loss}')
print(f'\n * Test: {eval_losses}')
sys.stdout.flush()
### MODEL ###
 FNO(
  (positional_embedding): GridEmbeddingND()
  (fno_blocks): FNOBlocks(
    (convs): ModuleList(
      (0-3): 4 x SpectralConv(
        (weight): DenseTensor(shape=torch.Size([32, 32, 16, 9]), rank=None)
      )
    )
    (fno_skips): ModuleList(
      (0-3): 4 x Flattened1dConv(
        (conv): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
      )
    )
    (channel_mlp): ModuleList(
      (0-3): 4 x ChannelMLP(
        (fcs): ModuleList(
          (0): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
          (1): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
        )
      )
    )
    (channel_mlp_skips): ModuleList(
      (0-3): 4 x SoftGating()
    )
  )
  (lifting): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
      (1): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
    )
  )
  (projection): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
      (1): Conv1d(64, 1, kernel_size=(1,), stride=(1,))
    )
  )
)

### OPTIMIZER ###
 AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    initial_lr: 0.008
    lr: 0.008
    weight_decay: 0.0001
)

### SCHEDULER ###
 <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7fcedb0491d0>

### LOSSES ###

 * Train: <neuralop.losses.data_losses.H1Loss object at 0x7fcedb048410>

 * Test: {'h1': <neuralop.losses.data_losses.H1Loss object at 0x7fcedb048410>, 'l2': <neuralop.losses.data_losses.LpLoss object at 0x7fcedb04b110>}

Create the trainer:

trainer = Trainer(model=model, n_epochs=20,
                  device=device,
                  data_processor=data_processor,
                  wandb_log=False,
                  eval_interval=3,
                  use_distributed=False,
                  verbose=True)

Then train the model on our small Darcy-Flow dataset:

trainer.train(train_loader=train_loader,
              test_loaders=test_loaders,
              optimizer=optimizer,
              scheduler=scheduler,
              regularizer=False,
              training_loss=train_loss,
              eval_losses=eval_losses)
Training on 1000 samples
Testing on [50, 50] samples         on resolutions [16, 32].
Raw outputs of shape torch.Size([32, 1, 16, 16])
[0] time=2.48, avg_loss=0.5903, train_err=18.4462
Eval: 16_h1=0.3639, 16_l2=0.2512, 32_h1=0.5425, 32_l2=0.2628
[3] time=2.54, avg_loss=0.2077, train_err=6.4899
Eval: 16_h1=0.2191, 16_l2=0.1505, 32_h1=0.4976, 32_l2=0.1840
[6] time=2.47, avg_loss=0.1768, train_err=5.5240
Eval: 16_h1=0.2902, 16_l2=0.2180, 32_h1=0.5098, 32_l2=0.2476
[9] time=2.47, avg_loss=0.1399, train_err=4.3722
Eval: 16_h1=0.2089, 16_l2=0.1298, 32_h1=0.5057, 32_l2=0.1746
[12] time=2.50, avg_loss=0.1350, train_err=4.2196
Eval: 16_h1=0.2000, 16_l2=0.1225, 32_h1=0.4881, 32_l2=0.1636
[15] time=2.48, avg_loss=0.1203, train_err=3.7584
Eval: 16_h1=0.2045, 16_l2=0.1263, 32_h1=0.5199, 32_l2=0.1752
[18] time=2.50, avg_loss=0.0929, train_err=2.9023
Eval: 16_h1=0.1954, 16_l2=0.1193, 32_h1=0.5096, 32_l2=0.1704

{'train_err': 2.8530129194259644, 'avg_loss': 0.09129641342163086, 'avg_lasso_loss': None, 'epoch_train_time': 2.4829514150000023}

Visualizing predictions

Let’s take a look at what our model’s predicted outputs look like. Again note that in this example, we train on a very small resolution for a very small number of epochs. In practice, we would train at a larger resolution, on many more samples.

test_samples = test_loaders[16].dataset

fig = plt.figure(figsize=(7, 7))
for index in range(3):
    data = test_samples[index]
    data = data_processor.preprocess(data, batched=False)
    # Input x
    x = data['x']
    # Ground-truth
    y = data['y']
    # Model prediction
    out = model(x.unsqueeze(0))

    ax = fig.add_subplot(3, 3, index*3 + 1)
    ax.imshow(x[0], cmap='gray')
    if index == 0:
        ax.set_title('Input x')
    plt.xticks([], [])
    plt.yticks([], [])

    ax = fig.add_subplot(3, 3, index*3 + 2)
    ax.imshow(y.squeeze())
    if index == 0:
        ax.set_title('Ground-truth y')
    plt.xticks([], [])
    plt.yticks([], [])

    ax = fig.add_subplot(3, 3, index*3 + 3)
    ax.imshow(out.squeeze().detach().numpy())
    if index == 0:
        ax.set_title('Model prediction')
    plt.xticks([], [])
    plt.yticks([], [])

fig.suptitle('Inputs, ground-truth output and prediction (16x16).', y=0.98)
plt.tight_layout()
fig.show()
Inputs, ground-truth output and prediction (16x16)., Input x, Ground-truth y, Model prediction

Zero-shot super-evaluation

In addition to training and making predictions on the same input size, the FNO’s invariance to the discretization of input data means we can natively make predictions on higher-resolution inputs and get higher-resolution outputs.

test_samples = test_loaders[32].dataset

fig = plt.figure(figsize=(7, 7))
for index in range(3):
    data = test_samples[index]
    data = data_processor.preprocess(data, batched=False)
    # Input x
    x = data['x']
    # Ground-truth
    y = data['y']
    # Model prediction
    out = model(x.unsqueeze(0))

    ax = fig.add_subplot(3, 3, index*3 + 1)
    ax.imshow(x[0], cmap='gray')
    if index == 0:
        ax.set_title('Input x')
    plt.xticks([], [])
    plt.yticks([], [])

    ax = fig.add_subplot(3, 3, index*3 + 2)
    ax.imshow(y.squeeze())
    if index == 0:
        ax.set_title('Ground-truth y')
    plt.xticks([], [])
    plt.yticks([], [])

    ax = fig.add_subplot(3, 3, index*3 + 3)
    ax.imshow(out.squeeze().detach().numpy())
    if index == 0:
        ax.set_title('Model prediction')
    plt.xticks([], [])
    plt.yticks([], [])

fig.suptitle('Inputs, ground-truth output and prediction (32x32).', y=0.98)
plt.tight_layout()
fig.show()
Inputs, ground-truth output and prediction (32x32)., Input x, Ground-truth y, Model prediction

We only trained the model on data at a resolution of 16x16, and with no modifications or special prompting, we were able to perform inference on higher-resolution input data and get higher-resolution predictions! In practice, we often want to evaluate neural operators at multiple resolutions to track a model’s zero-shot super-evaluation performance throughout training. That’s why many of our datasets, including the small Darcy-flow we showcased, are parameterized with a list of test_resolutions to choose from.

However, as you can see, these predictions are noisier than we would expect for a model evaluated at the same resolution at which it was trained. Leveraging the FNO’s discretization-invariance, there are other ways to scale the outputs of the FNO to train a true super-resolution capability.

Total running time of the script: (0 minutes 51.037 seconds)

Gallery generated by Sphinx-Gallery