Source code for neuralop.training.trainer

import torch
from torch.cuda import amp
from torch import nn
from timeit import default_timer
from pathlib import Path
from typing import Union
import sys

# Only import wandb and use if installed
wandb_available = False
try:
    import wandb
    wandb_available = True
except ModuleNotFoundError:
    wandb_available = False

import neuralop.mpu.comm as comm
from neuralop.losses import LpLoss
from .training_state import load_training_state, save_training_state


[docs] class Trainer: """ A general Trainer class to train neural-operators on given datasets """ def __init__( self, *, model: nn.Module, n_epochs: int, wandb_log: bool=False, device: str='cpu', mixed_precision: bool=False, data_processor: nn.Module=None, eval_interval: int=1, log_output: bool=False, use_distributed: bool=False, verbose: bool=False, ): """ Parameters ---------- model : nn.Module n_epochs : int wandb_log : bool, default is False whether to log results to wandb device : torch.device, or str 'cpu' or 'cuda' mixed_precision : bool, default is False whether to use torch.autocast to compute mixed precision data_processor : DataProcessor class to transform data, default is None if not None, data from the loaders is transform first with data_processor.preprocess, then after getting an output from the model, that is transformed with data_processor.postprocess. eval_interval : int, default is 1 how frequently to evaluate model and log training stats log_output : bool, default is False if True, and if wandb_log is also True, log output images to wandb use_distributed : bool, default is False whether to use DDP verbose : bool, default is False """ self.model = model self.n_epochs = n_epochs # only log to wandb if a run is active self.wandb_log = False if wandb_available: self.wandb_log = (wandb_log and wandb.run is not None) self.eval_interval = eval_interval self.log_output = log_output self.verbose = verbose self.use_distributed = use_distributed self.device = device # handle autocast device if isinstance(self.device, torch.device): self.autocast_device_type = self.device.type else: if "cuda" in self.device: self.autocast_device_type = "cuda" else: self.autocast_device_type = "cpu" self.mixed_precision = mixed_precision self.data_processor = data_processor
[docs] def train( self, train_loader, test_loaders, optimizer, scheduler, regularizer=None, training_loss=None, eval_losses=None, save_every: int=None, save_best: int=None, save_dir: Union[str, Path]="./ckpt", resume_from_dir: Union[str, Path]=None, ): """Trains the given model on the given datasets. Parameters ----------- train_loader: torch.utils.data.DataLoader training dataloader test_loaders: dict[torch.utils.data.DataLoader] testing dataloaders optimizer: torch.optim.Optimizer optimizer to use during training optimizer: torch.optim.lr_scheduler learning rate scheduler to use during training training_loss: training.losses function cost function to minimize eval_losses: dict[Loss] dict of losses to use in self.eval() save_every: int, optional, default is None if provided, interval at which to save checkpoints save_best: str, optional, default is None if provided, key of metric f"{loader_name}_{loss_name}" to monitor and save model with best eval result Overrides save_every and saves on eval_interval save_dir: str | Path, default "./ckpt" directory at which to save training states if save_every and/or save_best is provided resume_from_dir: str | Path, default None if provided, resumes training state (model, optimizer, regularizer, scheduler) from state saved in `resume_from_dir` Returns ------- all_metrics: dict dictionary keyed f"{loader_name}_{loss_name}" of metric results for last validation epoch across all test_loaders """ self.optimizer = optimizer self.scheduler = scheduler if regularizer: self.regularizer = regularizer else: self.regularizer = None if training_loss is None: training_loss = LpLoss(d=2) if eval_losses is None: # By default just evaluate on the training loss eval_losses = dict(l2=training_loss) # accumulated wandb metrics self.wandb_epoch_metrics = None # attributes for checkpointing self.save_every = save_every self.save_best = save_best if resume_from_dir is not None: self.resume_state_from_dir(resume_from_dir) # ensure save_best is a metric we collect if self.save_best is not None: metrics = [] for name in test_loaders.keys(): for metric in eval_losses.keys(): metrics.append(f"{name}_{metric}") assert self.save_best in metrics,\ f"Error: expected a metric of the form <loader_name>_<metric>, got {save_best}" best_metric_value = float('inf') # either monitor metric or save on interval, exclusive for simplicity self.save_every = None if self.verbose: print(f'Training on {len(train_loader.dataset)} samples') print(f'Testing on {[len(loader.dataset) for loader in test_loaders.values()]} samples' f' on resolutions {[name for name in test_loaders]}.') sys.stdout.flush() for epoch in range(self.n_epochs): train_err, avg_loss, avg_lasso_loss, epoch_train_time =\ self.train_one_epoch(epoch, train_loader, training_loss) epoch_metrics = dict( train_err=train_err, avg_loss=avg_loss, avg_lasso_loss=avg_lasso_loss, epoch_train_time=epoch_train_time ) if epoch % self.eval_interval == 0: # evaluate and gather metrics across each loader in test_loaders eval_metrics = self.evaluate_all(epoch=epoch, eval_losses=eval_losses, test_loaders=test_loaders) epoch_metrics.update(**eval_metrics) # save checkpoint if conditions are met if save_best is not None: if eval_metrics[save_best] < best_metric_value: best_metric_value = eval_metrics[save_best] self.checkpoint(save_dir) # save checkpoint if save_every and save_best is not set if self.save_every is not None: if epoch % self.save_every == 0: self.checkpoint(save_dir) return epoch_metrics
[docs] def train_one_epoch(self, epoch, train_loader, training_loss): """train_one_epoch trains self.model on train_loader for one epoch and returns training metrics Parameters ---------- epoch : int epoch number train_loader : torch.utils.data.DataLoader data loader of train examples test_loaders : dict dict of test torch.utils.data.DataLoader objects Returns ------- all_errors dict of all eval metrics for the last epoch """ self.on_epoch_start(epoch) avg_loss = 0 avg_lasso_loss = 0 self.model.train() if self.data_processor: self.data_processor.train() t1 = default_timer() train_err = 0.0 # track number of training examples in batch self.n_samples = 0 for idx, sample in enumerate(train_loader): loss = self.train_one_batch(idx, sample, training_loss) loss.backward() self.optimizer.step() train_err += loss.item() with torch.no_grad(): avg_loss += loss.item() if self.regularizer: avg_lasso_loss += self.regularizer.loss if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step(train_err) else: self.scheduler.step() epoch_train_time = default_timer() - t1 train_err /= len(train_loader) avg_loss /= self.n_samples if self.regularizer: avg_lasso_loss /= self.n_samples else: avg_lasso_loss = None lr = None for pg in self.optimizer.param_groups: lr = pg["lr"] if self.verbose and epoch % self.eval_interval == 0: self.log_training( epoch=epoch, time=epoch_train_time, avg_loss=avg_loss, train_err=train_err, avg_lasso_loss=avg_lasso_loss, lr=lr ) return train_err, avg_loss, avg_lasso_loss, epoch_train_time
def evaluate_all(self, epoch, eval_losses, test_loaders): # evaluate and gather metrics across each loader in test_loaders all_metrics = {} for loader_name, loader in test_loaders.items(): loader_metrics = self.evaluate(eval_losses, loader, log_prefix=loader_name) all_metrics.update(**loader_metrics) if self.verbose: self.log_eval(epoch=epoch, eval_metrics=all_metrics) return all_metrics
[docs] def evaluate(self, loss_dict, data_loader, log_prefix="", epoch=None): """Evaluates the model on a dictionary of losses Parameters ---------- loss_dict : dict of functions each function takes as input a tuple (prediction, ground_truth) and returns the corresponding loss data_loader : data_loader to evaluate on log_prefix : str, default is '' if not '', used as prefix in output dictionary epoch : int | None current epoch. Used when logging both train and eval default None Returns ------- errors : dict dict[f'{log_prefix}_{loss_name}] = loss for loss in loss_dict """ self.model.eval() if self.data_processor: self.data_processor.eval() errors = {f"{log_prefix}_{loss_name}": 0 for loss_name in loss_dict.keys()} self.n_samples = 0 with torch.no_grad(): for idx, sample in enumerate(data_loader): return_output = False if idx == len(data_loader) - 1: return_output = True eval_step_losses, outs = self.eval_one_batch(sample, loss_dict, return_output=return_output) for loss_name, val_loss in eval_step_losses.items(): errors[f"{log_prefix}_{loss_name}"] += val_loss for key in errors.keys(): errors[key] /= self.n_samples # on last batch, log model outputs if self.log_output: errors[f"{log_prefix}_outputs"] = wandb.Image(outs) return errors
[docs] def on_epoch_start(self, epoch): """on_epoch_start runs at the beginning of each training epoch. This method is a stub that can be overwritten in more complex cases. Parameters ---------- epoch : int index of epoch Returns ------- None """ self.epoch = epoch return None
[docs] def train_one_batch(self, idx, sample, training_loss): """Run one batch of input through model and return training loss on outputs Parameters ---------- idx : int index of batch within train_loader sample : dict data dictionary holding one batch Returns ------- loss: float | Tensor float value of training loss """ self.optimizer.zero_grad(set_to_none=True) if self.regularizer: self.regularizer.reset() if self.data_processor is not None: sample = self.data_processor.preprocess(sample) else: # load data to device if no preprocessor exists sample = { k: v.to(self.device) for k, v in sample.items() if torch.is_tensor(v) } self.n_samples += sample["y"].shape[0] if self.mixed_precision: with torch.autocast(device_type=self.autocast_device_type): out = self.model(**sample) else: out = self.model(**sample) if self.epoch == 0 and idx == 0 and self.verbose: print(f"Raw outputs of shape {out.shape}") if self.data_processor is not None: out, sample = self.data_processor.postprocess(out, sample) loss = 0.0 if self.mixed_precision: with torch.autocast(device_type=self.autocast_device_type): loss += training_loss(out, **sample) else: loss += training_loss(out, **sample) if self.regularizer: loss += self.regularizer.loss return loss
[docs] def eval_one_batch(self, sample: dict, eval_losses: dict, return_output: bool=False): """eval_one_batch runs inference on one batch and returns eval_losses for that batch. Parameters ---------- sample : dict data batch dictionary eval_losses : dict dictionary of named eval metrics return_outputs : bool whether to return model outputs for plotting by default False Returns ------- eval_step_losses : dict keyed "loss_name": step_loss_value for each loss name outputs: torch.Tensor | None optionally returns batch outputs """ if self.data_processor is not None: sample = self.data_processor.preprocess(sample) else: # load data to device if no preprocessor exists sample = { k: v.to(self.device) for k, v in sample.items() if torch.is_tensor(v) } self.n_samples += sample["y"].size(0) out = self.model(**sample) if self.data_processor is not None: out, sample = self.data_processor.postprocess(out, sample) eval_step_losses = {} for loss_name, loss in eval_losses.items(): val_loss = loss(out, **sample) eval_step_losses[loss_name] = val_loss if return_output: return eval_step_losses, out else: return eval_step_losses, None
[docs] def log_training(self, epoch:int, time: float, avg_loss: float, train_err: float, avg_lasso_loss: float=None, lr: float=None ): """Basic method to log results from a single training epoch. Parameters ---------- epoch: int time: float training time of epoch avg_loss: float average train_err per individual sample train_err: float train error for entire epoch avg_lasso_loss: float average lasso loss from regularizer, optional lr: float learning rate at current epoch """ # accumulate info to log to wandb if self.wandb_log: values_to_log = dict( train_err=train_err, time=time, avg_loss=avg_loss, avg_lasso_loss=avg_lasso_loss, lr=lr) msg = f"[{epoch}] time={time:.2f}, " msg += f"avg_loss={avg_loss:.4f}, " msg += f"train_err={train_err:.4f}" if avg_lasso_loss is not None: msg += f", avg_lasso={avg_lasso_loss:.4f}" print(msg) sys.stdout.flush() if self.wandb_log: wandb.log(data=values_to_log, step=epoch+1, commit=False)
[docs] def log_eval(self, epoch: int, eval_metrics: dict): """log_eval logs outputs from evaluation on all test loaders to stdout and wandb Parameters ---------- epoch : int current training epoch eval_metrics : dict metrics collected during evaluation keyed f"{test_loader_name}_{metric}" for each test_loader """ values_to_log = {} msg = "" for metric, value in eval_metrics.items(): if isinstance(value, float) or isinstance(value, torch.Tensor): msg += f"{metric}={value:.4f}, " if self.wandb_log: values_to_log[metric] = value msg = f"Eval: " + msg[:-2] # cut off last comma+space print(msg) sys.stdout.flush() if self.wandb_log: wandb.log(data=values_to_log, step=epoch+1, commit=True)
[docs] def resume_state_from_dir(self, save_dir): """ Resume training from save_dir created by `neuralop.training.save_training_state` Params ------ save_dir: Union[str, Path] directory in which training state is saved (see neuralop.training.training_state) """ if isinstance(save_dir, str): save_dir = Path(save_dir) # check for save model exists if (save_dir / "best_model_state_dict.pt").exists(): save_name = "best_model" elif (save_dir / "model_state_dict.pt").exists(): save_name = "model" else: raise FileNotFoundError("Error: resume_from_dir expects a model\ state dict named model.pt or best_model.pt.") # returns model, loads other modules in-place if provided self.model = load_training_state(save_dir=save_dir, save_name=save_name, model=self.model, optimizer=self.optimizer, regularizer=self.regularizer, scheduler=self.scheduler)
[docs] def checkpoint(self, save_dir): """checkpoint saves current training state to a directory for resuming later. See neuralop.training.training_state Parameters ---------- save_dir : str | Path directory in which to save training state """ if self.save_best is not None: save_name = 'best_model' else: save_name = "model" save_training_state(save_dir=save_dir, save_name=save_name, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler, regularizer=self.regularizer ) if self.verbose: print(f"Saved training state to {save_dir}")