U-NO on Darcy-Flow

Training a U-shaped Neural Operator (U-NO) on the small Darcy-Flow example we ship with the package.

This tutorial demonstrates the U-NO architecture, which combines the resolution invariance of neural operators with the multi-scale feature extraction of U-Net architectures. The U-NO uses skip connections and multi-resolution processing to capture both local and global features in the data, making it particularly effective for complex PDE problems.

Import dependencies

We import the necessary modules for working with the UNO model

import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import UNO
from neuralop import Trainer
from neuralop.training import AdamW
from neuralop.data.datasets import load_darcy_flow_small
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss

device = 'cpu'

Loading the Darcy-Flow dataset

We load the Darcy-Flow dataset with multiple resolutions for training and testing.

train_loader, test_loaders, data_processor = load_darcy_flow_small(
        n_train=1000, batch_size=32,
        test_resolutions=[16, 32], n_tests=[100, 50],
        test_batch_sizes=[32, 32],
)
Loading test db for resolution 16 with 100 samples
Loading test db for resolution 32 with 50 samples

Creating the U-NO model

We create a U-shaped Neural Operator with the following architecture: - in_channels: Number of input channels - out_channels: Number of output channels - hidden_channels: Width of the hidden layers - uno_out_channels: Channel dimensions for each layer in the U-Net structure - uno_n_modes: Fourier modes for each layer (decreasing then increasing) - uno_scalings: Scaling factors for each layer - domain_padding: Padding to handle boundary effects

model = UNO(in_channels=1,
            out_channels=1,
            hidden_channels=64,
            projection_channels=64,
            uno_out_channels=[32,64,64,64,32],
            uno_n_modes=[[16,16],[8,8],[8,8],[8,8],[16,16]],
            uno_scalings=[[1.0,1.0],[0.5,0.5],[1,1],[2,2],[1,1]],
            horizontal_skips_map=None,
            channel_mlp_skip="linear",
            n_layers = 5,
            domain_padding=0.2)

model = model.to(device)

# Count and display the number of parameters
n_params = count_model_params(model)
print(f'\nOur model has {n_params} parameters.')
sys.stdout.flush()
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'

Our model has 2700097 parameters.

Creating the optimizer and scheduler

We use AdamW optimizer with weight decay for regularization

optimizer = AdamW(model.parameters(),
                                lr=8e-3,
                                weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

Setting up loss functions

We use H1 loss for training and L2 loss for evaluation

l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)

train_loss = h1loss
eval_losses={'h1': h1loss, 'l2': l2loss}

Displaying configuration

We print the model architecture, optimizer, scheduler, and loss functions

print('\n### MODEL ###\n', model)
print('\n### OPTIMIZER ###\n', optimizer)
print('\n### SCHEDULER ###\n', scheduler)
print('\n### LOSSES ###')
print(f'\n * Train: {train_loss}')
print(f'\n * Test: {eval_losses}')
sys.stdout.flush()
### MODEL ###
 UNO(
  (positional_embedding): GridEmbeddingND()
  (domain_padding): DomainPadding()
  (lifting): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(3, 256, kernel_size=(1,), stride=(1,))
      (1): Conv1d(256, 64, kernel_size=(1,), stride=(1,))
    )
  )
  (fno_blocks): ModuleList(
    (0): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([64, 32, 16, 9]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(64, 32, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
            (1): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(64, 32, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
    (1): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([32, 64, 8, 5]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
            (1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
    (2): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([64, 64, 8, 5]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
            (1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
    (3): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([128, 64, 8, 5]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
            (1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
    (4): FNOBlocks(
      (convs): ModuleList(
        (0): SpectralConv(
          (weight): DenseTensor(shape=torch.Size([96, 32, 16, 9]), rank=None)
        )
      )
      (fno_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
      (channel_mlp): ModuleList(
        (0): ChannelMLP(
          (fcs): ModuleList(
            (0): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
            (1): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (channel_mlp_skips): ModuleList(
        (0): Flattened1dConv(
          (conv): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
  )
  (horizontal_skips): ModuleDict(
    (0): Flattened1dConv(
      (conv): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
    )
    (1): Flattened1dConv(
      (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
    )
  )
  (projection): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
      (1): Conv1d(64, 1, kernel_size=(1,), stride=(1,))
    )
  )
)

### OPTIMIZER ###
 AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    initial_lr: 0.008
    lr: 0.008
    weight_decay: 0.0001
)

### SCHEDULER ###
 <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7f185829b4d0>

### LOSSES ###

 * Train: <neuralop.losses.data_losses.H1Loss object at 0x7f1853fce120>

 * Test: {'h1': <neuralop.losses.data_losses.H1Loss object at 0x7f1853fce120>, 'l2': <neuralop.losses.data_losses.LpLoss object at 0x7f1858817890>}

Creating the trainer

We create a Trainer object that handles the training loop for the U-NO

trainer = Trainer(model=model,
                   n_epochs=20,
                  device=device,
                  data_processor=data_processor,
                  wandb_log=False,        # Disable Weights & Biases logging
                  eval_interval=3,       # Evaluate every 3 epochs
                  use_distributed=False,  # Single GPU/CPU training
                  verbose=True)          # Print training progress

Training the U-NO model

We train the model on our Darcy-Flow dataset. The trainer will: 1. Run the forward pass through the U-NO 2. Compute the H1 loss 3. Backpropagate and update weights 4. Evaluate on test data every 3 epochs

trainer.train(train_loader=train_loader,
              test_loaders=test_loaders,
              optimizer=optimizer,
              scheduler=scheduler,
              regularizer=False,
              training_loss=train_loss,
              eval_losses=eval_losses)
Training on 1000 samples
Testing on [50, 50] samples         on resolutions [16, 32].
/opt/hostedtoolcache/Python/3.13.8/x64/lib/python3.13/site-packages/torch/utils/data/dataloader.py:668: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.13.8/x64/lib/python3.13/site-packages/torch/nn/modules/module.py:1786: UserWarning: UNO.forward() received unexpected keyword arguments: ['y']. These arguments will be ignored.
  return forward_call(*args, **kwargs)
Raw outputs of shape torch.Size([32, 1, 16, 16])
/home/runner/work/neuraloperator/neuraloperator/neuralop/training/trainer.py:507: UserWarning: H1Loss.__call__() received unexpected keyword arguments: ['x']. These arguments will be ignored.
  loss += training_loss(out, **sample)
[0] time=8.52, avg_loss=0.7192, train_err=22.4742
/home/runner/work/neuraloperator/neuraloperator/neuralop/training/trainer.py:557: UserWarning: LpLoss.__call__() received unexpected keyword arguments: ['x']. These arguments will be ignored.
  val_loss = loss(out, **sample)
Eval: 16_h1=0.4670, 16_l2=0.3055, 32_h1=0.7125, 32_l2=0.3924
[3] time=8.29, avg_loss=0.2413, train_err=7.5420
Eval: 16_h1=0.2711, 16_l2=0.1796, 32_h1=0.5284, 32_l2=0.3361
[6] time=8.34, avg_loss=0.2083, train_err=6.5092
Eval: 16_h1=0.3397, 16_l2=0.2732, 32_h1=0.5667, 32_l2=0.4138
[9] time=8.21, avg_loss=0.1860, train_err=5.8128
Eval: 16_h1=0.2224, 16_l2=0.1386, 32_h1=0.5346, 32_l2=0.3237
[12] time=8.23, avg_loss=0.2075, train_err=6.4833
Eval: 16_h1=0.2227, 16_l2=0.1353, 32_h1=0.4988, 32_l2=0.2996
[15] time=8.30, avg_loss=0.1757, train_err=5.4895
Eval: 16_h1=0.2521, 16_l2=0.1504, 32_h1=0.4995, 32_l2=0.3142
[18] time=8.31, avg_loss=0.1341, train_err=4.1897
Eval: 16_h1=0.2161, 16_l2=0.1299, 32_h1=0.5307, 32_l2=0.3298

{'train_err': 4.030218005180359, 'avg_loss': 0.1289669761657715, 'avg_lasso_loss': None, 'epoch_train_time': 8.31089495499998}

Visualizing U-NO predictions

We visualize the model’s predictions on the Darcy-Flow dataset. Note that we trained on a very small resolution for a very small number of epochs. In practice, we would train at larger resolution on many more samples.

However, for practicality, we created a minimal example that: i) fits in just a few MB of memory ii) can be trained quickly on CPU

In practice we would train a Neural Operator on one or multiple GPUs

test_samples = test_loaders[32].dataset

fig = plt.figure(figsize=(7, 7))
for index in range(3):
    data = test_samples[index]
    data = data_processor.preprocess(data, batched=False)
    # Input x
    x = data['x']
    # Ground-truth
    y = data['y']
    # Model prediction: U-NO output
    out = model(x.unsqueeze(0).to(device)).cpu()

    # Plot input x
    ax = fig.add_subplot(3, 3, index*3 + 1)
    ax.imshow(x[0], cmap='gray')
    if index == 0:
        ax.set_title('Input x')
    plt.xticks([], [])
    plt.yticks([], [])

    # Plot ground-truth y
    ax = fig.add_subplot(3, 3, index*3 + 2)
    ax.imshow(y.squeeze())
    if index == 0:
        ax.set_title('Ground-truth y')
    plt.xticks([], [])
    plt.yticks([], [])

    # Plot model prediction
    ax = fig.add_subplot(3, 3, index*3 + 3)
    ax.imshow(out.squeeze().detach().numpy())
    if index == 0:
        ax.set_title('U-NO prediction')
    plt.xticks([], [])
    plt.yticks([], [])

fig.suptitle('U-NO predictions on 32x32 Darcy-Flow data', y=0.98)
plt.tight_layout()
fig.show()
U-NO predictions on 32x32 Darcy-Flow data, Input x, Ground-truth y, U-NO prediction

Total running time of the script: (2 minutes 52.550 seconds)

Gallery generated by Sphinx-Gallery