Note
Go to the end to download the full example code.
U-NO on Darcy-Flow
Training a U-shaped Neural Operator on the small Darcy-Flow example we ship with the package
import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import UNO
from neuralop import Trainer
from neuralop.training import AdamW
from neuralop.data.datasets import load_darcy_flow_small
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss
device = 'cpu'
Loading the Darcy Flow dataset
train_loader, test_loaders, data_processor = load_darcy_flow_small(
n_train=1000, batch_size=32,
test_resolutions=[16, 32], n_tests=[100, 50],
test_batch_sizes=[32, 32],
)
model = UNO(in_channels=1,
out_channels=1,
hidden_channels=64,
projection_channels=64,
uno_out_channels=[32,64,64,64,32],
uno_n_modes=[[16,16],[8,8],[8,8],[8,8],[16,16]],
uno_scalings=[[1.0,1.0],[0.5,0.5],[1,1],[2,2],[1,1]],
horizontal_skips_map=None,
channel_mlp_skip="linear",
n_layers = 5,
domain_padding=0.2)
model = model.to(device)
n_params = count_model_params(model)
print(f'\nOur model has {n_params} parameters.')
sys.stdout.flush()
Loading test db for resolution 16 with 100 samples
Loading test db for resolution 32 with 50 samples
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
fno_skip='linear'
channel_mlp_skip='linear'
Our model has 2700097 parameters.
Create the optimizer
optimizer = AdamW(model.parameters(),
lr=8e-3,
weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
Creating the losses
l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)
train_loss = h1loss
eval_losses={'h1': h1loss, 'l2': l2loss}
print('\n### MODEL ###\n', model)
print('\n### OPTIMIZER ###\n', optimizer)
print('\n### SCHEDULER ###\n', scheduler)
print('\n### LOSSES ###')
print(f'\n * Train: {train_loss}')
print(f'\n * Test: {eval_losses}')
sys.stdout.flush()
### MODEL ###
UNO(
(positional_embedding): GridEmbeddingND()
(domain_padding): DomainPadding()
(lifting): ChannelMLP(
(fcs): ModuleList(
(0): Conv1d(3, 256, kernel_size=(1,), stride=(1,))
(1): Conv1d(256, 64, kernel_size=(1,), stride=(1,))
)
)
(fno_blocks): ModuleList(
(0): FNOBlocks(
(convs): ModuleList(
(0): SpectralConv(
(weight): DenseTensor(shape=torch.Size([64, 32, 16, 9]), rank=None)
)
)
(fno_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(64, 32, kernel_size=(1,), stride=(1,), bias=False)
)
)
(channel_mlp): ModuleList(
(0): ChannelMLP(
(fcs): ModuleList(
(0): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
(1): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
)
)
)
(channel_mlp_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(64, 32, kernel_size=(1,), stride=(1,), bias=False)
)
)
)
(1): FNOBlocks(
(convs): ModuleList(
(0): SpectralConv(
(weight): DenseTensor(shape=torch.Size([32, 64, 8, 5]), rank=None)
)
)
(fno_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,), bias=False)
)
)
(channel_mlp): ModuleList(
(0): ChannelMLP(
(fcs): ModuleList(
(0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
(1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
)
)
)
(channel_mlp_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,), bias=False)
)
)
)
(2): FNOBlocks(
(convs): ModuleList(
(0): SpectralConv(
(weight): DenseTensor(shape=torch.Size([64, 64, 8, 5]), rank=None)
)
)
(fno_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
)
)
(channel_mlp): ModuleList(
(0): ChannelMLP(
(fcs): ModuleList(
(0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
(1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
)
)
)
(channel_mlp_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
)
)
)
(3): FNOBlocks(
(convs): ModuleList(
(0): SpectralConv(
(weight): DenseTensor(shape=torch.Size([128, 64, 8, 5]), rank=None)
)
)
(fno_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,), bias=False)
)
)
(channel_mlp): ModuleList(
(0): ChannelMLP(
(fcs): ModuleList(
(0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
(1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
)
)
)
(channel_mlp_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,), bias=False)
)
)
)
(4): FNOBlocks(
(convs): ModuleList(
(0): SpectralConv(
(weight): DenseTensor(shape=torch.Size([96, 32, 16, 9]), rank=None)
)
)
(fno_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)
)
)
(channel_mlp): ModuleList(
(0): ChannelMLP(
(fcs): ModuleList(
(0): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
(1): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
)
)
)
(channel_mlp_skips): ModuleList(
(0): Flattened1dConv(
(conv): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)
)
)
)
)
(horizontal_skips): ModuleDict(
(0): Flattened1dConv(
(conv): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)
)
(1): Flattened1dConv(
(conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
)
)
(projection): ChannelMLP(
(fcs): ModuleList(
(0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
(1): Conv1d(64, 1, kernel_size=(1,), stride=(1,))
)
)
)
### OPTIMIZER ###
AdamW (
Parameter Group 0
betas: (0.9, 0.999)
correct_bias: True
eps: 1e-06
initial_lr: 0.008
lr: 0.008
weight_decay: 0.0001
)
### SCHEDULER ###
<torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7f7c5a556210>
### LOSSES ###
* Train: <neuralop.losses.data_losses.H1Loss object at 0x7f7c5a26c830>
* Test: {'h1': <neuralop.losses.data_losses.H1Loss object at 0x7f7c5a26c830>, 'l2': <neuralop.losses.data_losses.LpLoss object at 0x7f7c5a555f90>}
Create the trainer
trainer = Trainer(model=model,
n_epochs=20,
device=device,
data_processor=data_processor,
wandb_log=False,
eval_interval=3,
use_distributed=False,
verbose=True)
Actually train the model on our small Darcy-Flow dataset
trainer.train(train_loader=train_loader,
test_loaders=test_loaders,
optimizer=optimizer,
scheduler=scheduler,
regularizer=False,
training_loss=train_loss,
eval_losses=eval_losses)
Training on 1000 samples
Testing on [50, 50] samples on resolutions [16, 32].
/opt/hostedtoolcache/Python/3.13.7/x64/lib/python3.13/site-packages/torch/utils/data/dataloader.py:666: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
warnings.warn(warn_msg)
Raw outputs of shape torch.Size([32, 1, 16, 16])
[0] time=8.79, avg_loss=0.6674, train_err=20.8575
Eval: 16_h1=0.4282, 16_l2=0.2592, 32_h1=0.8389, 32_l2=0.4061
[3] time=8.45, avg_loss=0.2407, train_err=7.5210
Eval: 16_h1=0.2271, 16_l2=0.1486, 32_h1=0.5589, 32_l2=0.3187
[6] time=8.48, avg_loss=0.2033, train_err=6.3525
Eval: 16_h1=0.2451, 16_l2=0.1579, 32_h1=0.5141, 32_l2=0.2719
[9] time=8.42, avg_loss=0.1775, train_err=5.5458
Eval: 16_h1=0.2719, 16_l2=0.1930, 32_h1=0.5460, 32_l2=0.3395
[12] time=8.43, avg_loss=0.1878, train_err=5.8674
Eval: 16_h1=0.2472, 16_l2=0.1522, 32_h1=0.5442, 32_l2=0.3111
[15] time=8.47, avg_loss=0.1612, train_err=5.0375
Eval: 16_h1=0.2528, 16_l2=0.1543, 32_h1=0.4823, 32_l2=0.2373
[18] time=8.44, avg_loss=0.1379, train_err=4.3093
Eval: 16_h1=0.2232, 16_l2=0.1293, 32_h1=0.4878, 32_l2=0.2673
{'train_err': 4.486526936292648, 'avg_loss': 0.14356886196136476, 'avg_lasso_loss': None, 'epoch_train_time': 8.462918849999937}
Plot the prediction, and compare with the ground-truth Note that we trained on a very small resolution for a very small number of epochs In practice, we would train at larger resolution, on many more samples.
However, for practicity, we created a minimal example that i) fits in just a few Mb of memory ii) can be trained quickly on CPU
In practice we would train a Neural Operator on one or multiple GPUs
test_samples = test_loaders[32].dataset
fig = plt.figure(figsize=(7, 7))
for index in range(3):
data = test_samples[index]
data = data_processor.preprocess(data, batched=False)
# Input x
x = data['x']
# Ground-truth
y = data['y']
# Model prediction
out = model(x.unsqueeze(0).to(device)).cpu()
ax = fig.add_subplot(3, 3, index*3 + 1)
ax.imshow(x[0], cmap='gray')
if index == 0:
ax.set_title('Input x')
plt.xticks([], [])
plt.yticks([], [])
ax = fig.add_subplot(3, 3, index*3 + 2)
ax.imshow(y.squeeze())
if index == 0:
ax.set_title('Ground-truth y')
plt.xticks([], [])
plt.yticks([], [])
ax = fig.add_subplot(3, 3, index*3 + 3)
ax.imshow(out.squeeze().detach().numpy())
if index == 0:
ax.set_title('Model prediction')
plt.xticks([], [])
plt.yticks([], [])
fig.suptitle('Inputs, ground-truth output and prediction.', y=0.98)
plt.tight_layout()
fig.show()

Total running time of the script: (2 minutes 55.878 seconds)