import torch
import torch.nn as nn


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

def no_softmax_cross_entropy(preds, targets, reduction='none'):
    loss = (-targets * torch.log(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

def get_CLIP_loss(CLIP_embs, CLAP_embs, temperature):
    logits = (CLAP_embs @ CLIP_embs.T) / temperature
    targets = torch.eye(logits.shape[0]).cuda()
    texts_loss = cross_entropy(logits, targets, reduction='none')
    images_loss = cross_entropy(logits.T, targets.T, reduction='none')
    loss = (images_loss + texts_loss) / 2.0  # shape: (batch_size)
    return loss.mean()

def get_item_L2_loss(source, target):
    L2_distance = torch.sum((source - target) ** 2, dim=-1)
    return L2_distance.mean()