from typing import List, Optional, Union
from math import prod
from pathlib import Path
import torch
# Only import wandb and use if installed
wandb_available = False
try:
import wandb
wandb_available = True
except ModuleNotFoundError:
wandb_available = False
[docs]
def count_model_params(model):
"""Returns the total number of parameters of a PyTorch model
Notes
-----
One complex number is counted as two parameters (we count real and imaginary parts)'
"""
return sum(
[p.numel() * 2 if p.is_complex() else p.numel() for p in model.parameters()]
)
[docs]
def count_tensor_params(tensor, dims=None):
"""Returns the number of parameters (elements) in a single tensor, optionally, along certain dimensions only
Parameters
----------
tensor : torch.tensor
dims : int list or None, default is None
if not None, the dimensions to consider when counting the number of parameters (elements)
Notes
-----
One complex number is counted as two parameters (we count real and imaginary parts)'
"""
if dims is None:
dims = list(tensor.shape)
else:
dims = [tensor.shape[d] for d in dims]
n_params = prod(dims)
if tensor.is_complex():
return 2 * n_params
return n_params
def wandb_login(api_key_file="../config/wandb_api_key.txt", key=None):
if key is None:
key = get_wandb_api_key(api_key_file)
wandb.login(key=key)
def set_wandb_api_key(api_key_file="../config/wandb_api_key.txt"):
import os
try:
os.environ["WANDB_API_KEY"]
except KeyError:
with open(api_key_file, "r") as f:
key = f.read()
os.environ["WANDB_API_KEY"] = key.strip()
def get_wandb_api_key(api_key_file="../config/wandb_api_key.txt"):
import os
try:
return os.environ["WANDB_API_KEY"]
except KeyError:
with open(api_key_file, "r") as f:
key = f.read()
return key.strip()
# Define the function to compute the spectrum
[docs]
def spectrum_2d(signal, n_observations, normalize=True):
"""This function computes the spectrum of a 2D signal using the Fast Fourier Transform (FFT).
Paramaters
----------
signal : a tensor of shape (T * n_observations * n_observations)
A 2D discretized signal represented as a 1D tensor with shape
(T * n_observations * n_observations), where T is the number of time
steps and n_observations is the spatial size of the signal.
T can be any number of channels that we reshape into and
n_observations * n_observations is the spatial resolution.
n_observations: an integer
Number of discretized points. Basically the resolution of the signal.
normalize: bool
whether to apply normalization to the output of the 2D FFT.
If True, normalizes the outputs by ``1/n_observations``
(actually ``1/sqrt(n_observations * n_observations)``).
Returns
--------
spectrum: a tensor
A 1D tensor of shape (s,) representing the computed spectrum.
The spectrum is computed using a square approximation to radial
binning, meaning that the wavenumber 'bin' into which a particular
coefficient is the coefficient's location along the diagonal, indexed
from the top-left corner of the 2d FFT output.
"""
T = signal.shape[0]
signal = signal.view(T, n_observations, n_observations)
if normalize:
signal = torch.fft.fft2(signal, norm="ortho")
else:
signal = torch.fft.rfft2(
signal, s=(n_observations, n_observations), norm="backward"
)
# 2d wavenumbers following PyTorch fft convention
k_max = n_observations // 2
wavenumers = torch.cat(
(
torch.arange(start=0, end=k_max, step=1),
torch.arange(start=-k_max, end=0, step=1),
),
0,
).repeat(n_observations, 1)
k_x = wavenumers.transpose(0, 1)
k_y = wavenumers
# Sum wavenumbers
sum_k = torch.sqrt(k_x**2 + k_y**2)
sum_k = sum_k
# Remove symmetric components from wavenumbers
index = -1.0 * torch.ones((n_observations, n_observations))
k_max1 = k_max + 1
index[0:k_max1, 0:k_max1] = sum_k[0:k_max1, 0:k_max1]
spectrum = torch.zeros((T, n_observations))
for j in range(1, n_observations + 1):
ind = torch.where(index == j)
spectrum[:, j - 1] = (signal[:, ind[0], ind[1]].abs() ** 2).sum(dim=1)
spectrum = spectrum.mean(dim=0)
return spectrum
Number = Union[float, int]
def validate_scaling_factor(
scaling_factor: Union[None, Number, List[Number], List[List[Number]]],
n_dim: int,
n_layers: Optional[int] = None,
) -> Union[None, List[float], List[List[float]]]:
"""
Parameters
----------
scaling_factor : None OR float OR list[float] Or list[list[float]]
n_dim : int
n_layers : int or None; defaults to None
If None, return a single list (rather than a list of lists)
with `factor` repeated `dim` times.
"""
if scaling_factor is None:
return None
if isinstance(scaling_factor, (float, int)):
if n_layers is None:
return [float(scaling_factor)] * n_dim
return [[float(scaling_factor)] * n_dim] * n_layers
if (
isinstance(scaling_factor, list)
and len(scaling_factor) > 0
and all([isinstance(s, (float, int)) for s in scaling_factor])
):
if n_layers is None and len(scaling_factor) == n_dim:
# this is a dim-wise scaling
return [float(s) for s in scaling_factor]
return [[float(s)] * n_dim for s in scaling_factor]
if (
isinstance(scaling_factor, list)
and len(scaling_factor) > 0
and all([isinstance(s, (list)) for s in scaling_factor])
):
s_sub_pass = True
for s in scaling_factor:
if all([isinstance(s_sub, (int, float)) for s_sub in s]):
pass
else:
s_sub_pass = False
if s_sub_pass:
return scaling_factor
return None
[docs]
def compute_rank(tensor):
# Compute the matrix rank of a tensor
rank = torch.matrix_rank(tensor)
return rank
[docs]
def compute_stable_rank(tensor):
# Compute the stable rank of a tensor
tensor = tensor.detach()
fro_norm = torch.linalg.norm(tensor, ord="fro") ** 2
l2_norm = torch.linalg.norm(tensor, ord=2) ** 2
rank = fro_norm / l2_norm
rank = rank
return rank
[docs]
def compute_explained_variance(frequency_max, s):
# Compute the explained variance based on frequency_max and singular
# values (s)
s_current = s.clone()
s_current[frequency_max:] = 0
return 1 - torch.var(s - s_current) / torch.var(s)
def get_project_root():
root = Path(__file__).parent.parent
return root