import torch
from torch import nn


# define siren layer & Siren model
class Sine(nn.Module):
    """Sine activation with scaling.

    Args:
        w0 (float): Omega_0 parameter from SIREN paper.
    """

    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)


def create_activation(config):
    if config.type == "relu":
        activation = nn.ReLU()
    elif config.type == "siren":
        # activation = Sine(config.siren_w0)
        raise NotImplementedError
    elif config.type == "silu":
        activation = nn.SiLU()
    else:
        raise NotImplementedError
    return activation
