import torch


class Buffer(object):
    def __init__(self, samples, log_p, max_size):
        assert max_size > 0
        self.samples = samples
        self.log_p = log_p
        self.max_size = max_size

    def insert(self, samples, log_p):
        self.samples = torch.cat([self.samples, samples], dim=0)
        self.log_p = torch.cat([self.log_p, log_p], dim=0)
        if len(self.samples) > self.max_size:
            self.samples = self.samples[-self.max_size:]
            self.log_p = self.log_p[-self.max_size:]

    def sample(self, n):
        dist = torch.distributions.multinomial.Categorical(logits=self.log_p)
        ids = dist.sample((n,))
        return self.samples[ids]

    def __len__(self):
        return len(self.samples)


def independent_MH(target, proposal, warm_up, num_samples, init_sample):
    proposal_samples = proposal.sample(2*num_samples)
    prev_sample = init_sample
    samples = [init_sample]
    i = 0
    while len(samples) < num_samples:
        next_sample = proposal_samples[i]
        log_p = target.log_prob(next_sample) - target.log_prob(prev_sample)
        log_p += proposal.log_prob(prev_sample) - proposal.log_prob(next_sample)
        u = torch.ones(1).uniform_().cuda()
        if log_p > torch.log(u):
            prev_sample = next_sample
        samples.append(prev_sample)
        i += 1
        if i == len(proposal_samples):
            i = 0
            proposal_samples = proposal.sample(2*num_samples)
    return torch.cat(samples[warm_up:]).detach()


def disc_MH(generator, discriminator, num_samples, init_samples, noise_dim):
    batch_size = len(init_samples)
    with torch.no_grad():
        samples = []
        accepts = []
        for i in range(batch_size):
            samples.append([init_samples[i].cpu()])
            accepts.append([1])
        num_accepted = 0
        prev_samples = init_samples
        while num_accepted < num_samples:
            z = torch.randn([batch_size, noise_dim]).to(init_samples.device)
            next_samples = generator(z)
            p = discriminator.acceptance_ratio(next_samples, prev_samples)
            u = torch.rand_like(p)
            for i in range(batch_size):
                if p[i] > u[i]:
                    samples[i].append(next_samples[i].cpu().clone())
                    accepts[i].append(1)
                else:
                    accepts[i][-1] += 1
            num_accepted += torch.sum(p > u).cpu().numpy()
            prev_samples[p > u] = next_samples[p > u]
    return samples, accepts

