Source code for neuralop.utils

from typing import List, Optional, Union
from math import prod
from pathlib import Path
import torch

# Only import wandb and use if installed
wandb_available = False
try:
    import wandb

    wandb_available = True
except ModuleNotFoundError:
    wandb_available = False


[docs] def count_model_params(model): """Returns the total number of parameters of a PyTorch model Notes ----- One complex number is counted as two parameters (we count real and imaginary parts)' """ return sum( [p.numel() * 2 if p.is_complex() else p.numel() for p in model.parameters()] )
[docs] def count_tensor_params(tensor, dims=None): """Returns the number of parameters (elements) in a single tensor, optionally, along certain dimensions only Parameters ---------- tensor : torch.tensor dims : int list or None, default is None if not None, the dimensions to consider when counting the number of parameters (elements) Notes ----- One complex number is counted as two parameters (we count real and imaginary parts)' """ if dims is None: dims = list(tensor.shape) else: dims = [tensor.shape[d] for d in dims] n_params = prod(dims) if tensor.is_complex(): return 2 * n_params return n_params
def wandb_login(api_key_file="../config/wandb_api_key.txt", key=None): if key is None: key = get_wandb_api_key(api_key_file) wandb.login(key=key) def set_wandb_api_key(api_key_file="../config/wandb_api_key.txt"): import os try: os.environ["WANDB_API_KEY"] except KeyError: with open(api_key_file, "r") as f: key = f.read() os.environ["WANDB_API_KEY"] = key.strip() def get_wandb_api_key(api_key_file="../config/wandb_api_key.txt"): import os try: return os.environ["WANDB_API_KEY"] except KeyError: with open(api_key_file, "r") as f: key = f.read() return key.strip() # Define the function to compute the spectrum
[docs] def spectrum_2d(signal, n_observations, normalize=True): """This function computes the spectrum of a 2D signal using the Fast Fourier Transform (FFT). Paramaters ---------- signal : a tensor of shape (T * n_observations * n_observations) A 2D discretized signal represented as a 1D tensor with shape (T * n_observations * n_observations), where T is the number of time steps and n_observations is the spatial size of the signal. T can be any number of channels that we reshape into and n_observations * n_observations is the spatial resolution. n_observations: an integer Number of discretized points. Basically the resolution of the signal. normalize: bool whether to apply normalization to the output of the 2D FFT. If True, normalizes the outputs by ``1/n_observations`` (actually ``1/sqrt(n_observations * n_observations)``). Returns -------- spectrum: a tensor A 1D tensor of shape (s,) representing the computed spectrum. The spectrum is computed using a square approximation to radial binning, meaning that the wavenumber 'bin' into which a particular coefficient is the coefficient's location along the diagonal, indexed from the top-left corner of the 2d FFT output. """ T = signal.shape[0] signal = signal.view(T, n_observations, n_observations) if normalize: signal = torch.fft.fft2(signal, norm="ortho") else: signal = torch.fft.rfft2( signal, s=(n_observations, n_observations), norm="backward" ) # 2d wavenumbers following PyTorch fft convention k_max = n_observations // 2 wavenumers = torch.cat( ( torch.arange(start=0, end=k_max, step=1), torch.arange(start=-k_max, end=0, step=1), ), 0, ).repeat(n_observations, 1) k_x = wavenumers.transpose(0, 1) k_y = wavenumers # Sum wavenumbers sum_k = torch.sqrt(k_x**2 + k_y**2) sum_k = sum_k # Remove symmetric components from wavenumbers index = -1.0 * torch.ones((n_observations, n_observations)) k_max1 = k_max + 1 index[0:k_max1, 0:k_max1] = sum_k[0:k_max1, 0:k_max1] spectrum = torch.zeros((T, n_observations)) for j in range(1, n_observations + 1): ind = torch.where(index == j) spectrum[:, j - 1] = (signal[:, ind[0], ind[1]].abs() ** 2).sum(dim=1) spectrum = spectrum.mean(dim=0) return spectrum
Number = Union[float, int] def validate_scaling_factor( scaling_factor: Union[None, Number, List[Number], List[List[Number]]], n_dim: int, n_layers: Optional[int] = None, ) -> Union[None, List[float], List[List[float]]]: """ Parameters ---------- scaling_factor : None OR float OR list[float] Or list[list[float]] n_dim : int n_layers : int or None; defaults to None If None, return a single list (rather than a list of lists) with `factor` repeated `dim` times. """ if scaling_factor is None: return None if isinstance(scaling_factor, (float, int)): if n_layers is None: return [float(scaling_factor)] * n_dim return [[float(scaling_factor)] * n_dim] * n_layers if ( isinstance(scaling_factor, list) and len(scaling_factor) > 0 and all([isinstance(s, (float, int)) for s in scaling_factor]) ): if n_layers is None and len(scaling_factor) == n_dim: # this is a dim-wise scaling return [float(s) for s in scaling_factor] return [[float(s)] * n_dim for s in scaling_factor] if ( isinstance(scaling_factor, list) and len(scaling_factor) > 0 and all([isinstance(s, (list)) for s in scaling_factor]) ): s_sub_pass = True for s in scaling_factor: if all([isinstance(s_sub, (int, float)) for s_sub in s]): pass else: s_sub_pass = False if s_sub_pass: return scaling_factor return None
[docs] def compute_rank(tensor): # Compute the matrix rank of a tensor rank = torch.matrix_rank(tensor) return rank
[docs] def compute_stable_rank(tensor): # Compute the stable rank of a tensor tensor = tensor.detach() fro_norm = torch.linalg.norm(tensor, ord="fro") ** 2 l2_norm = torch.linalg.norm(tensor, ord=2) ** 2 rank = fro_norm / l2_norm rank = rank return rank
[docs] def compute_explained_variance(frequency_max, s): # Compute the explained variance based on frequency_max and singular # values (s) s_current = s.clone() s_current[frequency_max:] = 0 return 1 - torch.var(s - s_current) / torch.var(s)
def get_project_root(): root = Path(__file__).parent.parent return root