from functools import reduce
import warnings

import torch

import pyro
import pyro.poutine as poutine
from pyro.poutine.util import prune_subsample_sites


def _guess_max_plate_nesting(model, args, kwargs):
    """
    Guesses max_plate_nesting by running the model once
    without enumeration. This optimistically assumes static model
    structure.
    """
    with poutine.block():
        model_trace = poutine.trace(model).get_trace(*args, **kwargs)
    sites = [site for site in model_trace.nodes.values()
             if site["type"] == "sample"]

    dims = [frame.dim
            for site in sites
            for frame in site["cond_indep_stack"]
            if frame.vectorized]
    max_plate_nesting = -min(dims) if dims else 0
    return max_plate_nesting


def _predictive_sequential(model, posterior_samples, model_args, model_kwargs,
                           num_samples, sample_sites, return_trace=False):
    collected = []
    samples = [{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)]
    for i in range(num_samples):
        trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(*model_args, **model_kwargs)
        if return_trace:
            collected.append(trace)
        else:
            collected.append({site: trace.nodes[site]['value'] for site in sample_sites})

    return collected if return_trace else {site: torch.stack([s[site] for s in collected])
                                           for site in sample_sites}


def _predictive(model, posterior_samples, num_samples, return_sites=None,
                return_trace=False, parallel=False, model_args=(), model_kwargs={}):
    max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
    model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*model_args, **model_kwargs))
    reshaped_samples = {}

    for name, sample in posterior_samples.items():
        sample_shape = sample.shape[1:]
        sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape)
        reshaped_samples[name] = sample

    def _vectorized_fn(fn):
        """
        Wraps a callable inside an outermost :class:`~pyro.plate` to parallelize
        sampling from the posterior predictive.

        :param fn: arbitrary callable containing Pyro primitives.
        :return: wrapped callable.
        """

        def wrapped_fn(*args, **kwargs):
            with pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1):
                return fn(*args, **kwargs)

        return wrapped_fn

    if return_trace:
        trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\
            .get_trace(*model_args, **model_kwargs)
        return trace

    return_site_shapes = {}
    for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
        site_shape = (num_samples,) + model_trace.nodes[site]['value'].shape
        if isinstance(return_sites, (list, tuple, set)):
            if site in return_sites:
                return_site_shapes[site] = site_shape
        else:
            if (return_sites is not None) or (site not in reshaped_samples):
                return_site_shapes[site] = site_shape

    # handle _RETURN site
    if isinstance(return_sites, (list, tuple, set)) and '_RETURN' in return_sites:
        value = model_trace.nodes['_RETURN']['value']
        shape = (num_samples,) + value.shape if torch.is_tensor(value) else None
        return_site_shapes['_RETURN'] = shape

    if not parallel:
        return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples,
                                      return_site_shapes.keys(), return_trace=False)

    trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\
        .get_trace(*model_args, **model_kwargs)
    predictions = {}
    for site, shape in return_site_shapes.items():
        value = trace.nodes[site]['value']
        if site == '_RETURN' and shape is None:
            predictions[site] = value
            continue
        if value.numel() < reduce((lambda x, y: x * y), shape):
            predictions[site] = value.expand(shape)
        else:
            predictions[site] = value.reshape(shape)

    return predictions


class Predictive(object):
    """
    This class is used to construct predictive distribution. The predictive distribution is obtained
    by running model conditioned on latent samples from `posterior_samples`.

    .. warning::
        The interface for the :class:`Predictive` class is experimental, and
        might change in the future.

    :param model: Python callable containing Pyro primitives.
    :param dict posterior_samples: dictionary of samples from the posterior.
    :param callable guide: optional guide to get posterior samples of sites not present
        in `posterior_samples`.
    :param int num_samples: number of samples to draw from the predictive distribution.
        This argument has no effect if ``posterior_samples`` is non-empty, in which case,
        the leading dimension size of samples in ``posterior_samples`` is used.
    :param return_sites: sites to return; by default only sample sites not present
        in `posterior_samples` are returned.
    :type return_sites: list, tuple, or set
    :param bool parallel: predict in parallel by wrapping the existing model
        in an outermost `plate` messenger. Note that this requires that the model has
        all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`.
    """
    def __init__(self, model, posterior_samples=None, guide=None, num_samples=None,
                 return_sites=None, parallel=False):
        if posterior_samples is None and num_samples is None:
            raise ValueError("Either posterior_samples or num_samples must be specified.")

        posterior_samples = {} if posterior_samples is None else posterior_samples

        for name, sample in posterior_samples.items():
            batch_size = sample.shape[0]
            if num_samples is None:
                num_samples = batch_size
            elif num_samples != batch_size:
                warnings.warn("Sample's leading dimension size {} is different from the "
                              "provided {} num_samples argument. Defaulting to {}."
                              .format(batch_size, num_samples, batch_size), UserWarning)
                num_samples = batch_size

        if num_samples is None:
            raise ValueError("No sample sites in posterior samples to infer `num_samples`.")

        if return_sites is not None:
            assert isinstance(return_sites, (list, tuple, set))

        self.model = model
        self.posterior_samples = {} if posterior_samples is None else posterior_samples
        self.num_samples = num_samples
        self.guide = guide
        self.return_sites = return_sites
        self.parallel = parallel

    def get_samples(self, *args, **kwargs):
        """
        Returns dict of samples from the predictive distribution. By default, only sample sites not
        contained in `posterior_samples` are returned. This can be modified by changing the
        `return_sites` keyword argument of this :class:`Predictive` instance.

        :param args: model arguments.
        :param kwargs: model keyword arguments.
        """
        posterior_samples = self.posterior_samples
        if self.guide is not None:
            # use return_sites='' as a special signal to return all sites
            posterior_samples = _predictive(self.guide, posterior_samples, self.num_samples,
                                            return_sites='', parallel=self.parallel,
                                            model_args=args, model_kwargs=kwargs)
        return _predictive(self.model, posterior_samples, self.num_samples,
                           return_sites=self.return_sites, parallel=self.parallel,
                           model_args=args, model_kwargs=kwargs)

    def get_vectorized_trace(self, *args, **kwargs):
        """
        Returns a single vectorized `trace` from the predictive distribution. Note that this
        requires that the model has all batch dims correctly annotated via :class:`~pyro.plate`.

        :param args: model arguments.
        :param kwargs: model keyword arguments.
        """
        posterior_samples = self.posterior_samples
        if self.guide is not None:
            # use return_sites='' as a special signal to return all sites
            posterior_samples = _predictive(self.guide, posterior_samples, self.num_samples,
                                            return_sites='', parallel=self.parallel,
                                            model_args=args, model_kwargs=kwargs)
        return _predictive(self.model, posterior_samples, self.num_samples,
                           return_trace=True, model_args=args, model_kwargs=kwargs)
