from abc import ABC

from contextlib import ExitStack
from functools import partial
from src import poutine
from src.contrib.util import iter_plates_to_shape, lexpand, rexpand, rmv
from src.util import is_bad
# from torchdiffeq import odeint

import src
import src.distributions as dist
import torch

EPS = 2**-22


class ExperimentModel(ABC):
    """
    Basic interface for probabilistic models
    """

    def __init__(self):
        self.epsilon = torch.tensor(EPS)

    def sanity_check(self):
        assert self.var_dim > 0
        assert len(self.var_names) > 0

    def make_model(self):
        raise NotImplementedError

    def reset(self, n_parallel):
        raise NotImplementedError

    def run_experiment(self, design, theta):
        """
        Execute an experiment with given design.
        """
        # create model from sampled params
        cond_model = src.condition(self.make_model(), data=theta)

        # infer experimental outcome given design and model
        y = cond_model(design)
        y = y.detach().clone()
        return y

    def get_likelihoods(self, y, design, thetas):
        size = thetas[self.var_names[0]].shape[0]
        cond_dict = dict(thetas)
        cond_dict.update({self.obs_label: lexpand(y, size)})
        cond_model = src.condition(self.make_model(), data=cond_dict)
        trace = poutine.trace(cond_model).get_trace(lexpand(design, size))
        trace.compute_log_prob()
        likelihoods = trace.nodes[self.obs_label]["log_prob"]
        return likelihoods

    def sample_theta(self, num_theta):
        dummy_design = torch.zeros(
            (num_theta, self.n_parallel, 1, 1, self.var_dim))
        cur_model = self.make_model()
        trace = poutine.trace(cur_model).get_trace(dummy_design)
        thetas = dict([(l, trace.nodes[l]["value"]) for l in self.var_names])
        return thetas


class CESModel(ExperimentModel):
    def __init__(self, init_rho_model=None, init_alpha_model=None,
                 init_mu_model=None, init_sig_model=None, n_parallel=1,
                 obs_sd=0.005, obs_label="y", n_elbo_samples=100,
                 n_elbo_steps=100, elbo_lr=0.04):
        super().__init__()
        self.init_rho_model = init_rho_model if init_rho_model is not None \
            else torch.ones(n_parallel, 1, 2)
        self.init_alpha_model = init_alpha_model \
            if init_alpha_model is not None else torch.ones(n_parallel, 1, 3)
        self.init_mu_model = init_mu_model if init_mu_model is not None \
            else torch.ones(n_parallel, 1)
        self.init_sig_model = init_sig_model if init_sig_model is not None \
            else 3. * torch.ones(n_parallel, 1)
        self.rho_con_model = self.init_rho_model.detach().clone()
        self.alpha_con_model = self.init_alpha_model.detach().clone()
        self.u_mu_model = self.init_mu_model.detach().clone()
        self.u_sig_model = self.init_sig_model.detach().clone()
        self.n_parallel, self.elbo_lr = n_parallel, elbo_lr
        self.n_elbo_samples, self.n_elbo_steps = n_elbo_samples, n_elbo_steps
        self.obs_sd = obs_sd
        self.obs_label = obs_label
        self.param_names = [
            "rho_con",
            "alpha_con",
            "u_mu",
            "u_sig",
        ]
        self.var_names = ["rho", "alpha", "u"]
        self.var_dim = 6
        self.sanity_check()

    def reset(self, init_rho_model=None, init_alpha_model=None,
              init_mu_model=None, init_sig_model=None, n_parallel=None):
        if n_parallel is not None:
            self.n_parallel = n_parallel
            self.init_rho_model = init_rho_model if init_rho_model \
                else torch.ones(self.n_parallel, 1, 2)
            self.init_alpha_model = init_alpha_model if init_alpha_model \
                else torch.ones(self.n_parallel, 1, 3)
            self.init_mu_model = init_mu_model if init_mu_model \
                else torch.ones(self.n_parallel, 1)
            self.init_sig_model = init_sig_model if init_sig_model \
                else 3. * torch.ones(self.n_parallel, 1)
            self.rho_con_model = self.init_rho_model.detach().clone()
            self.alpha_con_model = self.init_alpha_model.detach().clone()
            self.u_mu_model = self.init_mu_model.detach().clone()
            self.u_sig_model = self.init_sig_model.detach().clone()

    def make_model(self):
        def model(design):
            # pyro.set_rng_seed(10)
            if is_bad(design):
                raise ArithmeticError("bad design, contains nan or inf")
            batch_shape = design.shape[:-2]
            with ExitStack() as stack:
                for plate in iter_plates_to_shape(batch_shape):
                    stack.enter_context(plate)
                rho_shape = batch_shape + (self.rho_con_model.shape[-1],)
                rho = 0.01 + 0.99 * src.sample(
                    "rho",
                    dist.Dirichlet(self.rho_con_model.expand(rho_shape))
                ).select(-1, 0)
                alpha_shape = batch_shape + (self.alpha_con_model.shape[-1],)
                alpha = src.sample(
                    "alpha",
                    dist.Dirichlet(self.alpha_con_model.expand(alpha_shape))
                )
                u = src.sample(
                    "u",
                    dist.LogNormal(
                        self.u_mu_model.expand(batch_shape),
                        self.u_sig_model.expand(batch_shape)
                    )
                )
                rho = rexpand(rho, design.shape[-2])
                u = rexpand(u, design.shape[-2])
                d1, d2 = design[..., 0:3], design[..., 3:6]
                u1rho = (rmv(d1.pow(rho.unsqueeze(-1)), alpha)).pow(1. / rho)
                u2rho = (rmv(d2.pow(rho.unsqueeze(-1)), alpha)).pow(1. / rho)
                mean = u * (u1rho - u2rho)
                sd = u * self.obs_sd * (
                        1 + torch.norm(d1 - d2, dim=-1, p=2))

                emission_dist = dist.CensoredSigmoidNormal(
                    mean, sd, 1 - self.epsilon, self.epsilon
                ).to_event(1)
                y = src.sample(self.obs_label, emission_dist)
                return y

        return model

    def get_params(self):
        return torch.cat(
            [
                self.rho_con_model.reshape(self.n_parallel, -1),
                self.alpha_con_model.reshape(self.n_parallel, -1),
                self.u_mu_model.reshape(self.n_parallel, -1),
                self.u_sig_model.reshape(self.n_parallel, -1),
            ],
            dim=-1
        )


def holling2(a, th, t, n):
    an = a * n
    return -an / (1 + an * th)


def holling3(a, th, t, n):
    an2 = a * n * n
    return -an2 / (1 + an2 * th)


class PreyModel(ExperimentModel):
    def __init__(self, a_mu=None, a_sig=None, th_mu=None, th_sig=None, tau=24.,
                 n_parallel=1, obs_sd=0.005, obs_label="y"):
        super().__init__()
        self.a_mu = a_mu if a_mu is not None \
            else torch.ones(n_parallel, 1, 1) * -1.4
        self.a_sig = a_sig if a_sig is not None \
            else torch.ones(n_parallel, 1, 1) * 1.35
        self.th_mu = th_mu if th_mu is not None \
            else torch.ones(n_parallel, 1, 1) * -1.4
        self.th_sig = th_sig if th_sig is not None \
            else torch.ones(n_parallel, 1, 1) * 1.35
        self.tau = tau
        self.n_parallel = n_parallel
        self.obs_sd = obs_sd
        self.obs_label = obs_label
        self.var_names = ["a", "th"]
        self.var_dim = 1
        self.sanity_check()

    def make_model(self):
        def model(design):
            if is_bad(design):
                raise ArithmeticError("bad design, contains nan or inf")
            design = design.float()
            batch_shape = design.shape[:-2]
            with ExitStack() as stack:
                for plate in iter_plates_to_shape(batch_shape):
                    stack.enter_context(plate)
                a_shape = batch_shape + self.a_mu.shape[-1:]
                a = src.sample(
                    "a",
                    dist.LogNormal(
                        self.a_mu.expand(a_shape),
                        self.a_sig.expand(a_shape)
                    ).to_event(1)
                )
                a = a.expand(a.shape[:-1] + design.shape[-2:-1])
                th_shape = batch_shape + self.th_mu.shape[-1:]
                th = src.sample(
                    "th",
                    dist.LogNormal(
                        self.th_mu.expand(th_shape),
                        self.th_sig.expand(th_shape)
                    ).to_event(1)
                )
                th = th.expand(th.shape[:-1] + design.shape[-2:-1])
                diff_func = partial(
                    holling3,
                    a.flatten(),
                    th.flatten())
                int_sol = odeint(
                    diff_func,
                    design.flatten(),
                    torch.tensor([0., self.tau]),
                    method="rk4",
                    options={'step_size': 1.})
                n_t = int_sol[-1].reshape(design.shape)
                p_t = (design - n_t) / design
                emission_dist = dist.Binomial(design.reshape(a.shape),
                                              p_t.reshape(a.shape)).to_event(1)
                n = src.sample(
                    self.obs_label, emission_dist
                )
                return n

        return model

    def reset(self, n_parallel):
        self.n_parallel = n_parallel
        self.a_mu = torch.ones(n_parallel, 1, 1) * -1.4
        self.a_sig = torch.ones(n_parallel, 1, 1) * 1.35
        self.th_mu = torch.ones(n_parallel, 1, 1) * -1.4
        self.th_sig = torch.ones(n_parallel, 1, 1) * 1.35


class SourceModel(ExperimentModel):
    def __init__(self, d=2, k=2, theta_mu=None, theta_sig=None, alpha=None,
                 b=1e-1, m=1e-4, n_parallel=1, obs_sd=0.5, obs_label="y"):
        super().__init__()
        self.theta_mu = theta_mu if theta_mu is not None \
            else torch.zeros(n_parallel, 1, k, d)
        self.theta_sig = theta_sig if theta_sig is not None \
            else torch.ones(n_parallel, 1, k, d)
        self.alpha = alpha if alpha is not None \
            else torch.ones(n_parallel, 1, k)
        self.d, self.k, self.b, self.m = d, k, b, m
        self.obs_sd, self.obs_label = obs_sd, obs_label
        self.n_parallel = n_parallel
        self.var_names = ["theta"]
        self.var_dim = d
        self.sanity_check()

    def make_model(self):
        def model(design):
            if is_bad(design):
                raise ArithmeticError("bad design, contains nan or inf")
            batch_shape = design.shape[:-2]
            with ExitStack() as stack:
                for plate in iter_plates_to_shape(batch_shape):
                    stack.enter_context(plate)
                theta_shape = batch_shape + self.theta_mu.shape[-2:]
                theta = src.sample(
                    "theta",
                    dist.Normal(
                        self.theta_mu.expand(theta_shape),
                        self.theta_sig.expand(theta_shape)
                    ).to_event(2)
                )
                distance = torch.square(theta - design).sum(dim=-1)
                ratio = self.alpha / (self.m + distance)
                mu = self.b + ratio.sum(dim=-1, keepdims=True)
                emission_dist = dist.Normal(
                    torch.log(mu), self.obs_sd
                ).to_event(1)
                y = src.sample(self.obs_label, emission_dist)
                return y

        return model

    def reset(self, n_parallel):
        self.n_parallel = n_parallel
        self.theta_mu = torch.zeros(n_parallel, 1, self.k, self.d)
        self.theta_sig = torch.ones(n_parallel, 1, self.k, self.d)
        self.alpha = torch.ones(n_parallel, 1, self.k)



class LinGaussSEModel(ExperimentModel):
    def __init__(self, d=2, k=2, theta_mu=None, theta_sig=None, alpha=None,
                 b=1e-1, m=1e-4, n_parallel=1, obs_sd=0.5, obs_label="y"):
        super().__init__()
        # self.theta_mu = theta_mu if theta_mu is not None \
        #     else torch.zeros(n_parallel, 1, k, d)
        # self.theta_sig = theta_sig if theta_sig is not None \
        #     else torch.ones(n_parallel, 1, k, d)
        # self.alpha = alpha if alpha is not None \
        #     else torch.ones(n_parallel, 1, k)
        self.top_order_score_means = torch.zeros(n_parallel, 1, d)
        self.top_order_score_stds = torch.ones(n_parallel, 1, d)
        self.graph_edge_probs = 0.75 * torch.ones(n_parallel, 1, d, d)
        self.lin_coefficient_mean = torch.zeros(n_parallel, 1, d, d)
        self.lin_coefficient_std = torch.ones(n_parallel, 1, d, d)
        self.noise_scale_mean = 0.25 * torch.ones(n_parallel, 1, d)
        self.noise_scale_std = 0.25 * torch.ones(n_parallel, 1, d)
        self.d, self.k, self.b, self.m = d, k, b, m
        self.obs_sd, self.obs_label = obs_sd, obs_label
        self.n_parallel = n_parallel
        self.var_names = ["topological_order_scores", "full_graph", "coefficients", "noise_scales"]
        self.var_dim = 2 * d
        self.sanity_check()

    def make_model(self):
        def model(design):
            if is_bad(design):
                raise ArithmeticError("bad design, contains nan or inf")
            batch_shape = design.shape[:-2]
            with ExitStack() as stack:
                for plate in iter_plates_to_shape(batch_shape):
                    stack.enter_context(plate)

                ######################################################################
                # SEM latent variables
                ######################################################################

                # Graph
                order_score_shape = batch_shape + self.top_order_score_means.shape[-1:]
                top_order_scores = src.sample(
                    "topological_order_scores",
                    dist.Normal(
                        self.top_order_score_means.expand(order_score_shape),
                        self.top_order_score_stds.expand(order_score_shape),
                    ).to_event(1)
                )
                topological_order_mask = top_order_scores.unsqueeze(-2) > top_order_scores.unsqueeze(-1)
                graph_shape = batch_shape + self.graph_edge_probs.shape[-2:]
                full_graph =  src.sample(
                    "full_graph",
                    dist.Bernoulli(
                        probs=self.graph_edge_probs.expand(graph_shape),
                    ).to_event(2)
                )
                graph = topological_order_mask * full_graph

                # Linear coefficients and noise
                coef_shape = batch_shape + self.lin_coefficient_mean.shape[-2:]
                lin_coefficients = src.sample(
                    "coefficients",
                    dist.Normal(
                        self.lin_coefficient_mean.expand(coef_shape),
                        self.lin_coefficient_std.expand(coef_shape),
                    ).to_event(2)
                )
                noise_scale_shape = batch_shape + self.noise_scale_mean.shape[-1:]
                noise_scales = src.sample(
                    "noise_scales",
                    dist.Normal(
                        self.noise_scale_mean.expand(noise_scale_shape),
                        self.noise_scale_std.expand(noise_scale_shape)
                    ).to_event(1)
                )
                noise_scales = torch.nn.functional.softplus(noise_scales)

                ######################################################################
                # Intervention values and mask
                ######################################################################             
                intervention_mask = (design[..., :self.d] > 0).float()
                intervention_values = design[..., self.d:]

                ####################################################################
                # Sample y
                ####################################################################
                lin_an_sem = LinearANSEM(
                    graph,
                    lin_coefficients,
                    dist.Normal(0.0, noise_scales),
                    intervention_mask.squeeze(-2),
                    intervention_values.squeeze(-2),
                )

                # distance = torch.square(theta - design[..., :self.d]).sum(dim=-1)
                # ratio = self.alpha / (self.m + distance)
                # mu = self.b + ratio.sum(dim=-1, keepdims=True)
                # emission_dist = dist.Normal(
                #     torch.log(mu), self.obs_sd
                # ).to_event(1)
                # y = pyro.sample(self.obs_label, emission_dist)
                y = src.sample(self.obs_label, lin_an_sem)
                return y

        return model

    def reset(self, n_parallel):
        self.n_parallel = n_parallel
        # self.theta_mu = torch.zeros(n_parallel, 1, self.k, self.d)
        # self.theta_sig = torch.ones(n_parallel, 1, self.k, self.d)
        # self.alpha = torch.ones(n_parallel, 1, self.k)
        self.top_order_score_means = torch.zeros(n_parallel, 1, self.d)
        self.top_order_score_stds = torch.ones(n_parallel, 1, self.d)
        self.graph_edge_probs = 0.75 * torch.ones(n_parallel, 1, self.d, self.d)
        self.lin_coefficient_mean = torch.zeros(n_parallel, 1, self.d, self.d)
        self.lin_coefficient_std = torch.ones(n_parallel, 1, self.d, self.d)
        self.noise_scale_mean = 0.25 * torch.ones(n_parallel, 1, self.d)
        self.noise_scale_std = 0.25 * torch.ones(n_parallel, 1, self.d)


class LinearANSEM(dist.TorchDistribution):
    """
    Linear Additive Noise Structural Equation Model in Pyro.
    """

    has_rsample = True

    def __init__(self, graph, lin_coefficients, exogenous_noise_dist, intervention_mask, intervention_values):

        self.exogenous_noise_dist = exogenous_noise_dist
        surrogate_dist = self.exogenous_noise_dist.to_event(1)
        super().__init__(batch_shape=surrogate_dist.batch_shape, event_shape=surrogate_dist.event_shape, validate_args=False)
        self.graph = graph
        self.lin_coefficients = lin_coefficients
        self.num_nodes = self.graph.shape[-1]
        self.intervention_mask = intervention_mask
        self.intervention_values = intervention_values

    def predict(self, y):
        # Left multiplication of y -> rows represent the "from" and columns represent the "to" in our matrices
        return torch.matmul(y.unsqueeze(-2), self.graph * self.lin_coefficients).squeeze(-2)

    def rsample(self, sample_shape=torch.Size([])):
        z = self.exogenous_noise_dist.to_event(1).rsample(sample_shape)
        sample = torch.ones_like(z) * self.intervention_mask * self.intervention_values
        for i in range(self.num_nodes):
            sample = self.predict(sample) + z
            sample = self.intervention_mask * self.intervention_values + (1 - self.intervention_mask) * sample
        return sample

    def log_prob(self, y):
        predict = self.intervention_mask * self.intervention_values + (1 - self.intervention_mask) * self.predict(y)
        z = y - predict
        log_prob = self.exogenous_noise_dist.log_prob(z)
        # Do not add log prob terms for nodes which received a do-intervention
        return (log_prob * (1 - self.intervention_mask)).sum(-1)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(LinearANSEM, _instance)
        exogenous_noise_dist = self.exogenous_noise_dist.expand(batch_shape + [self.exogenous_noise_dist.batch_shape[-1]], _instance=_instance)
        new.__init__(self.graph, self.lin_coefficients, exogenous_noise_dist, self.intervention_mask, self.intervention_values)
        return new