import torch
from torch import log as tlog
import numpy as np


class ZSampler:
    def __init__(self, dim, device):
        self.dim = dim
        self.device = device

    def sample(self, K):
        return torch.randn(K, self.dim, 1, 1, device=self.device)


class NoFilter:
    def __init__(self, device, G, sampler, img_dims=(3, 64, 64)):
        self.device = device
        self.G = G.to(device)  # self.G = G.eval().to(device)
        self.sampler = sampler
        self.img_dims = img_dims
        self.chain = None

    def transition_step(self, K):
        with torch.no_grad():
            z = self.sampler.sample(K).to(self.device)
            x_p = self.G(z.detach()).cpu().data
            return x_p

    def sample_chain(self, N, K):
        n_batches = N // K
        n_chain = K * n_batches
        chain = torch.empty((n_chain, *self.img_dims))
        for i in range(n_batches):
            start_pointer = i * K
            end_pointer = i * K + K
            chain[start_pointer:end_pointer] = self.transition_step(K)
        self.chain = chain
        return 0


class ImhFilter:
    def __init__(self, device, D, G, sampler, img_dims=(3, 64, 64)):
        self.device = device
        self.D = D.eval().to(device)
        self.G = G.to(device)  # self.G = G.eval().to(device)
        self.sampler = sampler
        self.img_dims = img_dims
        self.chain = None
        self.state = None

    def log_prob(self, x_s, x_p):
        with torch.no_grad():
            log_prob = torch.log(self.D.ar(x_s, x_p)).cpu()
            log_prob = torch.clamp(log_prob, max=0.0).view(log_prob.size(0),)
        return log_prob

    def multiple_chain_step(self, x_s):
        with torch.no_grad():
            K = x_s.size(0)
            z = self.sampler.sample(K).to(self.device)
            x_p = self.G(z.detach())
            log_ar = self.log_prob(x_s, x_p)
            u = torch.rand(K)
            flag_acc = (log_ar > tlog(u))
        return x_p, flag_acc

    def sample_chains(self, N, init):
        with torch.no_grad():
            self.state = init.clone()
        K = init.size(0)
        n_batches = N // K
        n_chain = K * n_batches
        chain = torch.empty((n_chain, *self.img_dims))

        for i in range(n_batches):
            start_pointer = i * K
            end_pointer = i * K + K
            x_s = self.state.to(self.device)
            x_p, flag = self.multiple_chain_step(x_s)
            self.state[flag].data = x_p.cpu().data[flag]
            chain[start_pointer:end_pointer] = self.state.data
        self.chain = chain
        return 0

    def sample_chains_only_accepts(self, N, init):
        self.chain = []
        with torch.no_grad():
            self.state = init.clone()
        i = 0
        while (i < N):
            x_s = self.state.to(self.device)
            x_p, flag_acc = self.multiple_chain_step(x_s)
            n_accept = torch.sum(flag_acc).item()
            if n_accept > 0:
                self.chain.append(x_p[flag_acc].cpu().data)
                self.state.data[flag_acc] = x_p.cpu().data[flag_acc]
                i += n_accept
        return 0

    def sample_chains_reinit(self, N, K, init):
        self.chain = []
        with torch.no_grad():
            self.state = init.clone()
        i = 0
        k = 0
        while (i < N):
            x_s = self.state.to(self.device)
            x_p, flag_acc = self.multiple_chain_step(x_s)
            n_accept = torch.sum(flag_acc).item()
            if n_accept > 0:
                self.chain.append(x_p[flag_acc].cpu().data)
                self.state.data[flag_acc] = x_p.cpu().data[flag_acc]
                i += n_accept
                k += n_accept
            if k > K:
                k = 0
                with torch.no_grad():
                    self.state = init.clone()
        return 0
