Source code for neuralop.layers.complex

"""
Functionality for handling complex-valued spatial data
"""

from copy import deepcopy

import torch
from torch import nn
import torch.nn.functional as F


def CGELU(x: torch.Tensor):
    """Complex GELU activation function
    Follows the formulation of CReLU from [1]_.
    Applies GELU to real and imaginary parts of the input
    separately, then combine as complex number


    Parameters
    -----------
    x : torch.tensor (dtype=complex)
        pre-activation inputs

    References
    ----------
    .. [1] :

    Trabelsi, C., et al. (2018). "Deep Complex Networks".
        ICLR 2018, https://openreview.net/pdf?id=H1T2hmZAb.
    """

    return F.gelu(x.real).type(torch.cfloat) + 1j * F.gelu(x.imag).type(torch.cfloat)


def ctanh(x: torch.Tensor):
    """Complex-valued tanh stabilizer
    Apply ctanh is real and imag part of the input separately, then combine as complex number
    Args:
        x: complex tensor
    """
    return torch.tanh(x.real).type(torch.cfloat) + 1j * torch.tanh(x.imag).type(
        torch.cfloat
    )


def apply_complex(real_func, imag_func, x, dtype=torch.cfloat):
    """
    fr: a function (e.g., conv) to be applied on real part of x
    fi: a function (e.g., conv) to be applied on imag part of x
    x: complex input.
    """
    return (real_func(x.real) - imag_func(x.imag)).type(dtype) + 1j * (real_func(x.imag) + imag_func(x.real)).type(dtype)


[docs] class ComplexValued(nn.Module): """ Wrapper class that converts a standard nn.Module that operates on real data into a module that operates on complex-valued spatial data. """ def __init__(self, module): super(ComplexValued, self).__init__() self.fr = deepcopy(module) self.fi = deepcopy(module)
[docs] def forward(self, x): return apply_complex(self.fr, self.fi, x)