import torch
from torch.cuda import amp
from timeit import default_timer
import pathlib
from .callbacks import PipelineCallback
import neuralop.mpu.comm as comm
from neuralop.losses import LpLoss
[docs]
class Trainer:
def __init__(
self,
*,
model,
n_epochs,
wandb_log=True,
device=None,
amp_autocast=False,
data_processor=None,
callbacks=None,
log_test_interval=1,
log_output=False,
use_distributed=False,
verbose=False,
):
"""
A general Trainer class to train neural-operators on given datasets
Parameters
----------
model : nn.Module
n_epochs : int
wandb_log : bool, default is True
device : torch.device
amp_autocast : bool, default is False
data_processor : class to transform data, default is None
if not None, data from the loaders is transform first with data_processor.preprocess,
then after getting an output from the model, that is transformed with data_processor.postprocess.
log_test_interval : int, default is 1
how frequently to print updates
log_output : bool, default is False
if True, and if wandb_log is also True, log output images to wandb
use_distributed : bool, default is False
whether to use DDP
verbose : bool, default is False
"""
if callbacks:
assert isinstance(
callbacks, list
), "Callbacks must be a list of Callback objects"
self.callbacks = PipelineCallback(callbacks=callbacks)
self.override_load_to_device = (
self.callbacks.device_load_callback_idx is not None
)
self.overrides_loss = self.callbacks.overrides_loss
else:
self.callbacks = []
self.override_load_to_device = False
self.overrides_loss = False
if verbose:
print(f"{self.override_load_to_device=}")
print(f"{self.overrides_loss=}")
if self.callbacks:
self.callbacks.on_init_start(
model=model,
n_epochs=n_epochs,
wandb_log=wandb_log,
device=device,
amp_autocast=amp_autocast,
log_test_interval=log_test_interval,
log_output=log_output,
use_distributed=use_distributed,
verbose=verbose,
)
self.model = model
self.n_epochs = n_epochs
self.wandb_log = wandb_log
self.log_test_interval = log_test_interval
self.log_output = log_output
self.verbose = verbose
self.use_distributed = use_distributed
self.device = device
self.amp_autocast = amp_autocast
self.data_processor = data_processor
if self.callbacks:
self.callbacks.on_init_end(
model=model,
n_epochs=n_epochs,
wandb_log=wandb_log,
device=device,
amp_autocast=amp_autocast,
log_test_interval=log_test_interval,
log_output=log_output,
use_distributed=use_distributed,
verbose=verbose,
)
[docs]
def train(
self,
train_loader,
test_loaders,
optimizer,
scheduler,
regularizer,
training_loss=None,
eval_losses=None,
):
"""Trains the given model on the given datasets.
params:
train_loader: torch.utils.data.DataLoader
training dataloader
test_loaders: dict[torch.utils.data.DataLoader]
testing dataloaders
optimizer: torch.optim.Optimizer
optimizer to use during training
optimizer: torch.optim.lr_scheduler
learning rate scheduler to use during training
training_loss: training.losses function
cost function to minimize
eval_losses: dict[Loss]
dict of losses to use in self.eval()
"""
if self.callbacks:
self.callbacks.on_train_start(
train_loader=train_loader,
test_loaders=test_loaders,
optimizer=optimizer,
scheduler=scheduler,
regularizer=regularizer,
training_loss=training_loss,
eval_losses=eval_losses,
data_processor=self.data_processor,
)
if training_loss is None:
training_loss = LpLoss(d=2)
if eval_losses is None: # By default just evaluate on the training loss
eval_losses = dict(l2=training_loss)
errors = None
for epoch in range(self.n_epochs):
if self.callbacks:
self.callbacks.on_epoch_start(epoch=epoch)
avg_loss = 0
avg_lasso_loss = 0
self.model.train()
t1 = default_timer()
train_err = 0.0
# track number of training examples in batch
n_samples = 0
for idx, sample in enumerate(train_loader):
if self.callbacks:
self.callbacks.on_batch_start(
idx=idx, sample=sample, data_processor=self.data_processor
)
optimizer.zero_grad(set_to_none=True)
if regularizer:
regularizer.reset()
if self.data_processor is not None:
sample = self.data_processor.preprocess(sample)
else:
# load data to device if no preprocessor exists
sample = {
k: v.to(self.device)
for k, v in sample.items()
if torch.is_tensor(v)
}
n_samples += sample["y"].shape[0]
if self.amp_autocast:
with amp.autocast(enabled=True):
out = self.model(**sample)
else:
out = self.model(**sample)
if self.data_processor is not None:
out, sample = self.data_processor.postprocess(out, sample)
if self.callbacks:
self.callbacks.on_before_loss(out=out)
loss = 0.0
if self.overrides_loss:
loss += self.callbacks.compute_training_loss(
out=out, **sample, amp_autocast=self.amp_autocast
)
else:
if self.amp_autocast:
with amp.autocast(enabled=True):
loss += training_loss(out, **sample)
else:
loss += training_loss(out, **sample)
if regularizer:
loss += regularizer.loss
loss.backward()
del out
optimizer.step()
train_err += loss.item()
with torch.no_grad():
avg_loss += loss.item()
if regularizer:
avg_lasso_loss += regularizer.loss
if self.callbacks:
self.callbacks.on_batch_end()
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(train_err)
else:
scheduler.step()
epoch_train_time = default_timer() - t1
train_err /= len(train_loader)
avg_loss /= n_samples
if epoch % self.log_test_interval == 0:
if self.callbacks:
self.callbacks.on_before_val(
epoch=epoch,
train_err=train_err,
time=epoch_train_time,
avg_loss=avg_loss,
avg_lasso_loss=avg_lasso_loss,
)
for loader_name, loader in test_loaders.items():
errors = self.evaluate(eval_losses, loader, log_prefix=loader_name)
if self.callbacks:
self.callbacks.on_val_end()
if self.callbacks:
self.callbacks.on_epoch_end(
epoch=epoch, train_err=train_err, avg_loss=avg_loss
)
return errors
[docs]
def evaluate(self, loss_dict, data_loader, log_prefix=""):
"""Evaluates the model on a dictionary of losses
Parameters
----------
loss_dict : dict of functions
each function takes as input a tuple (prediction, ground_truth)
and returns the corresponding loss
data_loader : data_loader to evaluate on
log_prefix : str, default is ''
if not '', used as prefix in output dictionary
Returns
-------
errors : dict
dict[f'{log_prefix}_{loss_name}] = loss for loss in loss_dict
"""
if self.callbacks:
self.callbacks.on_val_epoch_start(
log_prefix=log_prefix, loss_dict=loss_dict, data_loader=data_loader
)
self.model.eval()
errors = {f"{log_prefix}_{loss_name}": 0 for loss_name in loss_dict.keys()}
n_samples = 0
with torch.no_grad():
for idx, sample in enumerate(data_loader):
if self.callbacks:
self.callbacks.on_val_batch_start(
idx=idx, sample=sample, data_processor=self.data_processor
)
if self.data_processor is not None:
sample = self.data_processor.preprocess(sample)
else:
# load data to device if no preprocessor exists
sample = {
k: v.to(self.device)
for k, v in sample.items()
if torch.is_tensor(v)
}
n_samples += sample["y"].size(0)
out = self.model(**sample)
if self.data_processor is not None:
out, sample = self.data_processor.postprocess(out, sample)
if self.callbacks:
self.callbacks.on_before_val_loss(out=out)
for loss_name, loss in loss_dict.items():
if self.overrides_loss:
val_loss = self.callbacks.compute_training_loss(out, **sample)
else:
val_loss = loss(out, **sample)
if val_loss.shape == ():
val_loss = val_loss.item()
errors[f"{log_prefix}_{loss_name}"] += val_loss
if self.callbacks:
self.callbacks.on_val_batch_end()
for key in errors.keys():
errors[key] /= n_samples
if self.callbacks:
self.callbacks.on_val_epoch_end(errors=errors, sample=sample, out=out)
del out
return errors