Source code for neuralop.models.uqno

from copy import deepcopy

import torch
import torch.nn as nn


from .base_model import BaseModel


[docs] class UQNO(BaseModel, name="UQNO"): """Uncertainty Quantification Neural Operator General N-dim (alpha, delta) Risk-Controlling Neural Operator, as described in [1]_. The UQNO is trained to map input functions to a residual function E(a, x) that describes the predicted error between the ground truth and the outputs of a trained model. E(a, x) is then used in combination with a calibrated scaling factor to predict calibrated uncertainty bands around the predictions of the trained model. Parameters ---------- base_model : nn.Module Pre-trained solution operator. Determined by the problem. residual_model : nn.Module, optional Architecture to train as the UQNO's quantile model. If None, a deep copy of base_model is used. Default: None References ----------- .. [1] : Ma, Z., Pitt, D., Azizzadenesheli, K., and Anandkumar, A. (2024). "Calibrated Uncertainty Quantification for Operator Learning via Conformal Prediction". TMLR, https://openreview.net/pdf?id=cGpegxy12T. """ def __init__( self, base_model: nn.Module, residual_model: nn.Module = None, **kwargs ): super().__init__() self.base_model = base_model if residual_model is None: residual_model = deepcopy(base_model) self.residual_model = residual_model
[docs] def forward(self, *args, **kwargs): """ Forward pass returns the solution u(a,x) and the uncertainty ball E(a,x) as a pair for pointwise quantile loss """ self.base_model.eval() # base-model weights are frozen # another way to handle this would be to use LoRA, or similar # ie freeze the weights, and train a low-rank matrix of weight perturbations with torch.no_grad(): solution = self.base_model(*args, **kwargs) quantile = self.residual_model(*args, **kwargs) return (solution, quantile)