"""
Callbacks store all non-essential logic
required to run specific training scripts.
The callbacks in this module follow the form and
logic of callbacks in Pytorch-Lightning (https://lightning.ai/docs/pytorch/stable)
"""
import os
from pathlib import Path
import sys
from typing import List, Union, Literal
import torch
import wandb
from .training_state import save_training_state
from neuralop.utils import compute_rank, compute_stable_rank, compute_explained_variance
[docs]
class Callback(object):
"""
Base callback class. Each abstract method is called in the trainer's
training loop at the appropriate time.
Callbacks are stateful, meaning they keep track of a state and
update it throughout the lifetime of a Trainer class.
Storing the state as a dict enables the Callback to keep track of
references to underlying parts of the Trainer's process, such as
models, cost functions and output encoders
"""
def __init__(self):
self.state_dict = {}
def _update_state_dict(self, **kwargs):
self.state_dict.update(kwargs)
def on_init_start(self, **kwargs):
pass
def on_init_end(self, *args, **kwargs):
pass
def on_before_train(self, *args, **kwargs):
pass
def on_train_start(self, *args, **kwargs):
pass
def on_epoch_start(self, *args, **kwargs):
pass
def on_batch_start(self, *args, **kwargs):
pass
def on_load_to_device(self, *args, **kwargs):
pass
def on_before_forward(self, *args, **kwargs):
pass
def on_before_loss(self, *args, **kwargs):
pass
def compute_training_loss(self, *args, **kwargs):
raise NotImplementedError
def on_batch_end(self, *args, **kwargs):
pass
def on_epoch_end(self, *args, **kwargs):
pass
def on_train_end(self, *args, **kwargs):
pass
def on_before_val(self, *args, **kwargs):
pass
def on_val_epoch_start(self, *args, **kwargs):
pass
def on_val_batch_start(self, *args, **kwargs):
pass
def on_before_val_loss(self, **kwargs):
pass
def compute_val_loss(self, *args, **kwargs):
pass
def on_val_batch_end(self, *args, **kwargs):
pass
def on_val_epoch_end(self, *args, **kwargs):
pass
def on_val_end(self, *args, **kwargs):
pass
class PipelineCallback(Callback):
def __init__(self, callbacks: List[Callback]):
"""
PipelineCallback handles logic for the case in which
a user passes more than one Callback to a trainer.
Parameters
----------
callbacks : List[Callback]
list of Callbacks to use in Trainer
"""
self.callbacks = callbacks
overrides_device_load = [
"on_load_to_device" in c.__class__.__dict__.keys() for c in callbacks
]
assert (
sum(overrides_device_load) < 2
), "More than one callback cannot override device loading"
if sum(overrides_device_load) == 1:
self.device_load_callback_idx = overrides_device_load.index(True)
print("using custom callback to load data to device.")
else:
self.device_load_callback_idx = None
print("using standard method to load data to device.")
# unless loss computation is overriden, call a basic loss function calculation
overrides_loss = [
"compute_training_loss" in c.__class__.__dict__.keys() for c in callbacks
]
if sum(overrides_loss) >= 1:
self.overrides_loss = True
print("using custom callback to compute loss.")
else:
self.overrides_loss = False
print("using standard method to compute loss.")
def _update_state_dict(self, **kwargs):
for c in self.callbacks:
c._update_state_dict(kwargs)
def on_init_start(self, *args, **kwargs):
for c in self.callbacks:
c.on_init_start(*args, **kwargs)
def on_init_end(self, *args, **kwargs):
for c in self.callbacks:
c.on_init_end(*args, **kwargs)
def on_before_train(self, *args, **kwargs):
for c in self.callbacks:
c.on_before_train(*args, **kwargs)
def on_train_start(self, *args, **kwargs):
for c in self.callbacks:
c.on_train_start(*args, **kwargs)
def on_epoch_start(self, *args, **kwargs):
for c in self.callbacks:
c.on_epoch_start(*args, **kwargs)
def on_batch_start(self, *args, **kwargs):
for c in self.callbacks:
c.on_batch_start(*args, **kwargs)
def on_load_to_device(self, *args, **kwargs):
if self.device_load_callback_idx:
self.callbacks[self.device_load_callback_idx].on_load_to_device(
*args, *kwargs
)
def on_before_forward(self, *args, **kwargs):
for c in self.callbacks:
c.on_before_forward(*args, **kwargs)
def on_before_loss(self, *args, **kwargs):
for c in self.callbacks:
c.on_before_loss(*args, **kwargs)
def compute_training_loss(self, *args, **kwargs):
if self.overrides_loss:
for c in self.callbacks:
c.compute_training_loss(*args, **kwargs)
else:
pass
def on_batch_end(self, *args, **kwargs):
for c in self.callbacks:
c.on_batch_end(*args, **kwargs)
def on_epoch_end(self, *args, **kwargs):
for c in self.callbacks:
c.on_epoch_end(*args, **kwargs)
def on_train_end(self, *args, **kwargs):
for c in self.callbacks:
c.on_train_end(*args, **kwargs)
def on_before_val(self, *args, **kwargs):
for c in self.callbacks:
c.on_before_val(*args, **kwargs)
def on_val_epoch_start(self, *args, **kwargs):
for c in self.callbacks:
c.on_val_epoch_start(*args, **kwargs)
def on_val_batch_start(self, *args, **kwargs):
for c in self.callbacks:
c.on_val_batch_start(*args, **kwargs)
def on_before_val_loss(self, *args, **kwargs):
for c in self.callbacks:
c.on_before_val_loss(*args, **kwargs)
def compute_val_loss(self, *args, **kwargs):
if self.overrides_loss:
for c in self.callbacks:
c.compute_val_loss(*args, **kwargs)
else:
pass
def on_val_batch_end(self, *args, **kwargs):
for c in self.callbacks:
c.on_val_batch_end(*args, **kwargs)
def on_val_epoch_end(self, *args, **kwargs):
for c in self.callbacks:
c.on_val_epoch_end(*args, **kwargs)
def on_val_end(self, *args, **kwargs):
for c in self.callbacks:
c.on_val_end(*args, **kwargs)
[docs]
class BasicLoggerCallback(Callback):
"""
Callback that implements simple logging functionality
expected when passing verbose to a Trainer
"""
def __init__(self, wandb_kwargs=None):
super().__init__()
if wandb_kwargs:
wandb.init(**wandb_kwargs)
def on_init_end(self, *args, **kwargs):
self._update_state_dict(**kwargs)
def on_train_start(self, **kwargs):
self._update_state_dict(**kwargs)
train_loader = self.state_dict["train_loader"]
test_loaders = self.state_dict["test_loaders"]
verbose = self.state_dict["verbose"]
n_train = len(train_loader.dataset)
self._update_state_dict(n_train=n_train)
if not isinstance(test_loaders, dict):
test_loaders = dict(test=test_loaders)
if verbose:
print(f"Training on {n_train} samples")
print(
f"Testing on {[len(loader.dataset) for loader in test_loaders.values()]} samples"
f" on resolutions {[name for name in test_loaders]}."
)
sys.stdout.flush()
def on_epoch_start(self, epoch):
self._update_state_dict(epoch=epoch)
def on_batch_start(self, idx, **kwargs):
self._update_state_dict(idx=idx)
def on_before_loss(self, out, **kwargs):
if (
self.state_dict["epoch"] == 0
and self.state_dict["idx"] == 0
and self.state_dict["verbose"]
):
print(f"Raw outputs of size {out.shape=}")
def on_before_val(self, epoch, train_err, time, avg_loss, avg_lasso_loss, **kwargs):
# track training err and val losses to print at interval epochs
msg = f"[{epoch}] time={time:.2f}, avg_loss={avg_loss:.4f}, train_err={train_err:.4f}"
values_to_log = dict(train_err=train_err, time=time, avg_loss=avg_loss)
self._update_state_dict(msg=msg, values_to_log=values_to_log)
self._update_state_dict(avg_lasso_loss=avg_lasso_loss)
def on_val_epoch_end(self, errors, **kwargs):
for loss_name, loss_value in errors.items():
if isinstance(loss_value, float):
self.state_dict["msg"] += f", {loss_name}={loss_value:.4f}"
else:
loss_value = {i: e.item() for (i, e) in enumerate(loss_value)}
self.state_dict["msg"] += f", {loss_name}={loss_value}"
self.state_dict["values_to_log"][loss_name] = loss_value
def on_val_end(self, *args, **kwargs):
if self.state_dict.get("regularizer", False):
avg_lasso = self.state_dict.get("avg_lasso_loss", 0.0)
avg_lasso /= self.state_dict.get("n_epochs")
self.state_dict["msg"] += f", avg_lasso={avg_lasso:.5f}"
print(self.state_dict["msg"])
sys.stdout.flush()
if self.state_dict.get("wandb_log", False):
for pg in self.state_dict["optimizer"].param_groups:
lr = pg["lr"]
self.state_dict["values_to_log"]["lr"] = lr
wandb.log(
self.state_dict["values_to_log"],
step=self.state_dict["epoch"] + 1,
commit=True,
)
[docs]
class CheckpointCallback(Callback):
def __init__(
self,
save_dir: Union[Path, str],
save_best: str = None,
save_interval: int = 1,
save_optimizer: bool = False,
save_scheduler: bool = False,
save_regularizer: bool = False,
resume_from_dir: Union[Path, str] = None,
):
"""CheckpointCallback handles saving and resuming
training state from checkpoint .pt save files.
Parameters
----------
save_dir : Union[Path, str], optional
folder in which to save checkpoints, by default './checkpoints'
save_best : str, optional
metric to monitor for best value in order to save state
save_interval : int, optional
interval on which to save/check metric, by default 1
save_optimizer : bool, optional
whether to save optimizer state, by default False
save_scheduler : bool, optional
whether to save scheduler state, by default False
save_regularizer : bool, optional
whether to save regularizer state, by default False
resume_from_dir : Union[Path, str], optional
folder from which to resume training state.
Expects saved states in the form: (all but model optional)
(best_model.pt or model.pt), optimizer.pt, scheduler.pt, regularizer.pt
All state files present will be loaded.
if some metric was monitored during checkpointing,
the file name will be best_model.pt.
"""
super().__init__()
if isinstance(save_dir, str):
save_dir = Path(save_dir)
if not save_dir.exists():
save_dir.mkdir(parents=True)
self.save_dir = save_dir
self.save_interval = save_interval
self.save_best = save_best
self.save_optimizer = save_optimizer
self.save_scheduler = save_scheduler
self.save_regularizer = save_regularizer
if resume_from_dir:
if isinstance(resume_from_dir, str):
resume_from_dir = Path(resume_from_dir)
assert resume_from_dir.exists()
self.resume_from_dir = resume_from_dir
def on_init_end(self, *args, **kwargs):
self._update_state_dict(**kwargs)
def on_train_start(self, *args, **kwargs):
self._update_state_dict(**kwargs)
verbose = self.state_dict.get("verbose", False)
if self.save_best:
assert self.state_dict[
"eval_losses"
], "Error: cannot monitor a metric if no validation metrics exist."
assert (
self.save_best in self.state_dict["eval_losses"].keys()
), "Error: cannot monitor a metric outside of eval_losses."
self.best_metric_value = float("inf")
else:
self.best_metric_value = None
# load state dict if resume_from_dir is given
if self.resume_from_dir:
saved_modules = [x.stem for x in self.resume_from_dir.glob("*.pt")]
assert (
"best_model_state_dict" in saved_modules
or "model_state_dict" in saved_modules
), "Error: CheckpointCallback expects a model state dict named model.pt or best_model.pt."
# no need to handle exceptions if assertion that either model file exists passes
if "best_model_state_dict" in saved_modules:
if hasattr(self.state_dict["model"], "load_checkpoint"):
self.state_dict["model"].load_checkpoint(
save_folder=self.resume_from_dir, save_name="best_model"
)
else:
self.state_dict["model"].load_state_dict(
torch.load(self.resume_from_dir / "best_model.pt")
)
if verbose:
print(f"Loading model state from best_model_state_dict.pt")
else:
if hasattr(self.state_dict["model"], "load_checkpoint"):
self.state_dict["model"].load_checkpoint(
save_folder=self.resume_from_dir, save_name="model"
)
else:
self.state_dict["model"].load_state_dict(
torch.load(self.resume_from_dir / "model.pt")
)
if verbose:
print(f"Loading model state from model.pt")
# load all of optimizer, scheduler, regularizer if they exist
for module in ["optimizer", "scheduler", "regularizer"]:
if module in saved_modules:
self.state_dict[module].load_state_dict(
torch.load(self.resume_from_dir / f"{module}.pt")
)
def on_epoch_start(self, *args, **kwargs):
self._update_state_dict(**kwargs)
def on_val_epoch_start(self, *args, **kwargs):
self._update_state_dict(**kwargs)
[docs]
def on_val_epoch_end(self, *args, **kwargs):
"""
Update state dict with errors
"""
self._update_state_dict(**kwargs)
[docs]
def on_epoch_end(self, *args, **kwargs):
"""
Save state to dir if all conditions are met
"""
if self.save_best:
log_prefix = self.state_dict["log_prefix"]
if (
self.state_dict["errors"][f"{log_prefix}_{self.save_best}"]
< self.best_metric_value
):
metric_cond = True
else:
metric_cond = False
else:
metric_cond = True
# Save states to save_dir
if self.state_dict["epoch"] % self.save_interval == 0 and metric_cond:
# save model or best_model.pt no matter what
if self.save_best:
model_name = "best_model"
else:
model_name = "model"
save_training_state(
self.save_dir,
model_name,
model=self.state_dict["model"],
optimizer=self.state_dict.get("optimizer", None),
regularizer=self.state_dict.get("regularizer", None),
scheduler=self.state_dict.get("scheduler", None),
)
if self.state_dict["verbose"]:
print(f"Saved training state to {self.save_dir}")
class IncrementalCallback(Callback):
"""
Callback that implements the Incremental Algorithm - Both the Gradient explained and Loss Gap versions
incremental : bool, default is False
if True, use the base incremental algorithm which is based on gradient variance
uses the incremental_grad_eps parameter - set the threshold for gradient variance
uses the incremental_buffer paramater - sets the number of buffer modes to calculate the gradient variance
uses the incremental_max_iter parameter - sets the initial number of iterations
uses the incremental_grad_max_iter parameter - sets the maximum number of iterations to accumulate the gradients
incremental_loss_gap : bool, default is False
if True, use the incremental algorithm based on loss gap
uses the incremental_loss_eps parameter
"""
def __init__(self,
incremental_grad: bool = False,
incremental_loss_gap: bool = False,
incremental_grad_eps: float = 0.001,
incremental_buffer: int = 5,
incremental_max_iter: int = 1,
incremental_grad_max_iter: int = 10,
incremental_loss_eps: float = 0.001
):
super().__init__()
self.incremental_loss_gap = incremental_loss_gap
self.incremental_grad = incremental_grad
self.incremental = self.incremental_loss_gap or self.incremental_grad
assert self.incremental, "Error: IncrementalCallback expects at least one incremental algorithm to be True."
assert not (self.incremental_loss_gap and self.incremental_grad), "Error: IncrementalCallback expects only one incremental algorithm to be True."
self.incremental_grad_eps = incremental_grad_eps
self.incremental_buffer = incremental_buffer
self.incremental_max_iter = incremental_max_iter
self.incremental_grad_max_iter = incremental_grad_max_iter
self.incremental_loss_eps = incremental_loss_eps
self.loss_list = []
self.mode = "Train"
def on_init_end(self, *args, **kwargs):
self._update_state_dict(**kwargs)
def on_train_start(self, **kwargs):
self._update_state_dict(**kwargs)
train_loader = self.state_dict['train_loader']
test_loaders = self.state_dict['test_loaders']
verbose = self.state_dict['verbose']
n_train = len(train_loader.dataset)
self._update_state_dict(n_train=n_train)
if not isinstance(test_loaders, dict):
test_loaders = dict(test=test_loaders)
if verbose:
print(f'Training on {n_train} samples')
print(f'Testing on {[len(loader.dataset) for loader in test_loaders.values()]} samples'
f' on resolutions {[name for name in test_loaders]}.')
sys.stdout.flush()
def on_epoch_start(self, epoch):
self._update_state_dict(epoch=epoch)
def on_epoch_end(self, epoch, **kwargs):
self._update_state_dict(epoch=epoch)
print(f'Currently the model is using incremental_n_modes = {self.state_dict["model"].fno_blocks.convs.n_modes}')
def on_batch_start(self, idx, **kwargs):
self._update_state_dict(idx=idx)
self.mode = "Train"
self.data = self.state_dict['data_processor']
if self.data is not None:
self.data.mode = self.mode
self.data.epoch = self.state_dict['epoch']
def on_before_loss(self, out, **kwargs):
if self.state_dict['epoch'] == 0 and self.state_dict['idx'] == 0 \
and self.state_dict['verbose']:
print(f'Raw outputs of size {out.shape=}')
def on_before_val(self, epoch, train_err, time, avg_loss, avg_lasso_loss, **kwargs):
# track training err and val losses to print at interval epochs
msg = f'[{epoch}] time={time:.2f}, avg_loss={avg_loss:.4f}, train_err={train_err:.4f}'
self.step(avg_loss)
self._update_state_dict(msg=msg)
self._update_state_dict(avg_lasso_loss=avg_lasso_loss)
def on_val_epoch_end(self, errors, **kwargs):
for loss_name, loss_value in errors.items():
if isinstance(loss_value, float):
self.state_dict['msg'] += f', {loss_name}={loss_value:.4f}'
else:
loss_value = {i:e.item() for (i, e) in enumerate(loss_value)}
self.state_dict['msg'] += f', {loss_name}={loss_value}'
def on_val_batch_start(self, *args, **kwargs):
self.mode = "Validation"
if self.data is not None:
self.data.mode = self.mode
self.data.epoch = self.state_dict['epoch']
def on_val_end(self, *args, **kwargs):
if self.state_dict.get('regularizer', False):
avg_lasso = self.state_dict.get('avg_lasso_loss', 0.)
avg_lasso /= self.state_dict.get('n_epochs')
self.state_dict['msg'] += f', avg_lasso={avg_lasso:.5f}'
print(self.state_dict['msg'])
sys.stdout.flush()
# Main step function: which algorithm to run
def step(self, loss=None):
if self.incremental_loss_gap and loss is not None:
self.loss_gap(loss)
if self.incremental_grad:
self.grad_explained()
# Algorithm 1: Incremental
def loss_gap(self, loss):
self.loss_list.append(loss)
self.ndim = len(self.state_dict['model'].fno_blocks.convs.n_modes)
# method 1: loss_gap
incremental_modes = self.state_dict['model'].fno_blocks.convs.n_modes[0]
max_modes = self.state_dict['model'].fno_blocks.convs.max_n_modes[0]
if len(self.loss_list) > 1:
if abs(self.loss_list[-1] - self.loss_list[-2]) <= self.incremental_loss_eps:
if incremental_modes < max_modes:
incremental_modes += 1
modes_list = tuple([incremental_modes] * self.ndim)
self.state_dict['model'].fno_blocks.convs.n_modes = modes_list
# Algorithm 2: Gradient based explained ratio
def grad_explained(self):
# for mode 1
if not hasattr(self, 'accumulated_grad'):
self.accumulated_grad = torch.zeros_like(
self.state_dict['model'].fno_blocks.convs.weight[0])
if not hasattr(self, 'grad_iter'):
self.grad_iter = 1
self.ndim = len(self.state_dict['model'].fno_blocks.convs.n_modes)
if self.grad_iter <= self.incremental_grad_max_iter:
self.grad_iter += 1
self.accumulated_grad += self.state_dict['model'].fno_blocks.convs.weight[0]
else:
incremental_final = []
for i in range(self.ndim):
max_modes = self.state_dict['model'].fno_blocks.convs.max_n_modes[i]
incremental_modes = self.state_dict['model'].fno_blocks.convs.n_modes[i]
weight = self.accumulated_grad
strength_vector = []
for mode_index in range(
min(weight.shape[1], incremental_modes)):
strength = torch.norm(
weight[:, mode_index, :], p='fro')
strength_vector.append(strength)
expained_ratio = compute_explained_variance(
incremental_modes - self.incremental_buffer, torch.Tensor(strength_vector))
if expained_ratio < self.incremental_grad_eps:
if incremental_modes < max_modes:
incremental_modes += 1
incremental_final.append(incremental_modes)
# update the modes and frequency dimensions
self.grad_iter = 1
self.accumulated_grad = torch.zeros_like(
self.state_dict['model'].fno_blocks.convs.weight[0])
main_modes = incremental_final[0]
modes_list = tuple([main_modes] * self.ndim)
self.state_dict['model'].fno_blocks.convs.n_modes = tuple(modes_list)