# Created on 2018/12
# Author: Kaituo XU

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


EPS = 1e-8

def overlap_and_add(signal, frame_step):
    """Reconstructs a signal from a framed representation.

    Adds potentially overlapping frames of a signal with shape
    `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
    The resulting tensor has shape `[..., output_size]` where

        output_size = (frames - 1) * frame_step + frame_length

    Args:
        signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
        frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.

    Returns:
        A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
        output_size = (frames - 1) * frame_step + frame_length

    Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
    """
    outer_dimensions = signal.size()[:-2]
    frames, frame_length = signal.size()[-2:]

    subframe_length = math.gcd(frame_length, frame_step)  # gcd=Greatest Common Divisor
    subframe_step = frame_step // subframe_length
    subframes_per_frame = frame_length // subframe_length
    output_size = frame_step * (frames - 1) + frame_length
    output_subframes = output_size // subframe_length

    # subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
    subframe_signal = signal.view(outer_dimensions[0],outer_dimensions[1], -1, subframe_length)

    frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
    frame = signal.new_tensor(frame).long()  # signal may in GPU or CPU
    frame = frame.contiguous().view(-1)

    # result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
    result = signal.new_zeros(outer_dimensions[0],outer_dimensions[1], output_subframes, subframe_length)
    result.index_add_(-2, frame, subframe_signal)
    # result = result.view(*outer_dimensions, -1)
    result = result.view(outer_dimensions[0],outer_dimensions[1], -1)
    return result


def remove_pad(inputs, inputs_lengths):
    """
    Args:
        inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
        inputs_lengths: torch.Tensor, [B]
    Returns:
        results: a list containing B items, each item is [C, T], T varies
    """
    results = []
    dim = inputs.dim()
    if dim == 3:
        C = inputs.size(1)
    for input, length in zip(inputs, inputs_lengths):
        if dim == 3: # [B, C, T]
            results.append(input[:,:length].view(C, -1).cpu().numpy())
        elif dim == 2:  # [B, T]
            results.append(input[:length].view(-1).cpu().numpy())
    return results


class ConvTasNet(nn.Module):
    def __init__(self, args, N=256, L=20, B=256, H=512, P=3, X=8, R=4, C=2, norm_type="gLN", causal=False,
                 mask_nonlinear='relu'):
        """
        Args:
            N: Number of filters in autoencoder
            L: Length of the filters (in samples)
            B: Number of channels in bottleneck 1 * 1-conv block
            H: Number of channels in convolutional blocks
            P: Kernel size in convolutional blocks
            X: Number of convolutional blocks in each repeat
            R: Number of repeats
            C: Number of speakers
            norm_type: BN, gLN, cLN
            causal: causal or non-causal
            mask_nonlinear: use which non-linear function to generate mask
        """
        super(ConvTasNet, self).__init__()
        # Hyper-parameter
        self.args = args
        self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
        self.norm_type = norm_type
        self.causal = causal
        self.mask_nonlinear = mask_nonlinear
        # Components
        self.encoder = Encoder(L, N)
        self.separator = TemporalConvNet(args, N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
        self.mask_conv1x1 = nn.Conv1d(B, N, 1, bias=False)
        self.decoder = Decoder(N, L)
        # init
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p)

    # def forward(self, mixture, hidden_outputs):
    #     """
    #     Args:
    #         mixture: [M, T], M is batch size, T is #samples
    #     Returns:
    #         est_source: [M, C, T]
    #     """
    #     mixture_w = self.encoder(mixture)
    #     est_mask= self.separator(mixture_w, hidden_outputs)
    #     score = self.mask_conv1x1(est_mask)
    #     if self.mask_nonlinear == 'softmax':
    #         est_mask = F.softmax(score, dim=1)
    #     elif self.mask_nonlinear == 'relu':
    #         est_mask = F.relu(score)
    #     est_source = self.decoder(mixture_w, est_mask)

    #     # T changed after conv1d in encoder, fix it here
    #     T_origin = mixture.size(-1)
    #     T_conv = est_source.size(-1)
    #     est_source = F.pad(est_source, (0, T_origin-T_conv))
    #     return est_source

    @classmethod
    def load_model(cls, path):
        # Load to CPU
        package = torch.load(path, map_location=lambda storage, loc: storage)
        model = cls.load_model_from_package(package)
        return model

    @classmethod
    def load_model_from_package(cls, package):
        model = cls(package['N'], package['L'], package['B'], package['H'],
                    package['P'], package['X'], package['R'], package['C'],
                    norm_type=package['norm_type'], causal=package['causal'],
                    mask_nonlinear=package['mask_nonlinear'])
        model.load_state_dict(package['state_dict'])
        return model

    @staticmethod
    def serialize(model, optimizer, epoch, tr_loss=None, cv_loss=None):
        package = {
            # hyper-parameter
            'N': model.N, 'L': model.L, 'B': model.B, 'H': model.H,
            'P': model.P, 'X': model.X, 'R': model.R, 'C': model.C,
            'norm_type': model.norm_type, 'causal': model.causal,
            'mask_nonlinear': model.mask_nonlinear,
            # state
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
            'epoch': epoch
        }
        if tr_loss is not None:
            package['tr_loss'] = tr_loss
            package['cv_loss'] = cv_loss
        return package










class Conditional_TemporalBlock(nn.Module):
    def __init__(self, config, in_channels, out_channels, kernel_size,
                 stride, padding, dilation, norm_type="gLN", causal=False):
        super(Conditional_TemporalBlock, self).__init__()
        # [M, B, K] -> [M, H, K]
        conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
        prelu = nn.PReLU()
        norm = chose_norm(norm_type, out_channels)
        # Put together
        if config.middle_separation_mode: #conditional 1-D conv block
            # [M, H, K] -> [M, B, K]
            dsconv = Conditional_DepthwiseSeparableConv(out_channels, in_channels, kernel_size,
                                                        stride, padding, dilation, norm_type,
                                                        causal)
            self.net = nn.Sequential(conv1x1, prelu, norm)
            self.dsconv=dsconv

    def forward(self, xs):
        """
        Args:
            x[0]: [M, topk, B, K]
            query=x[1]: [M, topk, spk_emb]
        Returns:
            [M, B, K]
        """
        x = xs[0]
        siz=x.size()
        query = xs[1]
        assert x.shape[:2]==query.shape[:2]
        x = x.view(-1,x.shape[-2],x.shape[-1]) #[M * topk, B, K]

        residual = x
        out = self.net(x)
        out = self.dsconv(out,query)
        # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
        out = (out + residual).view(*siz) # back to original size: [M, topk, B, K]

        return [out, query]  # look like w/o F.relu is better than w/ F.relu
        # return F.relu(out + residual)

class Conditional_DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding, dilation, norm_type="gLN", causal=False):
        super(Conditional_DepthwiseSeparableConv, self).__init__()
        # Use `groups` option to implement depthwise convolution
        # [M, H, K] -> [M, H, K]
        depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size,
                                   stride=stride, padding=padding,
                                   dilation=dilation, groups=in_channels,
                                   bias=False)
        if causal:
            chomp = Chomp1d(padding)
        prelu = nn.PReLU()
        norm = chose_norm(norm_type, in_channels)
        # [M, H, K] -> [M, B, K]
        pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
        # Put together
        if causal:
            self.depthwise_conv = depthwise_conv
            self.net = nn.Sequential( chomp, prelu, norm, pointwise_conv)
        else:
            self.depthwise_conv = depthwise_conv
            self.net = nn.Sequential( prelu, norm, pointwise_conv)

        self.linear_a=nn.Linear(512,in_channels) # BS,H
        self.linear_b=nn.Linear(512,in_channels) # BS,H

    def forward(self, x, query):
        """
        Args:
            x: [M*topk, H, K]
            query: [M, topK, Spk_EMB]
        Returns:
            result: [M, B, K]
        """
        # Testef from the Wavesplit : End-to-End Speech Separation by Speaker Clustering
        x = self.depthwise_conv(x) # keep the size --> M*topk,H,K
        linear_query_a=self.linear_a(query.view(-1,query.shape[-1])).unsqueeze(-1) #M*topk,Spk_emb --> M*topk,H,1
        linear_query_b=self.linear_b(query.view(-1,query.shape[-1])).unsqueeze(-1) #M*topk,Spk_emb --> M*topk,H,1
        x = linear_query_a*x+linear_query_b
        return self.net(x)


if __name__ == "__main__":
    torch.manual_seed(123)
    M, N, L, T = 2, 3, 4, 12
    K = 2*T//L-1
    B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
    mixture = torch.randint(3, (M, T))
    # test Encoder
    encoder = Encoder(L, N)
    encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
    mixture_w = encoder(mixture)
    print(('mixture', mixture))
    print(('U', encoder.conv1d_U.weight))
    print(('mixture_w', mixture_w))
    print(('mixture_w size', mixture_w.size()))

    # test TemporalConvNet
    separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
    est_mask = separator(mixture_w)
    print(('est_mask', est_mask))
    print(('model', separator))

    # test Decoder
    decoder = Decoder(N, L)
    est_mask = torch.randint(2, (B, K, C, N))
    est_source = decoder(mixture_w, est_mask)
    print(('est_source', est_source))

    # test Conv-TasNet
    conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
    est_source = conv_tasnet(mixture)
    print(('est_source', est_source))
    print(('est_source size', est_source.size()))

