import torch
import torch.nn.functional as F
from utils import *
import math
from group_utils import *

group_param_lie = LieParameterization('SOn', 3, 1)
group_param_euler = EulerParameterization('SOn', 3, 1)

def cdist_mean(x1, x2, p=2.0, *args, **kwargs):
    dim = x1.shape[1]
    return torch.cdist(x1.contiguous(), x2.contiguous(), p, *args, **kwargs) * (dim ** (-1 / p))


def loss_contrastive(enc, x, hinge_thresh, temperature, cosine_sim):
    x = x - 13.2
    device = get_device(enc)
    x = x.to(device)
    batch_size = x.shape[0] // 2
    x = enc(x)
    mask = torch.eye(batch_size, dtype=torch.float32).to(device)
    mask = mask.repeat(2, 2) - torch.eye(2 * batch_size, dtype=torch.float32).to(device)
    mask_ignore = 1 - torch.eye(2 * batch_size, dtype=torch.float32).to(device)

    if not cosine_sim:
        dist_score = ((x[:, None, :] - x[None, :, :])**2).sum(dim=-1)
        euclid_dist = torch.sqrt(dist_score + 1e-9) * mask_ignore
        zeros = torch.zeros_like(euclid_dist)
        loss_hinge = ((torch.maximum(zeros, hinge_thresh - euclid_dist)**2).sum(1)/ mask_ignore.sum(1)).mean()
        loss_sim = ((dist_score * mask).sum(1)/mask.sum(1)).mean()
        return loss_sim, loss_hinge
    else:
        x = F.normalize(x, dim=-1)
        sim_score = x @ x.T / temperature
        logits_max, _ = torch.max(sim_score, dim=1, keepdim=True)
        logits = sim_score - logits_max.detach()

        # compute log_prob
        exp_logits = torch.exp(logits) * mask_ignore
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-9)

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
        loss = - (temperature/0.07) * mean_log_prob_pos
        loss = loss.mean()
        return loss/2, loss/2




def loss_fn_ed(
        enc,
        x_list,
        barrier_type,
        hinge_thresh,
        cosine_sim=False,
        conformal_map=False,
        rotation_map=False,
        decompositions=1,
        origin_of_rotation=None
    ):
    device = get_device(enc)
    num_actions = len(x_list)
    assert num_actions >= 2
    if torch.is_tensor(x_list):
        x_list = [x_list[i].to(device) for i in range(num_actions)]
    else:
        x_list = [torch.Tensor(x).to(device) for x in x_list]

    z_list = [enc(x) for x in x_list]
    if rotation_map:
        z_list = [torch.cat((origin_of_rotation[None, :], z), axis=0) for z in z_list]  # add encoding of origin of rotation
    code_size = z_list[0].shape[1]

    if cosine_sim:
        z_list = [normlize_vector(z) for z in z_list]

    # -- symmetry loss --#
    subcode_size = code_size // decompositions
    loss_equiv = 0
    for k in range(decompositions):
        h_list = [z[:, k * subcode_size: (k + 1) * subcode_size] for z in z_list]

        # NOTE: conformal mapping will probably not work with decompositions
        if conformal_map:
            h_list = [(h[:, None, :] - h[None, :, :]).view(-1, subcode_size) for h in h_list]

        if cosine_sim or conformal_map:
            D_list = [1.0 - F.cosine_similarity(h[:, None, :], h[None, :, :], dim=2) for h in h_list]
        else:
            D_list = [cdist_mean(h, h, p=2) for h in h_list]
            # D_list = [D / torch.linalg.norm(D.view(-1)) for D in D_list] # normalize distances

        L_equiv = torch.zeros(num_actions, num_actions)
        for i in range(num_actions):
            for j in range(i + 1, num_actions):
                L_equiv[i, j] = torch.mean((D_list[i] - D_list[j]) ** 2)
                # L_equiv[i, j] = torch.mean(((D_list[i] - D_list[j]) / (D_list[i] + D_list[j] + 1e-3)) ** 2)
        cur_loss_equiv = torch.sum(L_equiv) / (num_actions * (num_actions - 1) / 2)
        # loss_equiv += cur_loss_equiv ** 2 / decompositions
        loss_equiv += cur_loss_equiv / decompositions

    # -- barrier loss --#
    z_all = torch.cat(z_list, dim=0)
    D_all = cdist_mean(z_all, z_all, p=2)
    # if cosine_sim:
    #     D_all = 1.0 - F.cosine_similarity(z_all[:, None, :], z_all[None, :, :], dim=2)
    mask = torch.eye(D_all.shape[0], dtype=torch.bool)

    use_hinge_loss = (hinge_thresh is not None)
    if not use_hinge_loss:
        hinge_thresh = 1  # just some random number

    if barrier_type == 'log':
        B_all = -torch.log(D_all[~mask] + 1e-9)
        B_min = -math.log(hinge_thresh)
    elif barrier_type == 'inv':
        B_all = 1.0 / (D_all[~mask] + 1e-9)
        B_min = 1.0 / hinge_thresh
    elif barrier_type == 'id':
        B_all = -D_all[~mask]
        B_min = -hinge_thresh
    else:
        assert False, 'Unknown `barrier_type`'

    if use_hinge_loss:
        loss_barrier = torch.mean(torch.maximum(torch.zeros(1, device=device), B_all - B_min))
    else:
        loss_barrier = torch.mean(B_all)  # TODO: improve?

    return loss_equiv, loss_barrier

def loss_fn_ed_inverse_action(
        enc,
        act_predictor,
        x_list,
        actions,
        origin_of_rotation,
        barrier_type,
        hinge_thresh,
        cosine_sim=False,
        conformal_map=False,
        rotation_map=False,
        decompositions=1
    ):
    device = get_device(enc)
    num_actions = len(x_list)

    assert num_actions >= 2
    if torch.is_tensor(x_list):
        x_list = [x_list[i].to(device) for i in range(num_actions)]
    else:
        x_list = [torch.Tensor(x).to(device) for x in x_list]
    batch_size = x_list[0].shape[0]
    actions = torch.Tensor(actions).to(device)
    z_list = [enc(x) for x in x_list]
    if rotation_map:
        z_list = [torch.cat((origin_of_rotation[None, :], z), axis=0) for z in z_list]  # add encoding of origin of rotation
    code_size = z_list[0].shape[1]

    if cosine_sim:
        z_list = [normlize_vector(z) for z in z_list]

    # -- symmetry loss --#
    subcode_size = code_size // decompositions
    loss_equiv = 0
    loss_action_pred = 0
    for k in range(decompositions):
        h_list = [z[:, k * subcode_size: (k + 1) * subcode_size] for z in z_list]

        # NOTE: conformal mapping will probably not work with decompositions
        if conformal_map:
            h_list = [(h[:, None, :] - h[None, :, :]).view(-1, subcode_size) for h in h_list]

        if cosine_sim or conformal_map:
            D_list = [1.0 - F.cosine_similarity(h[:, None, :], h[None, :, :], dim=2) for h in h_list]
        else:
            D_list = [cdist_mean(h, h, p=2) for h in h_list]
            # D_list = [D / torch.linalg.norm(D.view(-1)) for D in D_list] # normalize distances

        L_equiv = torch.zeros(num_actions, num_actions)
        L_inverse_action = torch.zeros(num_actions, num_actions)
        for i in range(num_actions):
            for j in range(i + 1, num_actions):
                L_equiv[i, j] = torch.mean((D_list[i] - D_list[j]) ** 2)
                eff_action = (actions[j] @ actions[i].T).unsqueeze(0).repeat(batch_size, 1, 1)
                pred_action = act_predictor(h_list[i], h_list[j])
                L_inverse_action = torch.mean((pred_action - eff_action) ** 2)

        cur_loss_equiv = torch.sum(L_equiv) / (num_actions * (num_actions - 1) / 2)
        # loss_equiv += cur_loss_equiv ** 2 / decompositions
        loss_equiv += cur_loss_equiv / decompositions
        loss_action_pred += torch.sum(L_inverse_action) / (num_actions * (num_actions - 1) / 2.0)

    # -- barrier loss --#
    z_all = torch.cat(z_list, dim=0)
    D_all = cdist_mean(z_all, z_all, p=2)
    # if cosine_sim:
    #     D_all = 1.0 - F.cosine_similarity(z_all[:, None, :], z_all[None, :, :], dim=2)
    mask = torch.eye(D_all.shape[0], dtype=torch.bool)

    use_hinge_loss = (hinge_thresh is not None)
    if not use_hinge_loss:
        hinge_thresh = 1  # just some random number

    if barrier_type == 'log':
        B_all = -torch.log(D_all[~mask] + 1e-9)
        B_min = -math.log(hinge_thresh)
    elif barrier_type == 'inv':
        B_all = 1.0 / (D_all[~mask] + 1e-9)
        B_min = 1.0 / hinge_thresh
    elif barrier_type == 'id':
        B_all = -D_all[~mask]
        B_min = -hinge_thresh
    else:
        assert False, 'Unknown `barrier_type`'

    if use_hinge_loss:
        loss_barrier = torch.mean(torch.maximum(torch.zeros(1, device=device), B_all - B_min))
    else:
        loss_barrier = torch.mean(B_all)  # TODO: improve?

    return loss_equiv, loss_barrier, loss_action_pred

def normlize_vector(vector):
    '''
    :param vector: takes a batch of single vector of size [batch_size, vect_dim]
    :return:
            returns normalized (in the vect_dim) vectors
    '''
    return vector / (torch.linalg.norm(vector, dim=-1, keepdim=True) + 1e-9)

def get_orthonormal_frame(vectors):
    '''
        Function to do gram scmidt orthonormalization
        :param vector: takes a batch of two vectors (3d) of size [batch_size, vect_dim * 2]
        :return:
                A batch of 3 othonormal vectors which defines a orthonormal frame
                size [batch_size, 3, vect_dim]
                note that vect_dim is 3 for 3d vectors
    '''
    processed_vector_list = []
    for k in range(3):
        if k in (0, 1):
            vector = vectors[:, k * 3: (k + 1) * 3]
            if k == 0:
                vector_norm = normlize_vector(vector)
                processed_vector_list.append(vector_norm)
            else:
                vector_dot = batch_dot(processed_vector_list[0], vector)
                vector = vector - vector_dot[:, None] * processed_vector_list[0]
                vector_norm = normlize_vector(vector)
                processed_vector_list.append(vector_norm)
                assert torch.abs(F.cosine_similarity(vector_norm, processed_vector_list[0])).mean() < 1e-3
        elif k == 2:
            vector_norm = torch.cross(
                processed_vector_list[0], processed_vector_list[1]
            )
            processed_vector_list.append(vector_norm)
            assert torch.abs(F.cosine_similarity(vector_norm, processed_vector_list[0])).mean() < 1e-3
            assert torch.abs(F.cosine_similarity(vector_norm, processed_vector_list[1])).mean() < 1e-3
        else:
            raise Exception("For So3 to use ortho-normalization stick to 2 vectors")
    return torch.cat(
        [vector.unsqueeze(1) for vector in processed_vector_list], dim=1
    )

def loss_fn_cosine_so3(
        enc,
        x_list,
        barrier_type,
        hinge_thresh
):
    '''
    This function is just for SO3 and only works with code_size 6 and decomposition 2
    '''
    device = get_device(enc)
    num_actions = len(x_list)
    assert num_actions >= 2
    if torch.is_tensor(x_list):
        x_list = [x_list[i].to(device) for i in range(num_actions)]
    else:
        x_list = [torch.Tensor(x).to(device) for x in x_list]
    code_list = [enc(x) for x in x_list]
    batch_size = code_list[0].shape[0]

    # normalization and Gram schmidt orthonormalization
    orthonormal_code_list = torch.cat(
        [get_orthonormal_frame(code).unsqueeze(0) for code in code_list],
        dim=0
    )   # shape [num_actions, batch_size, 3, vector_dim]
    # orthonormal_code_list = torch.cat(
    #     [
    #         group_param_euler.get_group_rep(code[:, :3].unsqueeze(1)).squeeze().unsqueeze(0)
    #         for code in code_list
    #     ],
    #     dim=0
    # )  # shape [num_actions, batch_size, 3, vector_dim]

    # -- symmetry loss --#
    loss_equiv = 0.0
    for k in range(3):
        subcode_list = orthonormal_code_list[:, :, k, :]
        D_list = [
            F.cosine_similarity(subcode[:, None, :], subcode[None, :, :], dim=2)
            for subcode in subcode_list
        ]
        L_equiv = torch.zeros(num_actions, num_actions)
        for i in range(num_actions):
            for j in range(i + 1, num_actions):
                L_equiv[i, j] = torch.mean((D_list[i] - D_list[j]) ** 2)
        cur_loss_equiv = torch.sum(L_equiv) / (num_actions * (num_actions - 1) / 2.0)
        loss_equiv += cur_loss_equiv

    loss_orth = F.cosine_similarity(
                    orthonormal_code_list[:, :, None, :, :],
                    orthonormal_code_list[:, :, :, None, :],
                    dim=-1
                ).mean() - 0.33333

    # -- barrier loss --#
    codes_all = orthonormal_code_list.reshape(
        num_actions, batch_size, -1
    ).reshape(
        num_actions * batch_size, -1
    )
    D_all = cdist_mean(codes_all, codes_all, p=2)

    mask = torch.eye(D_all.shape[0], dtype=torch.bool)

    use_hinge_loss = (hinge_thresh is not None)
    if not use_hinge_loss:
        hinge_thresh = 1  # just some random number

    if barrier_type == 'log':
        B_all = -torch.log(D_all[~mask] + 1e-9)
        B_min = -math.log(hinge_thresh)
    elif barrier_type == 'inv':
        B_all = 1.0 / (D_all[~mask] + 1e-9)
        B_min = 1.0 / hinge_thresh
    elif barrier_type == 'id':
        B_all = -D_all[~mask]
        B_min = -hinge_thresh
    else:
        assert False, 'Unknown `barrier_type`'

    if use_hinge_loss:
        loss_barrier = torch.mean(torch.maximum(torch.zeros(1, device=device), B_all - B_min))
    else:
        loss_barrier = torch.mean(B_all)  # TODO: improve?

    return loss_equiv, loss_barrier, loss_orth


def loss_fn_cosine_so3_inverse_action(
        enc,
        x_list,
        actions
):
    '''
    This function is just for SO3 and only works with code_size 6 and decomposition 2
    '''
    device = get_device(enc)
    num_actions = len(x_list)
    assert num_actions >= 2
    if torch.is_tensor(x_list):
        x_list = [x_list[i].to(device) for i in range(num_actions)]
    else:
        x_list = [torch.Tensor(x).to(device) for x in x_list]
    code_list = [enc(x) for x in x_list]
    actions = torch.Tensor(actions).to(device)

    # normalization and Gram schmidt orthonormalization
    orthonormal_codes = torch.cat(
        [get_orthonormal_frame(code).unsqueeze(0) for code in code_list],
        dim=0
    )   # shape [num_actions, batch_size, 3, vector_dim]
    # orthonormal_code_list = torch.cat(
    #     [
    #         group_param_euler.get_group_rep(code[:, :3].unsqueeze(1)).squeeze().unsqueeze(0)
    #         for code in code_list
    #     ],
    #     dim=0
    # )  # shape [num_actions, batch_size, 3, vector_dim]

    # -- symmetry loss and inverse RL loss--#
    loss_equiv = 0.0
    loss_inverse_action = 0.0
    for k in range(3):
        subcodes = orthonormal_codes[:, :, k, :]
        # shape [num_actions, batch_size, vector_dim]
        D_list = [
            F.cosine_similarity(subcodes[i][:, None, :], subcodes[i][None, :, :], dim=-1)
            for i in range(num_actions)
        ]
        L_equiv = torch.zeros(num_actions, num_actions)
        L_inverse_action = torch.zeros(num_actions, num_actions)
        for i in range(num_actions):
            for j in range(i + 1, num_actions):
                L_equiv[i, j] = torch.mean((D_list[i] - D_list[j]) ** 2)
                eff_action = actions[j] @ actions[i].T
                subcodes_predicted = subcodes[i] @ eff_action.T
                L_inverse_action[i, j] = torch.mean((subcodes_predicted - subcodes[j]) ** 2)
        loss_equiv += torch.sum(L_equiv) / (num_actions * (num_actions - 1) / 2.0)
        loss_inverse_action += torch.sum(L_inverse_action) / (num_actions * (num_actions - 1) / 2.0)

    loss_orth = F.cosine_similarity(
                    orthonormal_codes[:, :, None, :, :],
                    orthonormal_codes[:, :, :, None, :],
                    dim=-1
                ).mean() - 0.33333

    return loss_equiv, loss_inverse_action, loss_orth



