import pytest
import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import pyro.optim
import pyro.poutine as poutine
from pyro.optim.multi import MixedMultiOptimizer, Newton, PyroMultiOptimizer, TorchMultiOptimizer
from tests.common import assert_equal

FACTORIES = [
    lambda: PyroMultiOptimizer(pyro.optim.Adam({'lr': 0.05})),
    lambda: TorchMultiOptimizer(torch.optim.Adam, {'lr': 0.05}),
    lambda: Newton(trust_radii={'z': 0.2}),
    lambda: MixedMultiOptimizer([(['y'], PyroMultiOptimizer(pyro.optim.Adam({'lr': 0.05}))),
                                 (['x', 'z'], Newton())]),
    lambda: MixedMultiOptimizer([(['y'], pyro.optim.Adam({'lr': 0.05})),
                                 (['x', 'z'], Newton())]),
]


@pytest.mark.parametrize('factory', FACTORIES)
def test_optimizers(factory):
    optim = factory()

    def model(loc, cov):
        x = pyro.param("x", torch.randn(2))
        y = pyro.param("y", torch.randn(3, 2))
        z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1))
        pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x)
        with pyro.plate("y_plate", 3):
            pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y)
        with pyro.plate("z_plate", 4):
            pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)

    loc = torch.tensor([-0.5, 0.5])
    cov = torch.tensor([[1.0, 0.09], [0.09, 0.1]])
    for step in range(200):
        tr = poutine.trace(model).get_trace(loc, cov)
        loss = -tr.log_prob_sum()
        params = {name: site['value'].unconstrained()
                  for name, site in tr.nodes.items()
                  if site['type'] == 'param'}
        optim.step(loss, params)

    for name in ["x", "y", "z"]:
        actual = pyro.param(name)
        expected = loc.expand(actual.shape)
        assert_equal(actual, expected, prec=1e-2,
                     msg='{} in correct: {} vs {}'.format(name, actual, expected))


def test_multi_optimizer_disjoint_ok():
    parts = [(['w', 'x'], pyro.optim.Adam({'lr': 0.1})),
             (['y', 'z'], pyro.optim.Adam({'lr': 0.01}))]
    MixedMultiOptimizer(parts)


def test_multi_optimizer_overlap_error():
    parts = [(['x', 'y'], pyro.optim.Adam({'lr': 0.1})),
             (['y', 'z'], pyro.optim.Adam({'lr': 0.01}))]
    with pytest.raises(ValueError):
        MixedMultiOptimizer(parts)
