import torch
import torch.nn as nn
from .scale import Scale


class ScaledModule(nn.Module):
    """
        Container for a module that will be scaled.
    """
    def __init__(self, module):
        """
        :param nn.Module module: Module for which output will be scaled.
        """
        super().__init__()
        self.module = module
        self.scale = Scale(torch.empty([module.weight.size(0)]))

    def forward(self, xs):
        return self.scale(self.module(xs))
