import torch
import torch.nn as nn
import torch.nn.functional as F
from models.output_layer import *
from models.utils import *
from models import base

class LReLU(nn.Module):
    def __init__(self, c=1./3):
        super(LReLU, self).__init__()
        self.c = c

    def forward(self, x):
        return torch.clamp(F.leaky_relu(x, self.c), -3., 3.)

class Model(nn.Module):
    def __init__(self, d_data, d_emb, d_mlp, d_rnn, d_lat, n_layer, n_mlp_layer,
                 dropout=0., **kwargs):
        super(Model, self).__init__()
        self.d_data = d_data
        self.d_emb = d_emb
        self.d_mlp = d_mlp
        self.d_rnn = self.d_model = d_model = d_rnn
        self.d_lat = d_lat

        self.inp_emb = nn.Sequential(
            nn.Linear(d_data, d_emb), 
            nn.Dropout(dropout)
        )
        self.crit = GaussianMixture(1, d_data, self.d_rnn)
        self.crit_back = GaussianMixture(1, d_data, self.d_rnn)
        self.crit_aux = GaussianMixture(1, d_rnn, d_lat, tanh=True)

        self.bwd_rnn = nn.LSTM(d_emb, d_rnn)
        self.fwd_rnn = nn.LSTMCell(d_emb+d_mlp, d_rnn)
        self.gen_mod = nn.Linear(d_lat, d_mlp)

        nn.init.orthogonal_(self.bwd_rnn.weight_hh_l0.data)

        self.prior = nn.Sequential(
            nn.Linear(d_rnn, d_mlp), LReLU(), 
            nn.Linear(d_mlp, d_lat * 2)
        )
        self.post = nn.Sequential(
            nn.Linear(d_rnn * 2, d_mlp), LReLU(), 
            nn.Linear(d_mlp, d_lat * 2)
        )


    def init_zero_weight(self, shape):
        weight = next(self.parameters())
        return weight.new_zeros(shape)

    def reparameterize(self, mu, logvar, eps=None):
        std = logvar.mul(0.5).exp_()
        if eps is None:
            eps = std.new(std.size()).normal_()
        return eps.mul(std).add_(mu)

    def backward_pass(self, y):
        y = self.inp_emb(y)
        y_flip = torch.flip(y, (0,))

        brnn_outs_flip, brnn_hid_flip = self.bwd_rnn(y_flip)
        brnn_outs = torch.flip(brnn_outs_flip, (0,))

        return brnn_outs
       
    def forward_pass(self, x, brnn_outs):
        x_emb = self.inp_emb(x)

        frnn_hid = [self.init_zero_weight((x.size(1), self.d_rnn)),
                    self.init_zero_weight((x.size(1), self.d_rnn))]
        frnn_out = frnn_hid[0]

        # sample all noise once
        noises = x.new(x.size(0), x.size(1), self.d_lat).float().normal_()
        
        kld = 0.
        frnn_outs = []
        prior_mus, prior_logvars, post_mus, post_logvars = [], [], [], []
        z_list = []
        for step in range(x.size(0)):
            prior_param = self.prior(frnn_out)
            prior_param = torch.clamp(prior_param, -8., 8.)
            prior_mu , prior_logvar = torch.chunk(prior_param, 2, -1)

            post_inp = torch.cat([brnn_outs[step], frnn_out], -1)
            post_param = self.post(post_inp)
            post_param = torch.clamp(post_param, -8., 8.)
            post_mu, post_logvar = torch.chunk(post_param, 2, -1)

            # [bsz x d_lat]
            z = self.reparameterize(post_mu, post_logvar, eps=noises[step])
            z_list.append(z)
            # forward rnn step
            proj_z = self.gen_mod(z)
            frnn_inp = torch.cat([x_emb[step], proj_z], -1)
            frnn_hid = self.fwd_rnn(frnn_inp, frnn_hid)

            frnn_out = frnn_hid[0]
            frnn_outs.append(frnn_out)

            # # compute KL(q||p)
            # kld += gaussian_kld(
            #     [prior_mu, prior_logvar], [post_mu, post_logvar])

            prior_mus.append(prior_mu)
            prior_logvars.append(prior_logvar)
            post_mus.append(post_mu)
            post_logvars.append(post_logvar)

        # compute all KL(q||p) once
        kld = gaussian_kld(
            [torch.stack(post_mus), torch.stack(post_logvars)], 
            [torch.stack(prior_mus), torch.stack(prior_logvars)])

        frnn_outs = torch.stack(frnn_outs, 0)
        
        aux = self.crit_aux(torch.stack(z_list), brnn_outs.detach())
        
        return frnn_outs, kld, aux
    
    def forward(self, x, y, hidden=None, mask=None):
        
        qlen, bsz, _ = x.size()
            
        brnn_outs = self.backward_pass(y)    
        
        frnn_outs, kld, aux = self.forward_pass(x, brnn_outs)

        loss = self.crit(frnn_outs, y)
        loss_back = self.crit_back(brnn_outs, x)
        # sum over the seq_len (0) & seg_len (2) and avg over the batch_size (1)
        nll_loss = loss.sum(2)
        kld_loss = kld.sum(2)
        aux_loss = aux.mean(2) * self.d_data
        loss_back = loss_back.sum(2)
        if mask is not None:
            nll_loss = nll_loss * mask
            kld_loss = kld_loss * mask
            loss_back = loss_back * mask
            aux_loss = aux_loss * mask

        return nll_loss, loss_back + aux_loss, -kld_loss, hidden
    
if __name__ == '__main__':
    d_data, d_emb, d_mlp, d_rnn, d_lat, dropout = 1, 512, 256, 1500, 256, 0.
    m = Model(None, d_data, d_emb, d_mlp, d_rnn, d_lat, dropout)
    print('parameter number:', sum([p.nelement() for p in m.parameters()]))
    x = torch.rand(32, 8, d_data)
    y = torch.rand(32, 8, d_data)

    nll_loss, kld_loss, _ = m(x, y, None)

