Source code for neuralop.training.incremental

""" 
Trainer for Incremental-FNO
"""
import torch
from torch import nn

from .trainer import Trainer
from ..models import FNO, TFNO
from ..utils import compute_explained_variance

[docs] class IncrementalFNOTrainer(Trainer): """IncrementalFNOTrainer subclasses the Trainer to implement specific logic for the Incremental-FNO as described in [1]_. References ----------- .. [1]: George, R., Zhao, J., Kossaifi, J., Li, Z., and Anandkumar, A. (2024) "Incremental Spatial and Spectral Learning of Neural Operators for Solving Large-Scale PDEs". TMLR, https://openreview.net/pdf?id=xI6cPQObp0. """ def __init__(self, model: nn.Module, n_epochs: int, wandb_log: bool=False, device: str='cpu', mixed_precision: bool=False, data_processor: nn.Module=None, eval_interval: int=1, log_output: bool=False, use_distributed: bool=False, verbose: bool=False, incremental_grad: bool = False, incremental_loss_gap: bool = False, incremental_grad_eps: float = 0.001, incremental_buffer: int = 5, incremental_max_iter: int = 1, incremental_grad_max_iter: int = 10, incremental_loss_eps: float = 0.001, ): assert (isinstance(model, FNO) or isinstance(self.model, TFNO)), f"Error: \ IncrementalFNOTrainer is designed to work with FNO or TFNO, instead got\ a model of type {model.__class__.__name__}" super().__init__( model=model, n_epochs=n_epochs, wandb_log=wandb_log, device=device, mixed_precision=mixed_precision, data_processor=data_processor, eval_interval=eval_interval, log_output=log_output, use_distributed=use_distributed, verbose=verbose) self.incremental_loss_gap = incremental_loss_gap self.incremental_grad = incremental_grad self.incremental = self.incremental_loss_gap or self.incremental_grad assert self.incremental, \ "Error: IncrementalTrainer expects at least one incremental algorithm to be True." assert not (self.incremental_loss_gap and self.incremental_grad),\ "Error: IncrementalTrainer expects only one incremental algorithm to be True." self.incremental_grad_eps = incremental_grad_eps self.incremental_buffer = incremental_buffer self.incremental_max_iter = incremental_max_iter self.incremental_grad_max_iter = incremental_grad_max_iter self.incremental_loss_eps = incremental_loss_eps self.loss_list = [] # Main step function: which algorithm to run def incremental_update(self, loss=None): if self.incremental_loss_gap and loss is not None: self.loss_gap(loss) if self.incremental_grad: self.grad_explained()
[docs] def train_one_epoch(self, epoch, train_loader, training_loss): """train_one_epoch inherits from the base Trainer's method and adds the computation of the incremental-FNO algorithm before returning the training epoch's metrics. Parameters ---------- epoch : int epoch of training train_loader : DataLoader training_loss : callable loss function to train with Returns ------- train_err, avg_loss, avg_lasso_loss, epoch_train_time """ self.training = True if self.data_processor: self.data_processor.epoch = epoch train_err, avg_loss, avg_lasso_loss, epoch_train_time =\ super().train_one_epoch(epoch, train_loader, training_loss) self.incremental_update(avg_loss) return train_err, avg_loss, avg_lasso_loss, epoch_train_time
# Algorithm 1: Incremental
[docs] def loss_gap(self, loss): """ loss_gap increases the model's incremental modes if the epoch's training loss does not decrease sufficiently Parameters ----------- loss : float | scalar torch.Tensor scalar value of epoch's training loss """ self.loss_list.append(loss) self.ndim = len(self.model.fno_blocks.convs[0].n_modes) # method 1: loss_gap incremental_modes = self.model.fno_blocks.convs[0].n_modes[0] max_modes = self.model.fno_blocks.convs[0].max_n_modes[0] if len(self.loss_list) > 1: if abs(self.loss_list[-1] - self.loss_list[-2]) <= self.incremental_loss_eps: if incremental_modes < max_modes: incremental_modes += 1 modes_list = tuple([incremental_modes] * self.ndim) self.model.fno_blocks.convs[0].n_modes = modes_list
# Algorithm 2: Gradient based explained ratio def grad_explained(self): # for mode 1 if not hasattr(self, 'accumulated_grad'): self.accumulated_grad = torch.zeros_like( self.model.fno_blocks.convs[0].weight) if not hasattr(self, 'grad_iter'): self.grad_iter = 1 self.ndim = len(self.model.fno_blocks.convs[0].n_modes) if self.grad_iter <= self.incremental_grad_max_iter: self.grad_iter += 1 self.accumulated_grad += self.model.fno_blocks.convs[0].weight else: incremental_final = [] for i in range(self.ndim): max_modes = self.model.fno_blocks.convs[i].max_n_modes[0] incremental_modes = self.model.fno_blocks.convs[0].n_modes[0] weight = self.accumulated_grad strength_vector = [] for mode_index in range( min(weight.shape[1], incremental_modes)): strength = torch.norm( weight[:, mode_index, :], p='fro') strength_vector.append(strength) expained_ratio = compute_explained_variance( incremental_modes - self.incremental_buffer, torch.Tensor(strength_vector)) if expained_ratio < self.incremental_grad_eps: if incremental_modes < max_modes: incremental_modes += 1 incremental_final.append(incremental_modes) # update the modes and frequency dimensions self.grad_iter = 1 self.accumulated_grad = torch.zeros_like( self.model.fno_blocks.convs[0].weight) main_modes = incremental_final[0] modes_list = tuple([main_modes] * self.ndim) self.model.fno_blocks.convs[0].n_modes = tuple(modes_list)