Source code for neuralop.datasets.data_transforms

from abc import ABCMeta, abstractmethod

import torch
from neuralop.training.patching import MultigridPatching2D


class DataProcessor(torch.nn.Module, metaclass=ABCMeta):
    def __init__(self):
        """DataProcessor exposes functionality for pre-
        and post-processing data during training or inference.

        To be a valid DataProcessor within the Trainer requires
        that the following methods are implemented:

        - to(device): load necessary information to device, in keeping
            with PyTorch convention
        - preprocess(data): processes data from a new batch before being
            put through a model's forward pass
        - postprocess(out): processes the outputs of a model's forward pass
            before loss and backward pass
        - wrap(self, model):
            wraps a model in preprocess and postprocess steps to create one forward pass
        - forward(self, x):
            forward pass providing that a model has been wrapped
        """
        super().__init__()

    @abstractmethod
    def to(self, device):
        pass

    @abstractmethod
    def preprocess(self, x):
        pass

    @abstractmethod
    def postprocess(self, x):
        pass

    @abstractmethod
    def wrap(self, model):
        pass

    @abstractmethod
    def forward(self, x):
        pass


[docs] class DefaultDataProcessor(DataProcessor): def __init__( self, in_normalizer=None, out_normalizer=None, positional_encoding=None ): """A simple processor to pre/post process data before training/inferencing a model. Parameters ---------- in_normalizer : Transform, optional, default is None normalizer (e.g. StandardScaler) for the input samples out_normalizer : Transform, optional, default is None normalizer (e.g. StandardScaler) for the target and predicted samples positional_encoding : Processor, optional, default is None class that appends a positional encoding to the input """ super().__init__() self.in_normalizer = in_normalizer self.out_normalizer = out_normalizer self.positional_encoding = positional_encoding self.device = "cpu" def wrap(self, model): self.model = model return self
[docs] def to(self, device): if self.in_normalizer is not None: self.in_normalizer = self.in_normalizer.to(device) if self.out_normalizer is not None: self.out_normalizer = self.out_normalizer.to(device) self.device = device return self
def preprocess(self, data_dict, batched=True): x = data_dict["x"].to(self.device) y = data_dict["y"].to(self.device) if self.in_normalizer is not None: x = self.in_normalizer.transform(x) if self.positional_encoding is not None: x = self.positional_encoding(x, batched=batched) if self.out_normalizer is not None and self.train: y = self.out_normalizer.transform(y) data_dict["x"] = x data_dict["y"] = y return data_dict def postprocess(self, output, data_dict): y = data_dict["y"] if self.out_normalizer and not self.train: output = self.out_normalizer.inverse_transform(output) y = self.out_normalizer.inverse_transform(y) data_dict["y"] = y return output, data_dict
[docs] def forward(self, **data_dict): data_dict = self.preprocess(data_dict) output = self.model(data_dict["x"]) output = self.postprocess(output) return output, data_dict
class IncrementalDataProcessor(torch.nn.Module): def __init__(self, in_normalizer=None, out_normalizer=None, positional_encoding=None, device = 'cpu', subsampling_rates=[2, 1], dataset_resolution=16, dataset_indices=[2,3], epoch_gap=10, verbose=False): """An incremental processor to pre/post process data before training/inferencing a model In particular this processor first regularizes the input resolution based on the sub_list and dataset_indices in the spatial domain based on a fixed number of epochs. We incrementally increase the resolution like done in curriculum learning to train the model. This is useful for training models on large datasets with high resolution data. Parameters ---------- in_normalizer : Transform, optional, default is None normalizer (e.g. StandardScaler) for the input samples out_normalizer : Transform, optional, default is None normalizer (e.g. StandardScaler) for the target and predicted samples positional_encoding : Processor, optional, default is None class that appends a positional encoding to the input device : str, optional, default is 'cpu' device 'cuda' or 'cpu' where computations are performed subsampling_rates : list, optional, default is [2, 1] list of subsampling rates to use dataset_resolution : int, optional, default is 16 resolution of the input data dataset_indices : list, optional, default is [2, 3] list of indices of the dataset to slice to regularize the input resolution - Spatial Dimensions epoch_gap : int, optional, default is 10 number of epochs to wait before increasing the resolution verbose : bool, optional, default is False if True, print the current resolution """ super().__init__() self.in_normalizer = in_normalizer self.out_normalizer = out_normalizer self.positional_encoding = positional_encoding self.device = device self.sub_list = subsampling_rates self.dataset_resolution = dataset_resolution self.dataset_indices = dataset_indices self.epoch_gap = epoch_gap self.verbose = verbose self.mode = "Train" self.epoch = 0 self.current_index = 0 self.current_logged_epoch = 0 self.current_sub = self.index_to_sub_from_table(self.current_index) self.current_res = int(self.dataset_resolution / self.current_sub) print(f'Original Incre Res: change index to {self.current_index}') print(f'Original Incre Res: change sub to {self.current_sub}') print(f'Original Incre Res: change res to {self.current_res}') def wrap(self, model): self.model = model return self def to(self, device): if self.in_normalizer is not None: self.in_normalizer = self.in_normalizer.to(device) if self.out_normalizer is not None: self.out_normalizer = self.out_normalizer.to(device) self.device = device return self def epoch_wise_res_increase(self, epoch): # Update the current_sub and current_res values based on the epoch if epoch % self.epoch_gap == 0 and epoch != 0 and ( self.current_logged_epoch != epoch): self.current_index += 1 self.current_sub = self.index_to_sub_from_table(self.current_index) self.current_res = int(self.dataset_resolution / self.current_sub) self.current_logged_epoch = epoch if self.verbose: print(f'Incre Res Update: change index to {self.current_index}') print(f'Incre Res Update: change sub to {self.current_sub}') print(f'Incre Res Update: change res to {self.current_res}') def index_to_sub_from_table(self, index): # Get the sub value from the sub_list based on the index if index >= len(self.sub_list): return self.sub_list[-1] else: return self.sub_list[index] def regularize_input_res(self, x, y): # Regularize the input data based on the current_sub and dataset_name for idx in self.dataset_indices: indexes = torch.arange(0, x.size(idx), self.current_sub, device=self.device) x = x.index_select(dim=idx, index=indexes) y = y.index_select(dim=idx, index=indexes) return x, y def step(self, loss=None, epoch=None, x=None, y=None): if x is not None and y is not None: self.epoch_wise_res_increase(epoch) return self.regularize_input_res(x, y) def preprocess(self, data_dict, batched=True): x = data_dict['x'].to(self.device) y = data_dict['y'].to(self.device) if self.in_normalizer is not None: x = self.in_normalizer.transform(x) if self.positional_encoding is not None: x = self.positional_encoding(x, batched=batched) if self.out_normalizer is not None and self.train: y = self.out_normalizer.transform(y) if self.mode == "Train": x, y = self.step(epoch=self.epoch, x=x, y=y) data_dict['x'] = x data_dict['y'] = y return data_dict def postprocess(self, output, data_dict): y = data_dict['y'] if self.out_normalizer and not self.train: output = self.out_normalizer.inverse_transform(output) y = self.out_normalizer.inverse_transform(y) data_dict['y'] = y return output, data_dict def forward(self, **data_dict): data_dict = self.preprocess(data_dict) output = self.model(data_dict['x']) output = self.postprocess(output) return output, data_dict
[docs] class MGPatchingDataProcessor(DataProcessor): def __init__( self, model: torch.nn.Module, levels: int, padding_fraction: float, stitching: float, device: str = "cpu", in_normalizer=None, out_normalizer=None, positional_encoding=None, ): """MGPatchingDataProcessor Applies multigrid patching to inputs out-of-place with an optional output encoder/other data transform Parameters ---------- model: nn.Module model to wrap in MultigridPatching2D levels : int mg_patching level parameter for MultigridPatching2D padding_fraction : float mg_padding_fraction parameter for MultigridPatching2D stitching : float mg_patching_stitching parameter for MultigridPatching2D in_normalizer : neuralop.datasets.transforms.Transform, optional OutputEncoder to decode model inputs, by default None in_normalizer : neuralop.datasets.transforms.Transform, optional OutputEncoder to decode model outputs, by default None positional_encoding : neuralop.datasets.transforms.PositionalEmbedding2D, optional appends pos encoding to x if used device : str, optional device 'cuda' or 'cpu' where computations are performed positional_encoding : neuralop.datasets.transforms.Transform, optional """ super().__init__() self.levels = levels self.padding_fraction = padding_fraction self.stitching = stitching self.patcher = MultigridPatching2D( model=model, levels=self.levels, padding_fraction=self.padding_fraction, stitching=self.stitching, ) self.device = device # set normalizers to none by default self.in_normalizer, self.out_normalizer = None, None if in_normalizer: self.in_normalizer = in_normalizer.to(self.device) if out_normalizer: self.out_normalizer = out_normalizer.to(self.device) self.positional_encoding = positional_encoding self.model = None
[docs] def to(self, device): self.device = device if self.in_normalizer: self.in_normalizer = self.in_normalizer.to(self.device) if self.out_normalizer: self.out_normalizer = self.out_normalizer.to(self.device)
def wrap(self, model): self.model = model return self
[docs] def preprocess(self, data_dict, batched=True): """ Preprocess data assuming that if encoder exists, it has encoded all data during data loading Params ------ data_dict: dict dictionary keyed with 'x', 'y' etc represents one batch of data input to a model batched: bool whether the first dimension of 'x', 'y' represents batching """ data_dict = { k: v.to(self.device) for k, v in data_dict.items() if torch.is_tensor(v) } x, y = data_dict["x"], data_dict["y"] if self.in_normalizer: x = self.in_normalizer.transform(x) y = self.out_normalizer.transform(y) if self.positional_encoding is not None: x = self.positional_encoding(x, batched=batched) data_dict["x"], data_dict["y"] = self.patcher.patch(x, y) return data_dict
[docs] def postprocess(self, out, data_dict): """ Postprocess model outputs, including decoding if an encoder exists. Params ------ data_dict: dict dictionary keyed with 'x', 'y' etc represents one batch of data input to a model out: torch.Tensor model output predictions """ y = data_dict["y"] out, y = self.patcher.unpatch(out, y) if self.out_normalizer: y = self.out_normalizer.inverse_transform(y) out = self.out_normalizer.inverse_transform(out) data_dict["y"] = y return out, data_dict
[docs] def forward(self, **data_dict): data_dict = self.preprocess(data_dict) output = self.model(**data_dict) output, data_dict = self.postprocess(output, data_dict) return output, data_dict