import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from maskrcnn_benchmark import _C


# TODO: Use JIT to replace CUDA implementation in the future.
class _SigmoidFocalLoss(Function):
    @staticmethod
    def forward(ctx, logits, targets, gamma, alpha):
        ctx.save_for_backward(logits, targets)
        num_classes = logits.shape[1]
        ctx.num_classes = num_classes
        ctx.gamma = gamma
        ctx.alpha = alpha

        losses = _C.sigmoid_focalloss_forward(
            logits, targets, num_classes, gamma, alpha
        )
        return losses

    @staticmethod
    @once_differentiable
    def backward(ctx, d_loss):
        logits, targets = ctx.saved_tensors
        num_classes = ctx.num_classes
        gamma = ctx.gamma
        alpha = ctx.alpha
        d_loss = d_loss.contiguous()
        d_logits = _C.sigmoid_focalloss_backward(
            logits, targets, d_loss, num_classes, gamma, alpha
        )
        return d_logits, None, None, None, None


sigmoid_focal_loss_cuda = _SigmoidFocalLoss.apply


def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha):
    num_classes = logits.shape[1]
    dtype = targets.dtype
    device = targets.device
    class_range = torch.arange(1, num_classes + 1, dtype=dtype, device=device).unsqueeze(0)

    t = targets.unsqueeze(1)
    p = torch.sigmoid(logits)
    term1 = (1 - p) ** gamma * torch.log(p)
    term2 = p ** gamma * torch.log(1 - p)
    return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha)


class SigmoidFocalLoss(nn.Module):
    def __init__(self, gamma, alpha):
        super(SigmoidFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, logits, targets):
        if logits.is_cuda:
            loss_func = sigmoid_focal_loss_cuda
        else:
            loss_func = sigmoid_focal_loss_cpu

        loss = loss_func(logits, targets, self.gamma, self.alpha)
        return loss.sum()

    def __repr__(self):
        tmpstr = self.__class__.__name__ + "("
        tmpstr += "gamma=" + str(self.gamma)
        tmpstr += ", alpha=" + str(self.alpha)
        tmpstr += ")"
        return tmpstr


def token_sigmoid_softmax_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None):
    # Another modification is that because we use the cross entropy version, there is no frequent or not frequent class.
    # So we temporarily retired the design of alpha.

    assert (targets.dim() == 3)
    assert (pred_logits.dim() == 3)  # batch x from x to

    # reprocess target to become probability map ready for softmax
    targets = targets.float()
    target_num = targets.sum(-1) + 1e-8  # numerical stability
    targets = targets / target_num.unsqueeze(-1)  # T(x)

    if text_mask is not None:
        # reserve the last token for non object
        assert (text_mask.dim() == 2)
        text_mask[:, -1] = 1
        text_mask = (text_mask > 0).unsqueeze(1).repeat(1, pred_logits.size(1), 1)  # copy along the image channel
        pred_logits = pred_logits.masked_fill(~text_mask, -1000000)  # softmax

    out_prob = pred_logits.softmax(-1)

    filled_targets = targets.clone()
    filled_targets[filled_targets == 0] = 1.0

    weight = torch.clamp(targets - out_prob, min=0.001) / filled_targets
    weight = torch.pow(weight, gamma)  # weight = torch.pow(torch.clamp(target - out_prob, min=0.01), gamma)

    loss_ce = - targets * weight * pred_logits.log_softmax(
        -1)  # only those positives with positive target_sim will have losses.
    return loss_ce


def token_sigmoid_binary_focal_loss_v2(pred_logits, targets, alpha, gamma, text_mask=None):
    assert (targets.dim() == 3)
    assert (pred_logits.dim() == 3)  # batch x from x to

    if text_mask is not None:
        assert (text_mask.dim() == 2)

    # We convert everything into binary
    out_prob = pred_logits.sigmoid()
    out_prob_neg_pos = torch.stack([1 - out_prob, out_prob], dim=-1) + 1e-8  # batch x boxes x 256 x 2
    weight = torch.pow(-out_prob_neg_pos + 1.0, gamma)

    focal_zero = - weight[:, :, :, 0] * torch.log(out_prob_neg_pos[:, :, :, 0]) * (
            1 - alpha)  # negative class
    focal_one = - weight[:, :, :, 1] * torch.log(out_prob_neg_pos[:, :, :, 1]) * alpha  # positive class
    focal = torch.stack([focal_zero, focal_one], dim=-1)
    loss_ce = torch.gather(focal, index=targets.long().unsqueeze(-1), dim=-1)
    return loss_ce


def token_sigmoid_binary_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None):
    # binary version of focal loss
    # copied from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor with the reduction option applied.
    """
    assert (targets.dim() == 3)
    assert (pred_logits.dim() == 3)  # batch x from x to

    bs, n, _ = pred_logits.shape
    if text_mask is not None:
        assert (text_mask.dim() == 2)
        text_mask = (text_mask > 0).unsqueeze(1)
        text_mask = text_mask.repeat(1, pred_logits.size(1), 1)  # copy along the image channel dimension
        pred_logits = torch.masked_select(pred_logits, text_mask)
        targets = torch.masked_select(targets, text_mask)

        # print(pred_logits.shape)
        # print(targets.shape)

    p = torch.sigmoid(pred_logits)
    ce_loss = F.binary_cross_entropy_with_logits(pred_logits, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss


class TokenSigmoidFocalLoss(nn.Module):
    def __init__(self, alpha, gamma):
        super(TokenSigmoidFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets, text_masks=None, version="binary", **kwargs):
        if version == "binary":
            loss_func = token_sigmoid_binary_focal_loss
        elif version == "softmax":
            loss_func = token_sigmoid_softmax_focal_loss
        elif version == "binaryv2":
            loss_func = token_sigmoid_binary_focal_loss_v2
        else:
            raise NotImplementedError
        loss = loss_func(logits, targets, self.alpha, self.gamma, text_masks, **kwargs)
        return loss.sum()

    def __repr__(self):
        tmpstr = self.__class__.__name__ + "("
        tmpstr += "gamma=" + str(self.gamma)
        tmpstr += ", alpha=" + str(self.alpha)
        tmpstr += ")"
        return tmpstr
