Source code for neuralop.training.trainer

from timeit import default_timer
from pathlib import Path
from typing import Union
import sys
import warnings

import torch
from torch.cuda import amp
from torch import nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Only import wandb and use if installed
wandb_available = False
try:
    import wandb
    wandb_available = True
except ModuleNotFoundError:
    wandb_available = False

import neuralop.mpu.comm as comm
from neuralop.losses import LpLoss
from .training_state import load_training_state, save_training_state


[docs] class Trainer: """ A general Trainer class to train neural-operators on given datasets """ def __init__( self, *, model: nn.Module, n_epochs: int, wandb_log: bool=False, device: str='cpu', mixed_precision: bool=False, data_processor: nn.Module=None, eval_interval: int=1, log_output: bool=False, use_distributed: bool=False, verbose: bool=False, ): """ Parameters ---------- model : nn.Module n_epochs : int wandb_log : bool, default is False whether to log results to wandb device : torch.device, or str 'cpu' or 'cuda' mixed_precision : bool, default is False whether to use torch.autocast to compute mixed precision data_processor : DataProcessor class to transform data, default is None if not None, data from the loaders is transform first with data_processor.preprocess, then after getting an output from the model, that is transformed with data_processor.postprocess. eval_interval : int, default is 1 how frequently to evaluate model and log training stats log_output : bool, default is False if True, and if wandb_log is also True, log output images to wandb use_distributed : bool, default is False whether to use DDP verbose : bool, default is False """ self.model = model self.n_epochs = n_epochs # only log to wandb if a run is active self.wandb_log = False if wandb_available: self.wandb_log = (wandb_log and wandb.run is not None) self.eval_interval = eval_interval self.log_output = log_output self.verbose = verbose self.use_distributed = use_distributed self.device = device # handle autocast device if isinstance(self.device, torch.device): self.autocast_device_type = self.device.type else: if "cuda" in self.device: self.autocast_device_type = "cuda" else: self.autocast_device_type = "cpu" self.mixed_precision = mixed_precision self.data_processor = data_processor # Track starting epoch for checkpointing/resuming self.start_epoch = 0
[docs] def train( self, train_loader, test_loaders, optimizer, scheduler, regularizer=None, training_loss=None, eval_losses=None, save_every: int=None, save_best: int=None, save_dir: Union[str, Path]="./ckpt", resume_from_dir: Union[str, Path]=None, ): """Trains the given model on the given dataset. If a device is provided, the model and data processor are loaded to device here. Parameters ----------- train_loader: torch.utils.data.DataLoader training dataloader test_loaders: dict[torch.utils.data.DataLoader] testing dataloaders optimizer: torch.optim.Optimizer optimizer to use during training scheduler: torch.optim.lr_scheduler learning rate scheduler to use during training training_loss: training.losses function cost function to minimize eval_losses: dict[Loss] dict of losses to use in self.eval() save_every: int, optional, default is None if provided, interval at which to save checkpoints save_best: str, optional, default is None if provided, key of metric f"{loader_name}_{loss_name}" to monitor and save model with best eval result Overrides save_every and saves on eval_interval save_dir: str | Path, default "./ckpt" directory at which to save training states if save_every and/or save_best is provided resume_from_dir: str | Path, default None if provided, resumes training state (model, optimizer, regularizer, scheduler) from state saved in `resume_from_dir` Returns ------- all_metrics: dict dictionary keyed f"{loader_name}_{loss_name}" of metric results for last validation epoch across all test_loaders """ self.optimizer = optimizer self.scheduler = scheduler if regularizer: self.regularizer = regularizer else: self.regularizer = None if training_loss is None: training_loss = LpLoss(d=2) # Warn the user if training loss is reducing across the batch if hasattr(training_loss, 'reduction'): if training_loss.reduction == "mean": warnings.warn(f"{training_loss.reduction=}. This means that the loss is " "initialized to average across the batch dim. The Trainer " "expects losses to sum across the batch dim.") if eval_losses is None: # By default just evaluate on the training loss eval_losses = dict(l2=training_loss) # accumulated wandb metrics self.wandb_epoch_metrics = None # attributes for checkpointing self.save_every = save_every self.save_best = save_best if resume_from_dir is not None: self.resume_state_from_dir(resume_from_dir) # Load model and data_processor to device self.model = self.model.to(self.device) if self.use_distributed and dist.is_initialized(): device_id = dist.get_rank() self.model = DDP(self.model, device_ids=[device_id], output_device=device_id) if self.data_processor is not None: self.data_processor = self.data_processor.to(self.device) # ensure save_best is a metric we collect if self.save_best is not None: metrics = [] for name in test_loaders.keys(): for metric in eval_losses.keys(): metrics.append(f"{name}_{metric}") assert self.save_best in metrics,\ f"Error: expected a metric of the form <loader_name>_<metric>, got {save_best}" best_metric_value = float('inf') # either monitor metric or save on interval, exclusive for simplicity self.save_every = None if self.verbose: print(f'Training on {len(train_loader.dataset)} samples') print(f'Testing on {[len(loader.dataset) for loader in test_loaders.values()]} samples' f' on resolutions {[name for name in test_loaders]}.') sys.stdout.flush() for epoch in range(self.start_epoch, self.n_epochs): train_err, avg_loss, avg_lasso_loss, epoch_train_time =\ self.train_one_epoch(epoch, train_loader, training_loss) epoch_metrics = dict( train_err=train_err, avg_loss=avg_loss, avg_lasso_loss=avg_lasso_loss, epoch_train_time=epoch_train_time ) if epoch % self.eval_interval == 0: # evaluate and gather metrics across each loader in test_loaders eval_metrics = self.evaluate_all(epoch=epoch, eval_losses=eval_losses, test_loaders=test_loaders) epoch_metrics.update(**eval_metrics) # save checkpoint if conditions are met if save_best is not None: if eval_metrics[save_best] < best_metric_value: best_metric_value = eval_metrics[save_best] self.checkpoint(save_dir) # save checkpoint if save_every and save_best is not set if self.save_every is not None: if epoch % self.save_every == 0: self.checkpoint(save_dir) return epoch_metrics
[docs] def train_one_epoch(self, epoch, train_loader, training_loss): """train_one_epoch trains self.model on train_loader for one epoch and returns training metrics Parameters ---------- epoch : int epoch number train_loader : torch.utils.data.DataLoader data loader of train examples test_loaders : dict dict of test torch.utils.data.DataLoader objects Returns ------- all_errors dict of all eval metrics for the last epoch """ self.on_epoch_start(epoch) avg_loss = 0 avg_lasso_loss = 0 self.model.train() if self.data_processor: self.data_processor.train() t1 = default_timer() train_err = 0.0 # track number of training examples in batch self.n_samples = 0 for idx, sample in enumerate(train_loader): loss = self.train_one_batch(idx, sample, training_loss) loss.backward() self.optimizer.step() train_err += loss.item() with torch.no_grad(): avg_loss += loss.item() if self.regularizer: avg_lasso_loss += self.regularizer.loss if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step(train_err) else: self.scheduler.step() epoch_train_time = default_timer() - t1 train_err /= len(train_loader) avg_loss /= self.n_samples if self.regularizer: avg_lasso_loss /= self.n_samples else: avg_lasso_loss = None lr = None for pg in self.optimizer.param_groups: lr = pg["lr"] if self.verbose and epoch % self.eval_interval == 0: self.log_training( epoch=epoch, time=epoch_train_time, avg_loss=avg_loss, train_err=train_err, avg_lasso_loss=avg_lasso_loss, lr=lr ) return train_err, avg_loss, avg_lasso_loss, epoch_train_time
def evaluate_all(self, epoch, eval_losses, test_loaders): # evaluate and gather metrics across each loader in test_loaders all_metrics = {} for loader_name, loader in test_loaders.items(): loader_metrics = self.evaluate(eval_losses, loader, log_prefix=loader_name) all_metrics.update(**loader_metrics) if self.verbose: self.log_eval(epoch=epoch, eval_metrics=all_metrics) return all_metrics
[docs] def evaluate(self, loss_dict, data_loader, log_prefix="", epoch=None): """Evaluates the model on a dictionary of losses Parameters ---------- loss_dict : dict of functions each function takes as input a tuple (prediction, ground_truth) and returns the corresponding loss data_loader : data_loader to evaluate on log_prefix : str, default is '' if not '', used as prefix in output dictionary epoch : int | None current epoch. Used when logging both train and eval default None Returns ------- errors : dict dict[f'{log_prefix}_{loss_name}] = loss for loss in loss_dict """ # Ensure model and data processor are loaded to the proper device self.model = self.model.to(self.device) if self.data_processor is not None and self.data_processor.device != self.device: self.data_processor = self.data_processor.to(self.device) self.model.eval() if self.data_processor: self.data_processor.eval() errors = {f"{log_prefix}_{loss_name}": 0 for loss_name in loss_dict.keys()} # Warn the user if any of the eval losses is reducing across the batch for _, eval_loss in loss_dict.items(): if hasattr(eval_loss, 'reduction'): if eval_loss.reduction == "mean": warnings.warn(f"{eval_loss.reduction=}. This means that the loss is " "initialized to average across the batch dim. The Trainer " "expects losses to sum across the batch dim.") self.n_samples = 0 with torch.no_grad(): for idx, sample in enumerate(data_loader): return_output = False if idx == len(data_loader) - 1: return_output = True eval_step_losses, outs = self.eval_one_batch(sample, loss_dict, return_output=return_output) for loss_name, val_loss in eval_step_losses.items(): errors[f"{log_prefix}_{loss_name}"] += val_loss for key in errors.keys(): errors[key] /= self.n_samples # on last batch, log model outputs if self.log_output: errors[f"{log_prefix}_outputs"] = wandb.Image(outs) return errors
[docs] def on_epoch_start(self, epoch): """on_epoch_start runs at the beginning of each training epoch. This method is a stub that can be overwritten in more complex cases. Parameters ---------- epoch : int index of epoch Returns ------- None """ self.epoch = epoch return None
[docs] def train_one_batch(self, idx, sample, training_loss): """Run one batch of input through model and return training loss on outputs Parameters ---------- idx : int index of batch within train_loader sample : dict data dictionary holding one batch Returns ------- loss: float | Tensor float value of training loss """ self.optimizer.zero_grad(set_to_none=True) if self.regularizer: self.regularizer.reset() if self.data_processor is not None: sample = self.data_processor.preprocess(sample) else: # load data to device if no preprocessor exists sample = { k: v.to(self.device) for k, v in sample.items() if torch.is_tensor(v) } self.n_samples += sample["y"].shape[0] if self.mixed_precision: with torch.autocast(device_type=self.autocast_device_type): out = self.model(**sample) else: out = self.model(**sample) if self.epoch == 0 and idx == 0 and self.verbose: print(f"Raw outputs of shape {out.shape}") if self.data_processor is not None: out, sample = self.data_processor.postprocess(out, sample) loss = 0.0 if self.mixed_precision: with torch.autocast(device_type=self.autocast_device_type): loss += training_loss(out, **sample) else: loss += training_loss(out, **sample) if self.regularizer: loss += self.regularizer.loss return loss
[docs] def eval_one_batch(self, sample: dict, eval_losses: dict, return_output: bool=False): """eval_one_batch runs inference on one batch and returns eval_losses for that batch. Parameters ---------- sample : dict data batch dictionary eval_losses : dict dictionary of named eval metrics return_outputs : bool whether to return model outputs for plotting by default False Returns ------- eval_step_losses : dict keyed "loss_name": step_loss_value for each loss name outputs: torch.Tensor | None optionally returns batch outputs """ if self.data_processor is not None: sample = self.data_processor.preprocess(sample) else: # load data to device if no preprocessor exists sample = { k: v.to(self.device) for k, v in sample.items() if torch.is_tensor(v) } self.n_samples += sample["y"].size(0) out = self.model(**sample) if self.data_processor is not None: out, sample = self.data_processor.postprocess(out, sample) eval_step_losses = {} for loss_name, loss in eval_losses.items(): val_loss = loss(out, **sample) eval_step_losses[loss_name] = val_loss if return_output: return eval_step_losses, out else: return eval_step_losses, None
[docs] def log_training(self, epoch:int, time: float, avg_loss: float, train_err: float, avg_lasso_loss: float=None, lr: float=None ): """Basic method to log results from a single training epoch. Parameters ---------- epoch: int time: float training time of epoch avg_loss: float average train_err per individual sample train_err: float train error for entire epoch avg_lasso_loss: float average lasso loss from regularizer, optional lr: float learning rate at current epoch """ # accumulate info to log to wandb if self.wandb_log: values_to_log = dict( train_err=train_err, time=time, avg_loss=avg_loss, avg_lasso_loss=avg_lasso_loss, lr=lr) msg = f"[{epoch}] time={time:.2f}, " msg += f"avg_loss={avg_loss:.4f}, " msg += f"train_err={train_err:.4f}" if avg_lasso_loss is not None: msg += f", avg_lasso={avg_lasso_loss:.4f}" print(msg) sys.stdout.flush() if self.wandb_log: wandb.log(data=values_to_log, step=epoch+1, commit=False)
[docs] def log_eval(self, epoch: int, eval_metrics: dict): """log_eval logs outputs from evaluation on all test loaders to stdout and wandb Parameters ---------- epoch : int current training epoch eval_metrics : dict metrics collected during evaluation keyed f"{test_loader_name}_{metric}" for each test_loader """ values_to_log = {} msg = "" for metric, value in eval_metrics.items(): if isinstance(value, float) or isinstance(value, torch.Tensor): msg += f"{metric}={value:.4f}, " if self.wandb_log: values_to_log[metric] = value msg = f"Eval: " + msg[:-2] # cut off last comma+space print(msg) sys.stdout.flush() if self.wandb_log: wandb.log(data=values_to_log, step=epoch+1, commit=True)
[docs] def resume_state_from_dir(self, save_dir): """ Resume training from save_dir created by `neuralop.training.save_training_state` Params ------ save_dir: Union[str, Path] directory in which training state is saved (see neuralop.training.training_state) """ if isinstance(save_dir, str): save_dir = Path(save_dir) # check for save model exists if (save_dir / "best_model_state_dict.pt").exists(): save_name = "best_model" elif (save_dir / "model_state_dict.pt").exists(): save_name = "model" else: raise FileNotFoundError("Error: resume_from_dir expects a model\ state dict named model.pt or best_model.pt.") # returns model, loads other modules if provided self.model, self.optimizer, self.scheduler, self.regularizer, resume_epoch =\ load_training_state(save_dir=save_dir, save_name=save_name, model=self.model, optimizer=self.optimizer, regularizer=self.regularizer, scheduler=self.scheduler) if resume_epoch is not None: if resume_epoch > self.start_epoch: self.start_epoch = resume_epoch if self.verbose: print(f"Trainer resuming from epoch {resume_epoch}")
[docs] def checkpoint(self, save_dir): """checkpoint saves current training state to a directory for resuming later. Only saves training state on the first GPU. See neuralop.training.training_state Parameters ---------- save_dir : str | Path directory in which to save training state """ if comm.get_local_rank() == 0: if self.save_best is not None: save_name = 'best_model' else: save_name = "model" save_training_state(save_dir=save_dir, save_name=save_name, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler, regularizer=self.regularizer, epoch=self.epoch ) if self.verbose: print(f"[Rank 0]: saved training state to {save_dir}")