import torch.nn as nn
from antgine.modules.quantization.functions.binarize import reswish_binarize


class ReswishBinarize(nn.Module):
    """
        Reswish binary module. (module wrapper on :func:`antgine.modules.quantization.functions.binarize.rewsish_binarize.`)
    """
    def __init__(self, beta):
        """
        :param float beta: Reswish's beta value.
        """
        super().__init__()
        self.beta = beta

    def forward(self, inputs):
        return reswish_binarize(inputs, self.beta)
