import torch
import torch.nn as nn
import torch.nn.functional as F


class Discriminator(nn.Module):
    def __init__(self, params):
        super().__init__()

        self.params = params
        self.x_dim = params['x_dim']
        self.h_dim = params['h_dim']
        self.n_layers = params['n_layers']
        
        self.gru = nn.GRU(self.x_dim, self.h_dim, self.n_layers)
            
        self.discrim = nn.Sequential(
            nn.Linear(self.h_dim + self.x_dim, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, 1),
            nn.Sigmoid())

    def forward(self, data):
        # data: (T, batch_size, x_dim)
        assert data.size(-1) == self.x_dim
        h = torch.zeros(self.n_layers, data.size(1), self.h_dim).to(data.device)
        
        x = data[:-1]
        a = data[1:]
        p, hidden = self.gru(x, h)
        p = torch.cat([p, a], 2)
        prob = self.discrim(p)
        '''
        # Process entire sequence
        _, h = self.gru(data, h)
        prob = self.discrim(h[-1])
        '''
        return prob


class Generator(nn.Module):

    def __init__(self, params):
        super().__init__()

        self.params = params
        self.x_dim = params['x_dim']
        self.z_dim = params['z_dim']
        self.h_dim = params['h_dim']
        self.n_layers = params['n_layers']

        self.dec_z = nn.Linear(self.z_dim, self.h_dim*self.n_layers)

        self.gru = nn.GRU(self.x_dim, self.h_dim, self.n_layers)

        self.dec_output = nn.Linear(self.h_dim, self.x_dim)

    def forward(self, data, z=None):
        # data: (T, batch_size, x_dim)
        assert data.size(-1) == self.x_dim

        T = data.size(0)
        batch_size = data.size(1)

        # Sample z if not provided
        if z is None:
            z = torch.randn(batch_size, self.z_dim).to(data.device)
        assert z.size(-1) == self.z_dim

        # Initial hidden state decoded from z
        h = self.dec_z(z).view(self.n_layers, batch_size, self.h_dim)

        loss = 0.0

        for t in range(T):
            out = self.dec_output(h[-1])

            loss += F.mse_loss(out, data[t]) # averaged across batch by default

            _, h = self.gru(out.unsqueeze(0), h)

        return loss

    def generate(self, batch_size, T, device='cpu'):
        # Sample random noise and decode
        z = torch.randn(batch_size, self.z_dim).to(device)

        return self.decode(z, T)

        # h = self.dec_z(z).view(self.n_layers, batch_size, self.h_dim)

        # for t in range(T): 
        #     out = self.dec_output(h[-1]).unsqueeze(0)
        #     samples.append(out)

        #     _, h = self.gru(out, h)

        # return torch.cat(samples, 0)

    def decode(self, z, T):
        assert z.size(-1) == self.z_dim
        assert T > 0

        samples = []

        h = self.dec_z(z).view(self.n_layers, z.size(0), self.h_dim)

        for t in range(T): 
            out = self.dec_output(h[-1]).unsqueeze(0)
            samples.append(out)

            _, h = self.gru(out, h)

        return torch.cat(samples, 0)

