U-NO on Darcy-Flow

Training a U-shaped Neural Operator on the small Darcy-Flow example we ship with the package

import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import UNO
from neuralop import Trainer
from neuralop.training import AdamW
from neuralop.data.datasets import load_darcy_flow_small
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss

device = 'cpu'

Loading the Darcy Flow dataset

train_loader, test_loaders, data_processor = load_darcy_flow_small(
        n_train=1000, batch_size=32,
        test_resolutions=[16, 32], n_tests=[100, 50],
        test_batch_sizes=[32, 32],
)

model = UNO(in_channels=1,
            out_channels=1,
            hidden_channels=64,
            projection_channels=64,
            uno_out_channels=[32,64,64,64,32],
            uno_n_modes=[[16,16],[8,8],[8,8],[8,8],[16,16]],
            uno_scalings=[[1.0,1.0],[0.5,0.5],[1,1],[2,2],[1,1]],
            horizontal_skips_map=None,
            channel_mlp_skip="linear",
            n_layers = 5,
            domain_padding=0.2)

model = model.to(device)

n_params = count_model_params(model)
print(f'\nOur model has {n_params} parameters.')
sys.stdout.flush()
Loading test db for resolution 16 with 100 samples
Loading test db for resolution 32 with 50 samples
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'

Our model has 2700097 parameters.

Create the optimizer

optimizer = AdamW(model.parameters(),
                                lr=8e-3,
                                weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

Creating the losses

l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)

train_loss = h1loss
eval_losses={'h1': h1loss, 'l2': l2loss}
print('\n### MODEL ###\n', model)
print('\n### OPTIMIZER ###\n', optimizer)
print('\n### SCHEDULER ###\n', scheduler)
print('\n### LOSSES ###')
print(f'\n * Train: {train_loss}')
print(f'\n * Test: {eval_losses}')
sys.stdout.flush()
### MODEL ###
 UNO(
  (positional_embedding): GridEmbeddingND()
  (domain_padding): DomainPadding()
  (lifting): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(3, 256, kernel_size=(1,), stride=(1,))
      (1): Conv1d(256, 64, kernel_size=(1,), stride=(1,))
    )
  )
  (fno_blocks): ModuleList(
    (0): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([64, 32, 16, 9]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(64, 32, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
            (1): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(64, 32, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
    (1): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([32, 64, 8, 5]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
            (1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
    (2): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([64, 64, 8, 5]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
            (1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
    (3): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([128, 64, 8, 5]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
            (1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
    (4): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([96, 32, 16, 9]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
            (1): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
  )
  (horizontal_skips): ModuleDict(
    (0): Flattened1dConv(
      (conv): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
    )
    (1): Flattened1dConv(
      (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
    )
  )
  (projection): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
      (1): Conv1d(64, 1, kernel_size=(1,), stride=(1,))
    )
  )
)

### OPTIMIZER ###
 AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    initial_lr: 0.008
    lr: 0.008
    weight_decay: 0.0001
)

### SCHEDULER ###
 <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7f7c5a556210>

### LOSSES ###

 * Train: <neuralop.losses.data_losses.H1Loss object at 0x7f7c5a26c830>

 * Test: {'h1': <neuralop.losses.data_losses.H1Loss object at 0x7f7c5a26c830>, 'l2': <neuralop.losses.data_losses.LpLoss object at 0x7f7c5a555f90>}

Create the trainer

trainer = Trainer(model=model,
                   n_epochs=20,
                  device=device,
                  data_processor=data_processor,
                  wandb_log=False,
                  eval_interval=3,
                  use_distributed=False,
                  verbose=True)

Actually train the model on our small Darcy-Flow dataset

trainer.train(train_loader=train_loader,
              test_loaders=test_loaders,
              optimizer=optimizer,
              scheduler=scheduler,
              regularizer=False,
              training_loss=train_loss,
              eval_losses=eval_losses)
Training on 1000 samples
Testing on [50, 50] samples         on resolutions [16, 32].
/opt/hostedtoolcache/Python/3.13.7/x64/lib/python3.13/site-packages/torch/utils/data/dataloader.py:666: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
Raw outputs of shape torch.Size([32, 1, 16, 16])
[0] time=8.79, avg_loss=0.6674, train_err=20.8575
Eval: 16_h1=0.4282, 16_l2=0.2592, 32_h1=0.8389, 32_l2=0.4061
[3] time=8.45, avg_loss=0.2407, train_err=7.5210
Eval: 16_h1=0.2271, 16_l2=0.1486, 32_h1=0.5589, 32_l2=0.3187
[6] time=8.48, avg_loss=0.2033, train_err=6.3525
Eval: 16_h1=0.2451, 16_l2=0.1579, 32_h1=0.5141, 32_l2=0.2719
[9] time=8.42, avg_loss=0.1775, train_err=5.5458
Eval: 16_h1=0.2719, 16_l2=0.1930, 32_h1=0.5460, 32_l2=0.3395
[12] time=8.43, avg_loss=0.1878, train_err=5.8674
Eval: 16_h1=0.2472, 16_l2=0.1522, 32_h1=0.5442, 32_l2=0.3111
[15] time=8.47, avg_loss=0.1612, train_err=5.0375
Eval: 16_h1=0.2528, 16_l2=0.1543, 32_h1=0.4823, 32_l2=0.2373
[18] time=8.44, avg_loss=0.1379, train_err=4.3093
Eval: 16_h1=0.2232, 16_l2=0.1293, 32_h1=0.4878, 32_l2=0.2673

{'train_err': 4.486526936292648, 'avg_loss': 0.14356886196136476, 'avg_lasso_loss': None, 'epoch_train_time': 8.462918849999937}

Plot the prediction, and compare with the ground-truth Note that we trained on a very small resolution for a very small number of epochs In practice, we would train at larger resolution, on many more samples.

However, for practicity, we created a minimal example that i) fits in just a few Mb of memory ii) can be trained quickly on CPU

In practice we would train a Neural Operator on one or multiple GPUs

test_samples = test_loaders[32].dataset

fig = plt.figure(figsize=(7, 7))
for index in range(3):
    data = test_samples[index]
    data = data_processor.preprocess(data, batched=False)
    # Input x
    x = data['x']
    # Ground-truth
    y = data['y']
    # Model prediction
    out = model(x.unsqueeze(0).to(device)).cpu()

    ax = fig.add_subplot(3, 3, index*3 + 1)
    ax.imshow(x[0], cmap='gray')
    if index == 0:
        ax.set_title('Input x')
    plt.xticks([], [])
    plt.yticks([], [])

    ax = fig.add_subplot(3, 3, index*3 + 2)
    ax.imshow(y.squeeze())
    if index == 0:
        ax.set_title('Ground-truth y')
    plt.xticks([], [])
    plt.yticks([], [])

    ax = fig.add_subplot(3, 3, index*3 + 3)
    ax.imshow(out.squeeze().detach().numpy())
    if index == 0:
        ax.set_title('Model prediction')
    plt.xticks([], [])
    plt.yticks([], [])

fig.suptitle('Inputs, ground-truth output and prediction.', y=0.98)
plt.tight_layout()
fig.show()
Inputs, ground-truth output and prediction., Input x, Ground-truth y, Model prediction

Total running time of the script: (2 minutes 55.878 seconds)

Gallery generated by Sphinx-Gallery