Source code for neuralop.training.patching

import math

import torch
from torch import nn

import neuralop.mpu.comm as comm
from neuralop.mpu.mappings import (
    gather_from_model_parallel_region,
    scatter_to_model_parallel_region,
)


[docs] class MultigridPatching2D(nn.Module): """ MultigridPatching2D wraps a model in multi-grid domain decomposition and patching. Parameters ---------- model : nn.Module model to wrap levels : int, optional number of levels of patching to use, by default 0 padding_fraction : float, optional fraction of input size to add as padding around patches, by default 0 use_distributed : bool, optional whether patching is performed in distributed mode, by default False stitching : bool, optional whether patches are to be stitched back together in spatial dimensions during training, by default True. Stitching is always performed during evaluation. """ def __init__( self, model: nn.Module, levels: int = 0, padding_fraction: float = 0, use_distributed: bool = False, stitching: bool = True, ): """Initialize MultigridPatching2D. If computation is split into distributed data or model parallel, adds parameter hooks to account for scattering patches across multiple processes. """ super().__init__() self.levels = levels if isinstance(padding_fraction, (float, int)): padding_fraction = [padding_fraction, padding_fraction] self.padding_fraction = padding_fraction n_patches = 2**levels if isinstance(n_patches, int): n_patches = [n_patches, n_patches] self.n_patches = n_patches self.model = model self.use_distributed = use_distributed self.stitching = stitching if levels > 0: print( "MGPatching(" f"n_patches={self.n_patches}, " f"padding_fraction={self.padding_fraction}, " f"levels={self.levels}, " f"use_distributed={use_distributed}, " f"stitching={stitching})" ) # If distributed patches are stiched, re-scale gradients to revert DDP averaging if self.use_distributed and self.stitching: for param in model.parameters(): param.register_hook( lambda grad: grad * float(comm.get_model_parallel_size()) )
[docs] def patch(self, x, y): """use multi-grid domain decomposition to split `x` and `y` into patches. If in a distributed scheme, scatters patches across processes. Parameters ---------- x : torch.tensor model input function y : _type_ model output function """ # if not stitching in single-device, create patches for y if not self.stitching: y = make_patches(y, n=self.n_patches, p=0) # If not stitching, scatter truth, otherwise keep on every GPU if self.use_distributed: y = scatter_to_model_parallel_region(y, 0) # Create padded patches in batch dimension (identity if levels=0) x = self._make_mg_patches(x) # Split data across processes if self.use_distributed: x = scatter_to_model_parallel_region(x, 0) return x, y
[docs] def unpatch(self, x, y, evaluation=False): """unpatch tensors created by `self.patch`. Stitch patches together if in evaluation mode, or if stitching is applied. Parameters ---------- x : torch.tensor tensor with patching structure created patching input `x`. May be either inputs `x` or raw model outputs (same shape/patching structure) Shape (b * n^2, c, h / n + 2 * pad_h, w / n + 2 * pad_w) y : torch.tensor tensor of patched ground-truth `y`. Shape (b * n^2, c, h / n, w / n) or (b, c, h, w) when not stitched evaluation : bool, optional whether in evaluation mode, by default False. If True, `x` and `y` are both evaluated after stitching, regardless of other settings. """ # Remove padding in the output if self.padding_height > 0 or self.padding_width > 0: x = self._unpad(x) # Gather patches if they are to be stitched back together if self.use_distributed and self.stitching: x = gather_from_model_parallel_region(x, dim=0) # Stich patches or patch the truth if output left unstitched if self.stitching or evaluation: x = self._stitch(x) # if x is not stitched during training, y is patched # re-stitch y during evaluation only if evaluation and not self.stitching: y = self._stitch(y) return x, y
def _stitch(self, x): """Stitch back together multi-grid patches created by `self._make_mg_patches`. Small patches are collated along the batch dimension as different inputs. Unroll the batch dimension and stick all patches from the same input back together in their proper locations. For an input shape (n * n * n, c, h / n, w / n), produces an output tensor of shape (b, c, h, w) Parameters ---------- x : torch.tensor input tensor, split into patches and collated along batch dim shape (batch * n^2, c, h / n, w / n) """ # Only 1D and 2D supported assert x.ndim == 4, f"Only 2D patch supported but got input with {x.ndim} dims." if self.n_patches[0] <= 1 and self.n_patches[1] <= 1: return x # Size with padding removed size = x.size() # if self.mode == "batch-wise": B = size[0] // (self.n_patches[0] * self.n_patches[1]) W = size[3] * self.n_patches[1] C = size[1] H = size[2] * self.n_patches[0] # Reshape x = x.permute(0, 3, 2, 1) x = x.reshape(B, self.n_patches[0], self.n_patches[1], size[3], size[2], C) x = x.permute(0, 5, 1, 4, 2, 3) x = x.reshape(B, C, H, W) return x def _make_mg_patches(self, x): """ Split a tensor `x` into multi-grid patches. The number of patches into which `x` is split is determined by `2 ** self.levels`. Steps: 1. Optionally apply padding if self.padding_fraction > 0 2. Split each element in the batch into `2**self.levels` patches of equal size 3. For each level L in `self.levels`, subsample all patches by a factor of `2**L` 4. Pad the subsampled patches so all level tensors are the same size 5. Concatenate the patches along the channel dim and return Parameters ---------- x : torch.tensor input tensor to split into patches shape (batch, c, h, w) """ levels = self.levels if levels <= 0: return x _, _, height, width = x.shape padding = [ int(round(v)) for v in [ height * self.padding_fraction[0], width * self.padding_fraction[1], ] ] self.padding_height = padding[0] self.padding_width = padding[1] patched = make_patches(x, n=2**self.levels, p=padding) s1_patched = patched.size(-2) - 2 * padding[0] s2_patched = patched.size(-1) - 2 * padding[1] for level in range(1, levels + 1): sub_sample = 2**level s1_stride = s1_patched // sub_sample s2_stride = s2_patched // sub_sample x_sub = x[:, :, ::sub_sample, ::sub_sample] s2_pad = ( math.ceil( (s2_patched + (2**levels - 1) * s2_stride - x_sub.size(-1)) / 2.0 ) + padding[1] ) s1_pad = ( math.ceil( (s1_patched + (2**levels - 1) * s1_stride - x_sub.size(-2)) / 2.0 ) + padding[0] ) if s2_pad > x_sub.size(-1): diff = s2_pad - x_sub.size(-1) x_sub = torch.nn.functional.pad( x_sub, pad=[x_sub.size(-1), x_sub.size(-1), 0, 0], mode="circular" ) x_sub = torch.nn.functional.pad( x_sub, pad=[diff, diff, 0, 0], mode="circular" ) else: x_sub = torch.nn.functional.pad( x_sub, pad=[s2_pad, s2_pad, 0, 0], mode="circular" ) if s1_pad > x_sub.size(-2): diff = s1_pad - x_sub.size(-2) x_sub = torch.nn.functional.pad( x_sub, pad=[0, 0, x_sub.size(-2), x_sub.size(-2)], mode="circular" ) x_sub = torch.nn.functional.pad( x_sub, pad=[0, 0, diff, diff], mode="circular" ) else: x_sub = torch.nn.functional.pad( x_sub, pad=[0, 0, s1_pad, s1_pad], mode="circular" ) x_sub = x_sub.unfold(-1, s2_patched + 2 * padding[1], s2_stride) x_sub = x_sub.unfold(-3, s1_patched + 2 * padding[0], s1_stride) x_sub = x_sub.permute(0, 2, 3, 4, 5, 1) x_sub = x_sub.reshape( patched.size(0), s2_patched + 2 * padding[1], s1_patched + 2 * padding[0], -1, ) x_sub = x_sub.permute(0, 3, 2, 1) patched = torch.cat((patched, x_sub), 1) return patched def _unpad(self, x): """Remove padding around the edges (`mode=circular`) of tensor `x`. Parameters ---------- x : torch.tensor padded input tensor, shape (b, c, h + self.padding_height * 2, w + self.padding_width * 2) Returns ------- x : torch.tensor unpadded tensor, shape (b, c, h, w) """ return x[ ..., self.padding_height : -self.padding_height, self.padding_width : -self.padding_width, ].contiguous()
def make_patches(x, n, p=0): """make_patches splits `x` into `n` equally-sized patches with padding fraction `p`. Stacks patches along the batch dimension. Starting with an input tensor of shape (batch, C, s) or (batch, C, h, w), returns a corresponding patched output tensor of shape (n * batch, C, s / n + 2p) or (n1 * n2 * batch, C, h / n1 + 2 * p1, w / n2 + 2 * p2) Parameters ---------- x : torch.tensor input tensor, before patching n : int number of patches into which to split each example in `x` p : int, optional number of pixels to use when padding `x`, by default 0 """ size = x.size() # Only 1D and 2D supported assert len(size) == 3 or len(size) == 4 if len(size) == 3: d = 1 else: d = 2 if isinstance(p, int): p = [p, p] # Pad if p[0] > 0 or p[1] > 0: if d == 1: x = torch.nn.functional.pad(x, pad=p, mode="circular") else: x = torch.nn.functional.pad( x, pad=[p[1], p[1], p[0], p[0]], mode="circular" ) if isinstance(n, int): n = [n, n] if n[0] <= 1 and n[1] <= 1: return x # Patches must be equally sized for j in range(d): assert size[-(j + 1)] % n[-(j + 1)] == 0 # Patch for j in range(d): patch_size = size[-(j + 1)] // n[-(j + 1)] x = x.unfold(-(2 * j + 1), patch_size + 2 * p[-(j + 1)], patch_size) x = x.permute(0, 2, 3, 4, 5, 1) x = x.reshape( size[0] * n[0] * n[1], size[-1] // n[-1] + 2 * p[-1], size[-2] // n[-2] + 2 * p[-2], size[1], ) x = x.permute(0, 3, 2, 1) return x