import torch
from torch import log as tlog

class MarkovZNoFilter:
    def __init__(self, P, device, img_dims=(3, 64, 64)):
        self.device = device
        self.P = P.eval().to(device)
        self.img_dims = img_dims
        self.x_chain = None
        self.z_chain = None

    def sample_chains(self, T, K, init, append_init=False):
        z_s = init.clone()
        self.x_chain = torch.zeros((K, T, *self.img_dims))
        self.z_chain = torch.zeros((K, T, z_s.size(1)))

        if append_init:
            self.x_chain[:, 0, :, :, :] = self.P.G(z_s.view(*z_s.size(), 1, 1).to(self.device)).data.cpu()
            self.z_chain[:, 0, :] = z_s.data.cpu()

        for i in range(int(append_init), T):
            with torch.no_grad():
                x_p, z_p = self.P(z_s)
            self.x_chain[:, i, :, :, :] = x_p.data.cpu()
            self.z_chain[:, i, :] = z_p.data.cpu()
            z_s = z_p
        return 0


class MmhFilter:
    def __init__(self, D, P, device, img_dims=(3, 64, 64)):
        self.device = device
        self.D = D.eval().to(device)
        self.P = P.eval().to(device)
        self.img_dims = img_dims
        self.oa = []

    def log_prob(self, x_s, x_p):
        with torch.no_grad():
            log_test = self.D.log_test(x_s, x_p).cpu()
            log_ar = torch.clamp(log_test, max=0.0).view(log_test.size(0), )
        return log_ar

    def multiple_chain_step(self, z_s, x_s):
        K = x_s.size(0)
        with torch.no_grad():
            x_p, z_p = self.P(z_s)
            log_prob = self.log_prob(x_s.to(self.device), x_p.to(self.device))
            u = torch.rand(K)
            flag_acc = (log_prob > tlog(u))
        return x_p, z_p, flag_acc

    def sample_chains(self, N, T, z_init, x_init):
        count_accept = 0
        t = 0

        z_s = z_init.clone()
        x_s = x_init.clone()

        with torch.no_grad():
            while count_accept < N:
                x_p, z_p, flag_acc = self.multiple_chain_step(z_s, x_s)
                n_accept = torch.sum(flag_acc).item()
                if n_accept > 0:
                    self.oa.append(x_p.cpu().data[flag_acc])

                x_s[flag_acc] = x_p.cpu()[flag_acc]
                z_s[flag_acc] = z_p.cpu()[flag_acc]

                t += 1
                count_accept += n_accept

                if t > T:
                    print('Get%d of %d, max T reached' % (count_accept, N))
                    return 0
        print('----------')
        print(count_accept)
        return 0

    def sample_keep_states(self, T, K, z_init, x_init):
        count_accept = 0
        t = 0

        z_s = z_init.clone()
        x_s = x_init.clone()

        self.x_chain = torch.zeros((K, T, *self.img_dims))
        self.z_chain = torch.zeros((K, T, z_s.size(1)))

        with torch.no_grad():
            while t < T:
                x_p, z_p, flag_acc = self.multiple_chain_step(z_s, x_s)
                n_accept = torch.sum(flag_acc).item()
                self.rejections.append(x_p.cpu().data[~flag_acc])
                if n_accept > 0:
                    self.oa.append(x_p.cpu().data[flag_acc])

                x_s[flag_acc] = x_p.cpu()[flag_acc]
                z_s[flag_acc] = z_p.cpu()[flag_acc]

                self.x_chain[:, t, :, :, :] = x_s.data.clone()
                self.z_chain[:, t, :] = z_s.data.clone()

                t += 1
                count_accept += n_accept
        print('----------')
        print(count_accept)
        return 0
