import itertools

import pytest
import torch

from pyro.infer.util import torch_exp
from pyro.ops.einsum import contract
from tests.common import assert_equal


@pytest.mark.parametrize('min_size', [1, 2])
@pytest.mark.parametrize('equation', [
    ',ab->ab',
    'ab,,bc->a',
    'ab,,bc->b',
    'ab,,bc->c',
    'ab,,bc->ac',
    'ab,,b,bc->ac',
    'a,ab->ab',
    'ab,b,bc->a',
    'ab,b,bc->b',
    'ab,b,bc->c',
    'ab,b,bc->ac',
    'ab,bc->ac',
    'ab,bc,cd->',
    'ab,bc,cd->a',
    'ab,bc,cd->b',
    'ab,bc,cd->c',
    'ab,bc,cd->d',
    'ab,bc,cd->ac',
    'ab,bc,cd->ad',
    'ab,bc,cd->bc',
    'a,a,ab,b,b,b,b->a',
])
@pytest.mark.parametrize('infinite', [False, True], ids=['finite', 'infinite'])
def test_einsum(equation, min_size, infinite):
    inputs, output = equation.split('->')
    inputs = inputs.split(',')
    symbols = sorted(set(equation) - set(',->'))
    sizes = dict(zip(symbols, itertools.count(min_size)))
    shapes = [torch.Size(tuple(sizes[dim] for dim in dims))
              for dims in inputs]
    operands = [torch.full(shape, -float('inf')) if infinite else torch.randn(shape)
                for shape in shapes]

    expected = contract(equation, *(torch_exp(x) for x in operands), backend='torch').log()
    actual = contract(equation, *operands, backend='pyro.ops.einsum.torch_log')
    assert_equal(actual, expected)
