Neural Operators in PyTorch



_images/neuraloperator_logo_long.png

neuraloperator is a comprehensive library for learning neural operators in PyTorch. It is the official implementation for Fourier Neural Operators and Tensorized Neural Operators.

Unlike regular neural networks, neural operators enable learning mapping between function spaces, and this library provides all of the tools to do so on your own data.

NeuralOperators are also resolution invariant, so your trained operator can be applied on data of any resolution.

Quickstart

This guide will walk you through the standard ML workflow of loading data, creating a neural operator, training it on your data and saving the trained model for later use.

First install the library pip install neuraloperator (see Installing NeuralOperator for more options).

To create a Fourier Neural Operator model:

from neuralop.models import FNO

operator = FNO(n_modes=(16, 16), hidden_channels=64,
                in_channels=3, out_channels=1)

To save the weights of the trained model:

model.save_checkpoint(save_folder='./checkpoints/', save_name='example_fno')

And to load the weights later:

from neuralop.models import FNO
model = FNO.from_checkpoint(save_folder='./checkpoints/', save_name='example_fno')

neuraloperator comes prepackaged with an example dataset of flows governed by the Darcy flow equation.

To import the data:

import torch
from neuralop.datasets import load_darcy_flow_small

train_loader, test_loaders, data_processor = load_darcy_flow_small(
     n_train=1000, batch_size=32,
     test_resolutions=[32], n_tests=[100],
     test_batch_sizes=[32],
     positional_encoding=True

)

Similar to the API provided by torchvision, this dataset includes training and test data for use in standard PyTorch training loops, as well as a preprocessor object that automates the transforms to convert the data into the form best understood by the model.

We provide a Trainer object that automates the logic of a basic neural operator training loop to speed up experimentation (see :doc: auto_examples for more information).

from neuralop.training import Trainer

# Create the trainer
trainer = Trainer(model=model, n_epochs=20,
                  data_processor=data_processor,
                  wandb_log=False,
                  log_test_interval=3,
                  use_distributed=False,
                  verbose=True)

# train the model
trainer.train(train_loader=train_loader,
           test_loaders=test_loaders,
           optimizer=optimizer,
           scheduler=scheduler,
           regularizer=False,
           training_loss=train_loss,
           eval_losses=eval_losses)

Tensorization is also provided out of the box: you can improve the previous models by simply using a Tucker Tensorized FNO with just a few parameters:

from neuralop.models import TFNO

operator = TFNO(n_modes=(16, 16), hidden_channels=64,
                in_channels=3,
                out_channels=1,
                factorization='tucker',
                implementation='factorized'
                rank=0.05)

This will use a Tucker factorization of the weights. The forward pass will be efficient by contracting directly the inputs with the factors of the decomposition. The Fourier layers will have 5% of the parameters of an equivalent, dense Fourier Neural Operator!




If you use NeuralOperator, please cite the following papers:

Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B., Bhattacharya, K., Stuart, A., and Anandkumar A. , “Fourier Neural Operator for Parametric Partial Differential Equations”, ICLR, 2021.
https://arxiv.org/abs/2010.08895.

Kovachki, N., Li, Z., Liu, B., Azizzadenesheli, K., Bhattacharya, K., Stuart, A., and Anandkumar A. , “Neural Operator: Learning Maps Between Function Spaces”, , JMLR, 2021.
https://arxiv.org/abs/2108.08481.