import sys
from abc import ABCMeta, abstractmethod
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow_addons.losses.focal_loss import SigmoidFocalCrossEntropy
from utils.metrics import CostWeightGenerator

AVAIL_LOSSES = [
    "CategoricalCrossentropy",
    "SphericalEmbeddingConstraint",
    "SphericalEmbeddingConstraint4Weight",
    "Multiplet",
    "LLLR",
    "LSEL",
    "LSIF",
    "LSIFwC", 
    "DSKL", 
    "BARR",
    "Logistic",
    "NGA-LSEL",
    "CombinedMargin",
    "BinaryCrossentropy",
    "SigmoidFocalCrossentropy",
]

class CategoricalXentLoss():
        """
        RNN-incompatible.
        # Remark
        Cost-sensitive loss weighting is supported, 
        following the class-balanced loss:
        [Cui, Yin, et al. 
        "Class-balanced loss based on effective number of samples." 
        Proceedings of the IEEE/CVF Conference 
        on Computer Vision and Pattern Recognition. 2019.]
        (https://arxiv.org/abs/1901.05555).
        # Args
        labels: A tf.int32 Tensor with shape (batch,). Non-one-hot labels.
        logits: A tf.float32 Tensor with shape (batch, num classes). Logits.
        beta: None or a float larger than 0. 
            Larger beta leads to 
            more discriminative weights.
            If beta = 1, then weights are simply 
            the inverse class frequencies (1 / N_k, 
            where N_k is the sample size of class k).
            If beta = -1, weights = [1,1,...,1] (len=num classes).
            This is useful for non-cost-sensitive learning.
        # Returns
        loss: A float scalar Tensor.
        """
        def __init__(self, 
            classwise_sample_sizes=None, beta=-1, label_smoothing=0):
            self.beta = beta
            self.loss_function = tf.keras.losses.CategoricalCrossentropy(
                from_logits=True, 
                reduction=tf.keras.losses.Reduction.NONE,
                name='categorical_crossentropy',
                label_smoothing=label_smoothing)
            self.cwg = CostWeightGenerator(classwise_sample_sizes, beta)

        def __call__(self, labels, logits):
            num_classes = logits.shape[1]
            # Re-weighting
            if self.beta == -1:
                weights = tf.constant(1., dtype=tf.float32)
            else:
                weights = self.cwg(labels=labels) # shape = (batch,)

            # Calc loss
            loss = self.loss_function(
                y_true=tf.one_hot(
                    indices=labels, depth=num_classes, axis=1), 
                y_pred=logits,
                sample_weight=weights) # shape=(batch,)
            loss = tf.reduce_mean(loss) # scalar

            return loss


class BinaryXentLoss():
    """
    RNN-incompatible.
    Background class is assumed to exist in the dataset.
    Background class is defined to be the LAST class.
    E.g., If total num of classes = 18 (0,1,...,17), 
            the background class is 17.
    However, logits and labels_oh do not include the background class.
    Therefore, logits.shape[1] = labels_oh.shape[1] = num all classes - 1.
    # Args
    balance_fgbg: A bool. Loss reweighting for foreground-background
        balancing.
    balance_fgclasses: A bool. Loss reweighting for foreground
        classes.
    labels_oh: A tf.int32 Tensor with shape (batch, num fg classes). 
        One-hot labels.
    logits: A tf.float32 Tensor with shape (batch, num fg classes).
        Logits.
    classwise_sample_sizes: A list of ints. Len = num_classes INCLUDING
        the background class = logits.shape[-1] + 1.
    beta: A float larger than 0. 
        Larger beta leads to 
        more discriminative weights.
        - If beta = 1, then weights are simply 
            the inverse class frequencies (1 / N_k, 
            where N_k is the sample size of class k).
        - If beta = -1, weights = [1,1,...,1] (len=num classes).
    # Returns
    loss: A float scalar Tensor.
    # Remark 
    Cost-sensitive loss weighting is supported, 
    following the class-balanced loss:
    [Cui, Yin, et al. 
    "Class-balanced loss based on effective number of samples." 
    Proceedings of the IEEE/CVF Conference 
    on Computer Vision and Pattern Recognition. 2019.]
    (https://arxiv.org/abs/1901.05555).
    """
    def __init__(self, 
        balance_fgbg,
        balance_fgclasses,
        classwise_sample_sizes=None, beta=-1, label_smoothing=0):
        self.balance_fgbg = balance_fgbg
        self.balance_fgclasses = balance_fgclasses
        self.beta = beta
        self.loss_function = tf.keras.losses.BinaryCrossentropy(
            from_logits=True, 
            reduction=tf.keras.losses.Reduction.NONE,
            name='binary_crossentropy',
            label_smoothing=label_smoothing)
        self.cwg_fgbg = CostWeightGenerator(
            [int(sum(classwise_sample_sizes[:-1])), 
                classwise_sample_sizes[-1]], 
            beta) # Two classes
        self.cwg_fgclass = CostWeightGenerator(
            classwise_sample_sizes[:-1], 
            beta) # Num fg classes
        self.max_batch_size = 0
        self.num_classes_all = len(classwise_sample_sizes)

    def __call__(self, labels_oh, logits):
        # Error handling
        if len(logits.shape) > 2:
            msg = "BCE loss for time series data not currentrly supported."
            raise NotImplementedError(msg)

        assert self.num_classes_all - 1 == labels_oh.shape[1]
        assert self.num_classes_all - 1 == logits.shape[1]

        # Preprocessing
        tmp_batch_size = logits.shape[0]
        if self.max_batch_size < tmp_batch_size:
            self.max_batch_size = tmp_batch_size
            self.weights_fgclasses_org = tf.stack(
                [self.cwg_fgclass()] * self.max_batch_size, axis=0)
                # (max_batch_size, num fg classes)

        # Calc loss weights (loss reweighting)
        if self.balance_fgbg:
            if self.beta != -1:
                tmp_labels = tf.cast(
                    tf.reduce_all(tf.equal(labels_oh, 0), axis=1),
                    tf.int32) # (tmp_batch, )
                    # If a sample in the batch comes from 
                    # background class, tmp_label = 1.
                weights_fgbg = self.cwg_fgbg(labels=tmp_labels)
                    # (tmp_batch, )
                weights_fgbg = tf.tile(weights_fgbg, [logits.shape[1]])
                    # (num fg classes * tmp_batch, )
            else:
                weights_fgbg = tf.constant(1., dtype=tf.float32)
        else:
            weights_fgbg = tf.constant(1., dtype=tf.float32)

        if self.balance_fgclasses:
            if self.beta != -1:
                weights_fgclasses = self.weights_fgclasses_org[:tmp_batch_size]
                    # (tmp_batch, num fg classes, )
                weights_fgclasses = tf.transpose(weights_fgclasses, [1, 0])
                    # (num fg classes, tmp_batch)
                weights_fgclasses = tf.reshape(weights_fgclasses, [-1])
                    # (num fg classes * tmp_batch, )
            else:
                weights_fgclasses = tf.constant(1., dtype=tf.float32)
        else:
            weights_fgclasses = tf.constant(1., dtype=tf.float32)

        weights = weights_fgbg * weights_fgclasses # (num fg classes * tmp_batch, )

        # Calc loss
        y_true = tf.reshape(
            tf.transpose(labels_oh, [1, 0]), 
            [-1, 1])
        y_pred = tf.reshape(
            tf.transpose(logits, [1, 0]),
            [-1, 1])
        loss = self.loss_function(
            y_true=y_true, 
            y_pred=y_pred,
            sample_weight=weights
            ) # shape=(tmp_batch,)
        loss = tf.reduce_mean(loss) # scalar

        return loss


class SigmoidFocalXentLoss():
    """
    RNN-compatible.
    Background class is assumed to exist in the dataset.
    Background class is defined to be the LAST class.
    E.g., If total num of classes = 18 (0,1,...,17), 
        the background class is 17.
    However, logits and labels_oh do not include the background class.
    Therefore, logits.shape[1] = labels_oh.shape[1] = num all classes - 1.
    # Args
    balance_fgbg: A bool. Loss reweighting for foreground-background
        balancing.
    balance_fgclasses: A bool. Loss reweighting for foreground
        classes.
    classwise_sample_sizes: A list of ints. Len = num_classes INCLUDING
        the background class = logits.shape[-1] + 1.
    beta: A float larger than 0. 
        Larger beta leads to 
        more discriminative weights.
        - If beta = 1, then weights are simply 
            the inverse class frequencies (1 / N_k, 
            where N_k is the sample size of class k).
        - If beta = -1, weights = [1,1,...,1] (len=num classes).
    # Returns
    loss: A float scalar Tensor.
    # Remark 
    Cost-sensitive loss weighting is supported, 
    following the class-balanced loss:
    [Cui, Yin, et al. 
    "Class-balanced loss based on effective number of samples." 
    Proceedings of the IEEE/CVF Conference 
    on Computer Vision and Pattern Recognition. 2019.]
    (https://arxiv.org/abs/1901.05555).
    """
    def __init__(self, 
        balance_fgbg,
        balance_fgclasses,
        classwise_sample_sizes=None, beta=-1, alpha=0.25, gamma=2.0):
        self.balance_fgbg = balance_fgbg
        self.balance_fgclasses = balance_fgclasses
        self.beta = beta
        self.loss_function = tfa.losses.SigmoidFocalCrossEntropy(
            from_logits=True, 
            reduction=tf.keras.losses.Reduction.NONE,
            name='sigmoid_focal_crossentropy',
            alpha=alpha,
            gamma=gamma)
        self.cwg_fgbg = CostWeightGenerator(
            [int(sum(classwise_sample_sizes[:-1])), 
                classwise_sample_sizes[-1]], 
            beta) # Two classes
        self.cwg_fgclass = CostWeightGenerator(
            classwise_sample_sizes[:-1], 
            beta) # Num fg classes
        self.max_batch_size = 0
        self.num_classes_all = len(classwise_sample_sizes)

    def __call__(self, labels_oh, logits):
        """
        # Args
        labels_oh: One-hot labels with shape 
            (batch, num fg classes) (non-temporal data)
            or (batch, duration, num fg classes) (time series data). 
        logits: Logits with shape 
            (batch, num fg classes) (non-temporal data or time series data with
            videowise annotations)
            or (batch, duration, num fg classes) (time series data with 
            framewise annotations).
        # Return
        loss: A scalar.
        """
        assert self.num_classes_all - 1 == labels_oh.shape[1]
        assert self.num_classes_all - 1 == logits.shape[1]

        # RNN-compatible
        if len(logits.shape) == 3:
            duration = logits.shape[1]

            # Reshape logits
            logits = tf.transpose(logits, [1, 0, 2])
                # (duration, batch, num fg classes)
            logits = tf.reshape(logits, [-1, logits.shape[2]])
                # (duration * batch, num fg classes)

            # Reshape labels_oh
            if len(labels_oh) == 3: # framewise annotation
                labels_oh = tf.transpose(labels_oh, [1, 0, 2])
                labels_oh = tf.reshape(labels_oh, [-1, labels_oh.shape[2]])
                    # (duration * batch, num fg classes)

            elif len(labels_oh) == 2: # videowise annotation
                labels_oh = tf.tile(labels_oh, [duration, 1])
                    # (duration * batch, num fg classes)

            else:
                msg = "len(labels_oh.shape) must be 2 or 3. Got {}".\
                    format(len(labels_oh.shape))
                raise ValueError(msg)

        # Preprocessing
        tmp_batch_size = logits.shape[0]
        if self.max_batch_size < tmp_batch_size:
            self.max_batch_size = tmp_batch_size
            self.weights_fgclasses_org = tf.stack(
                [self.cwg_fgclass()] * self.max_batch_size, axis=0)
                # (max_batch_size, num fg classes)

        # Calc loss weights (loss reweighting)
        if self.balance_fgbg:
            if self.beta != -1:
                tmp_labels = tf.cast(
                    tf.reduce_all(tf.equal(labels_oh, 0), axis=1), 
                    tf.int32) # (tmp_batch, )
                    # If a sample in the batch comes from 
                    # background class, tmp_label = 1.
                weights_fgbg = self.cwg_fgbg(labels=tmp_labels) 
                    # (tmp_batch, )
                weights_fgbg = tf.tile(weights_fgbg, [logits.shape[1]]) 
                    # (num fg classes * tmp_batch, )
            else:
                weights_fgbg = tf.constant(1., dtype=tf.float32)
        else:
            weights_fgbg = tf.constant(1., dtype=tf.float32)

        if self.balance_fgclasses:
            if self.beta != -1:
                weights_fgclasses = self.weights_fgclasses_org[:tmp_batch_size]
                    # (tmp_batch, num fg classes, )
                weights_fgclasses = tf.transpose(weights_fgclasses, [1, 0])
                    # (num fg classes, tmp_batch)
                weights_fgclasses = tf.reshape(weights_fgclasses, [-1])
                    # (num fg classes * tmp_batch, )
            else:
                weights_fgclasses = tf.constant(1., dtype=tf.float32)
        else:
            weights_fgclasses = tf.constant(1., dtype=tf.float32)

        weights = weights_fgbg * weights_fgclasses # (num fg classes * tmp_batch, )

        # Calc loss
        y_true = tf.reshape(
            tf.transpose(labels_oh, [1, 0]), 
            [-1, 1])
        y_pred = tf.reshape(
            tf.transpose(logits, [1, 0]),
            [-1, 1])
        loss = self.loss_function(
            y_true=y_true, 
            y_pred=y_pred,
            sample_weight=weights
            ) # shape=(tmp_batch,)
        loss = tf.reduce_mean(loss) # scalar

        return loss


class SphericalEmbeddingConstraint():
    """ Spherical Embedding Constraint (SEC) [NeurIPS2020]
        https://arxiv.org/abs/2011.02785
        RNN-compatible.
    """ 
    def __call__(self, feature):
        """
        # Args
        feature: A tf.float32 Tensor with shape (..., feat dims).
        # Returns
        sec: A scalar Tensor.
        """
        feat_dims = feature.shape[-1]
        feature = tf.reshape(feature, [-1, feat_dims])
        feature = feature ** 2
        feature /= feat_dims 
            # To avoid curse of dim.
            # No such division in the original paper.
            # shape = (-1,)
        feature = tf.reduce_sum(feature, axis=-1) 
        mu = tf.stop_gradient(
            tf.reduce_mean(feature, axis=0, keepdims=True))
            # shape = (1,)
        sec = tf.reduce_mean((feature - mu) ** 2)
        return sec, mu[0] # scalar


class SphericalEmbeddingConstraint4Weight():
    """ Spherical Embedding Constraint (SEC) [NeurIPS2020]
        https://arxiv.org/abs/2011.02785
        is applied to the last Dense layer.
        RNN-compatible.
    # Note
    The scale of weights may vary significantly,
    depending on the initialization.
    """ 
    def __call__(self, weight):
        """
        # Args
        weights: A Tensor with shape (feat dims, num classes).
        # Returns
        sec: A scalar Tensor.
        """
        weight = weight ** 2
        weight = tf.reduce_sum(weight, axis=0)
            # (num classes) 
        #weight /= feat_dims 
            # To avoid curse of dim.
            # No such division in the original paper.
            # shape = (num classes,)
            ###################################
            # Now this is commented out because
            # otherwise sec can be around 1e-12.
            ###################################
        mu = tf.stop_gradient(
            tf.reduce_mean(weight, axis=0, keepdims=True))
            # shape = (1,)
        sec = tf.reduce_mean((weight - mu) ** 2)
        return sec, mu[0] # scalar


class MultipletLoss():
    def __init__(self,
        classwise_sample_sizes=None, beta=None, label_smoothing=0):
        """
        RNN-compatible. CNN-incompatible.
        # Args
        beta: A float larger than 0. Larger beta leads to 
            more discriminative weights.
            If beta = 1, then weights are simply 
            the inverse class frequencies (1 / N_k, 
            where N_k is the sample size of class k).
            If beta = -1, weights = [1,1,...,1] (len=num classes).
            This is useful for non-cost-sensitive learning.
        classwise_sample_sizes: A list of integers. 
            The length is equal to the number of classes.
        """
        self.classwise_sample_sizes = classwise_sample_sizes
        self.beta = beta
        self.loss_function = tf.keras.losses.CategoricalCrossentropy(
            from_logits=True, 
            reduction=tf.keras.losses.Reduction.NONE,
            name='categorical_crossentropy',
            label_smoothing=label_smoothing)
        self.cwg = CostWeightGenerator(classwise_sample_sizes, beta)


    def __call__(self, logits_slice, labels_slice):
        """ 
        Multiplet loss for density estimation of time-series data.
        Cost-sensitive loss weighting is supported, 
        following the class-balanced loss:
        [Cui, Yin, et al. 
        "Class-balanced loss based on effective number of samples." 
        Proceedings of the IEEE/CVF Conference 
        on Computer Vision and Pattern Recognition. 2019.]
        (https://arxiv.org/abs/1901.05555).
        # Args
        logits_slice: An logit Tensor with shape 
            ((effective) batch size, order of SPRT + 1, num classes). 
            This is the output of LSTMModel.call(inputs, training).
        labels_slice: A label Tensor with shape 
            ((effective) batch size,) (videowise annotation)
            or ((effective) batch size, duration) (framewise annotation). 
        # Return
        multiplet: A scalar Tensor. Sum of multiplet losses.
        """
        effbs, order_tandem, num_classes = logits_slice.shape
        order_tandem -= 1

        # Calc logits and reshape-with-copy labels
        logits = tf.transpose(logits_slice, [1, 0, 2]) # (T, B, C)
        logits = tf.reshape(logits, [-1, num_classes]) # (T*B, C)

        if len(labels_slice.shape) == 1: # videowise annotation
            labels = tf.tile(labels_slice, [order_tandem + 1,]) # (T*B,)
            labels = tf.one_hot(
                indices=labels, depth=num_classes, axis=1) # (T*B, C)
        elif len(labels_slice.shape) == 2: # framewise annotation 
            labels = tf.reshape(labels_slice, [-1, num_classes])
        else:
            msg = "len(labels_slice.shape) must be 1 or 2. Got {}.".\
                format(len(labels_slice.shape))
            raise ValueError(msg)

        # Re-weighting
        if self.beta == -1:
            weights = tf.cast(1. / effbs / (order_tandem + 1), dtype=tf.float32)

        else:
            weights = self.cwg(labels=labels) # shape = (T*B,)

        # Calc multiplet losses
        multiplet = self.loss_function(
            y_true=labels, 
            y_pred=logits,
            sample_weight=weights) # shape=(T*B,)
        multiplet = tf.reduce_mean(multiplet) # scalar
            # A scalar averaged wrt. B*T

        return multiplet


class CombinedMarginLoss():
    """ Ref: https://github.com/auroua/InsightFace_TF/blob/master/losses/face_losses.py
    RNN-compatible.
    # Margins
    logits of correct classes = s (cos(a theta + m) - b)
    logits of the others = s cos(theta)
    """
    def __init__(self, classwise_sample_sizes=None, beta=-1, 
        margin_a=1.35, margin_m=0.5, margin_b=0.4, s=64):
        """
        # Args
        embedding: Bottleneck features. Shape = 
            (batch, feat_dims) or (batch, duration, feat_dims)
        labels: Shape = 
            (batch, ) 
            (non-temporal data or time series data with
            videowise annotations) 
            or 
            (batch, duration) 
            (time series data with framewise annotations).
        weights: A Tensor with shape (feat_dims, num_classes).
        s: A scalar. Defaults to 64.
        margin_a:  SphareFace. Defaults to 1.35.
        margin_m: s (cos(a theta + m) - b). ArcFace. Defaults to 0.5.
        margin_b: s (cos(a theta + m) - b). CosFace. Defaults to 0.4.
        """
        # Error handling
        assert margin_a > 0
        assert margin_b >= 0
        assert margin_m >= 0
        assert s > 0

        # Initialize
        self.beta = beta
        self.cwg = CostWeightGenerator(classwise_sample_sizes, beta)
        self.num_classes = len(classwise_sample_sizes)
        self.margin_a = margin_a
        self.margin_m = margin_m
        self.margin_b = margin_b
        self.s = s
        self.XentLoss = tf.keras.losses.CategoricalCrossentropy(
            from_logits=True, 
            reduction=tf.keras.losses.Reduction.NONE) # (T*B or B, )

    def CombinedMarginLoss_func(self, 
        embedding, labels, weightmtx, sample_weight):
        """
        # Args
        embedding: Bottleneck features. Shape = 
            (batch or batch*duration, feat_dims)
        labels: (batch or batch*duration, )
        weights: A Tensor with shape (feat_dims, num_classes).
        s: A scalar. Defaults to 64.
        margin_a: SphareFace. Defaults to 1.35.
        margin_m: s (cos(a theta + m) - b). ArcFace. Defaults to 0.5.
        margin_b: s (cos(a theta + m) - b). CosFace. Defaults to 0.4.
        """
        # Calc cos
        embedding_unit = tf.nn.l2_normalize(embedding, axis=1) # (T*B or B, D)
        weights_unit = tf.nn.l2_normalize(weightmtx, axis=0) # (D, C)
        cos_t = tf.matmul(embedding_unit, weights_unit) # (T*B or B, C)
        
        # Extract cos of correct classes
        ordinal = tf.constant(list(range(0, embedding.shape[0])), tf.int32) # (T*B or B, )
        ordinal_y = tf.stack([ordinal, labels], axis=1) # (T*B or B, 2)
            # [[0, labels[0]], [1, labels[1]], ..., [T*B-1 or B-1, labels[T*B-1 or B-1]]]
        zy = self.s * cos_t 
            # (T*B or B, C)
            # zy = cosines (logits) of the other classes
        sel_cos_t = tf.gather_nd(zy, ordinal_y) 
            # (T*B or B, )
            # sel_cos_t = cosines of correct classes
            
        # Combine margins
        if self.margin_a != 1.0 or self.margin_m != 0.0 or self.margin_b != 0.0:
            if self.margin_a == 1.0 and self.margin_m == 0.0:
                new_zy = sel_cos_t - self.s * self.margin_b # CosFace
                    # (T*B or B, )
            else:
                cos_value = sel_cos_t / self.s
                t = tf.math.acos(tf.clip_by_value(cos_value, -1., 1.))
                    # (T*B or B, )

                if self.margin_a != 1.0:
                    t = t * self.margin_a # SphereFace
                if self.margin_m > 0.0:
                    t = t + self.margin_m # ArcFace

                cond = tf.greater(t, np.pi)
                body = tf.where(
                    cond, 
                    tf.math.cos(t - np.pi) - 1, 
                    tf.math.cos(t)) # (T*B or B, )
                
                if self.margin_b > 0.0:
                    body = body - self.margin_b # CosFace

                new_zy = body * self.s
                    # (T*B or B, )

            # Calc logits
            # updated_logits can be directly sent to the softmax loss.
            updated_logits = tf.add( # result shape = (T*B or B, C)
                zy, 
                tf.scatter_nd(
                    ordinal_y, 
                    tf.subtract(new_zy, sel_cos_t), # to cancel zy of correct classes
                    zy.shape)) 
                    # tf.scatter_nd of 
                    # (T*B or B, 2), 
                    # (T*B or B), 
                    # (T*B or B, C) 
                    # gives (T*B or B, C)
        else:
            updated_logits = zy

        loss = self.XentLoss(
            y_true=tf.one_hot(
                indices=labels, depth=self.num_classes, axis=1),
            y_pred=updated_logits,
            sample_weight=sample_weight) # shape=(batch,)

        loss = tf.reduce_mean(loss) # scalar

        return loss

    def __call__(self, labels, bottleneck, weightmtx):
        """
        # Args
        bottleneck: Bottleneck features. 
            Shape = (batch, feat_dims) or (batch, duration, feat_dims)
        labels: (batch, ) or (batch, duration).
        weightmtx: A Tensor with shape (feat_dims, num_classes).
        """
        # Reshape if necessary
        if len(bottleneck.shape) == 3:
            feat_dims = bottleneck.shape[-1]
            duration = bottleneck.shape[1]
            bottleneck = tf.transpose(bottleneck, (1, 0, 2))
            bottleneck = tf.reshape(bottleneck, (-1, feat_dims)) # (T*B, D)
            if len(labels.shape) == 1:
                labels = tf.tile(labels, (duration,)) # (T*B, )
            elif len(labels.shape) == 2:
                labels = tf.transpose(labels, [1, 0])
                labels = tf.reshape(labels, [-1]) # (T*B,)
            else:
                msg = "len(labels.shape) must be 1 or 2. Got {}".\
                    format(len(labels.shape))
                raise ValueError(msg)
            
        # Re-weighting
        if self.beta == -1:
            sample_weight = tf.cast(1. / labels.shape[0], dtype=tf.float32) 
                # (T*B or B,)
        else:
            sample_weight = self.cwg(labels=labels)
                # (T*B or B,)
        
        # Calc loss
        loss = self.CombinedMarginLoss_func(
            embedding=bottleneck,
            labels=labels,
            weightmtx=weightmtx,
            sample_weight=sample_weight) # scalar

        return loss


class LossManager():
    def __init__(self, names_loss, list_kwargs, 
        classwise_sample_sizes=None, 
        framewise_annotation=False):
        """
        # Note
        The orders of names_loss and list_kwargs 
        must match each other.
        # Args
        num_classes: An int. Number of classes.
        names_loss: A list of str.
        list_kwargs: A list of dictionaries. 
            Used for the arguments of each loss.
        classwise_sample_size: None (no loss re-weighting) 
            or a list of integers. The length must be equal to 
            the number of classes.
        framewise_annotation: A bool for time series data.
        # Remark
        When you implement yet another loss function,
        1. Add a name to self.ls_available (and info_available_list.yaml).
        2. Add smth an elif sentence under # Get loss functions in __init__. 
        3. Add def LOSS_NAME_call if necessary. 
        4. Add smth an elif under # CASE sentence in __call__. 
        (5. Add self.X = None to __init__ if you defined a new class variable X.)
        """
        # Initialization
        self.names_loss = names_loss
        self.list_kwargs = list_kwargs
        self.classwise_sample_sizes = classwise_sample_sizes
        #self.num_classes = len(classwise_sample_sizes)
        self.framewise_annotation = framewise_annotation

        self.mu_sec = None
        self.mu_sec4w = None
        self.loss1_LLLR = None
        self.loss2_LLLR = None
        self.loss1_LSEL = None
        self.loss2_LSEL = None
        self.loss1_LSIF = None
        self.loss2_LSIF = None
        self.loss1_LSIFwC = None
        self.loss2_LSIFwC = None
        self.loss1_DSKL = None
        self.loss2_DSKL = None
        self.loss1_BARR = None
        self.loss2_BARR = None
        self.loss1_Logistic = None
        self.loss2_Logistic = None
        self.loss1_NGALSEL = None
        self.loss2_NGALSEL = None

        # Assert
        self.ls_available = AVAIL_LOSSES 
        for v in names_loss:
            if not v in self.ls_available:
                raise ValueError("{} not in the list of available losses: {}".\
                    format(v, self.ls_available))

        # Get loss functions
        self.loss_functions = []
        for iter_name_loss, iter_kwargs in zip(names_loss, list_kwargs): 
            if iter_name_loss == "CategoricalCrossentropy":
                lossf = CategoricalXentLoss(
                    classwise_sample_sizes=classwise_sample_sizes,
                    beta=iter_kwargs["beta"],
                    label_smoothing=iter_kwargs["label_smoothing"]) 
                self.loss_functions.append(lossf)

            elif iter_name_loss == "SphericalEmbeddingConstraint":
                lossf = SphericalEmbeddingConstraint()
                self.loss_functions.append(lossf)

            elif iter_name_loss == "SphericalEmbeddingConstraint4Weight":
                lossf = SphericalEmbeddingConstraint4Weight()
                self.loss_functions.append(lossf)              
                
            elif iter_name_loss == "CombinedMargin":
                lossf = CombinedMarginLoss(
                    classwise_sample_sizes=classwise_sample_sizes,
                    beta=iter_kwargs["beta"],
                    margin_a=iter_kwargs["margin_a"],
                    margin_m=iter_kwargs["margin_m"],
                    margin_b=iter_kwargs["margin_b"],
                    s=iter_kwargs["s"],
                    )
                self.loss_functions.append(lossf)

            elif iter_name_loss == "BinaryCrossentropy":
                lossf = BinaryXentLoss(
                    balance_fgbg=iter_kwargs["balance_fgbg"],
                    balance_fgclasses=iter_kwargs["balance_fgclasses"],
                    classwise_sample_sizes=classwise_sample_sizes,
                    beta=iter_kwargs["beta"],
                    label_smoothing=iter_kwargs["label_smoothing"]) 
                self.loss_functions.append(lossf)

            elif iter_name_loss == "BinaryCrossentropy":
                lossf = BinaryXentLoss(
                    balance_fgbg=iter_kwargs["balance_fgbg"],
                    balance_fgclasses=iter_kwargs["balance_fgclasses"],
                    classwise_sample_sizes=classwise_sample_sizes,
                    beta=iter_kwargs["beta"],
                    label_smoothing=iter_kwargs["label_smoothing"]) 
                self.loss_functions.append(lossf)

            elif iter_name_loss == "SigmoidFocalCrossentropy":
                lossf = SigmoidFocalXentLoss(
                    balance_fgbg=iter_kwargs["balance_fgbg"],
                    balance_fgclasses=iter_kwargs["balance_fgclasses"],
                    classwise_sample_sizes=classwise_sample_sizes,
                    beta=iter_kwargs["beta"],
                    alpha=iter_kwargs["alpha"],
                    gamma=iter_kwargs["gamma"])
                self.loss_functions.append(lossf)

            else:
                raise ValueError("{} not in the list of available losses: {}".\
                    format(iter_name_loss, self.ls_available))

    def __call__(self, labels, logits, bottleneck, model):
        """
        # Args
        labels: One-hot or non-one-hot labels.
            A tf.int32 Tensor with shape 
            (batch,) for images
            ((batch, num_classes(num foreground classes 
            for BinaryCrossentropy and SigmoidFocalCrossentropy))),
            (batch, ) for sequences, and
            (batch * (duration - order_tandem), ) 
            for sequences with sequential_slice.
            For framewise annotations, shape = 
            (batch, duration,) or (batch, duration, num fg classes)
        logits: Logits.
            A tf.float32 Tensor with shape 
            (batch, num_classes(num foreground classes for 
            BinaryCrossentropy and SigmoidFocalCrossentropy))
            for images,
            (batch, duration, num_classes)
            for sequences, and
            (batch * (duration - order_tandem), order_tandem + 1, num_classes)
            for sequences with sequential_slice. 
        bottleneck: Bottleneck feature vectors. 
            A tf.float32 Tensor with shape 
            (batch, feat_dims) for images,
            (batch, duration, feat_dims) for sequences, and
            (batch * (duration - order_tandem), order_tandem + 1, feat_dims)
            for sequences with sequential_slice.            
        model: tf.keras.Model.
        # Returns
        total_loss: A scalar Tensor. The objective.
        losses: A list of losses.
        """
        # Initialization
        self.total_loss = 0.
        self.losses = []
        self.duration = logits.shape[1]

        # Calc losses
        for iter_name_loss, iter_loss_function, iter_kwargs\
            in zip(self.names_loss, self.loss_functions, self.list_kwargs):
            # CASE sentence
            if iter_name_loss == "CategoricalCrossentropy": 
                # RNN-incompatible as of 20210915
                # Use Multiplet loss with order_tandem=-1 or None
                msg = "Logits must have shape (batch, num_classes)."+\
                    "Got logits.shape = {}".format(logits.shape)
                assert len(logits.shape) == 2, msg

                loss = iter_loss_function(
                    labels=labels, 
                    logits=logits)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "SphericalEmbeddingConstraint": # RNN-compatible as of 20210915
                loss, self.mu_sec = iter_loss_function(bottleneck)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "SphericalEmbeddingConstraint4Weight": # RNN-compatible as of 20210915
                weight = model.layers[-1].weights[0]
                loss, self.mu_sec4w = iter_loss_function(weight)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "Multiplet":
                loss = iter_loss_function(logits, labels)
                    # logits.shape=(B, T, num_classes)
                    # Would be logits_slice.
                    # labels.shape=(B(, T))
                    # Would be labels_slice.
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "LLLR":
                logits_concat, labels_concat = sequential_concat(
                    logits, labels, self.duration)
                loss, self.loss1_LLLR, self.loss2_LLLR =\
                    iter_loss_function(logits_concat, labels_concat)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "LSEL": # framewise annotation-compatible
                logits_concat, labels_concat = sequential_concat(
                    logits, labels, self.duration, 
                    framewise_annotation=self.framewise_annotation)
                loss, self.loss1_LSEL, self.loss2_LSEL =\
                    iter_loss_function(logits_concat, labels_concat)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "LSIF":
                logits_concat, labels_concat = sequential_concat(
                    logits, labels, self.duration)
                loss, self.loss1_LSIF, self.loss2_LSIF =\
                    iter_loss_function(logits_concat, labels_concat)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss
                
            elif iter_name_loss == "LSIFwC":
                logits_concat, labels_concat = sequential_concat(
                    logits, labels, self.duration)
                loss, self.loss1_LSIFwC, self.loss2_LSIFwC =\
                    iter_loss_function(
                        logits_concat, labels_concat, 
                        multLam=iter_kwargs["multLam"])
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "DSKL":
                logits_concat, labels_concat = sequential_concat(
                    logits, labels, self.duration)
                loss, self.loss1_DSKL, self.loss2_DSKL =\
                    iter_loss_function(logits_concat, labels_concat)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "BARR":
                logits_concat, labels_concat = sequential_concat(
                    logits, labels, self.duration)
                loss, self.loss1_BARR, self.loss2_BARR =\
                    iter_loss_function(
                        logits_concat, labels_concat, 
                        multLam=iter_kwargs["multLam"])
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "Logistic":
                logits_concat, labels_concat = sequential_concat(
                    logits, labels, self.duration)
                loss, self.loss1_Logistic, self.loss2_Logistic =\
                    iter_loss_function(logits_concat, labels_concat)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "NGA-LSEL":
                logits_concat, labels_concat = sequential_concat(
                    logits, labels, self.duration)
                loss, self.loss1_NGALSEL, self.loss2_NGALSEL =\
                    iter_loss_function(logits_concat, labels_concat)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "CombinedMargin":
                weightmtx = model.layers[-1].weights[0] # RNN-compatible
                loss = iter_loss_function(
                    labels=labels,
                    bottleneck=bottleneck, 
                    weightmtx=weightmtx)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "BinaryCrossentropy": 
                # RNN-incompatible as of 20211019
                msg = "Logits must have shape (batch, num_classes)."+\
                    "Got logits.shape = {}".format(logits.shape)
                assert len(logits.shape) == 2, msg

                loss = iter_loss_function(
                    labels_oh=labels, 
                    logits=logits)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "SigmoidFocalCrossentropy": 
                msg = "Logits must have shape (batch, num_classes)."+\
                    "Got logits.shape = {}".format(logits.shape)
                assert len(logits.shape) == 2, msg

                loss = iter_loss_function(
                    labels_oh=labels, 
                    logits=logits)
                self.losses.append(loss)
                self.total_loss += iter_kwargs["prefactor"] * loss

            elif iter_name_loss == "BinaryCombinedMargin": 
                raise NotImplementedError()

            else:
                raise ValueError("{} not in the list of available losses {}.".\
                    format(iter_name_loss, self.ls_available))

        return self.total_loss, self.losses


def calc_loss(loss_manager, model, x, y, flag_weight_decay, training, 
    calc_grad, prefactor_weight_decay): 
    """
    # Args
    loss_manager: A LossManager instance.
    model: A tf.keras.Model object.
    x: A Tensor with shape=(batch, H, W, C).
    y: A Tensor with shape (batch,).
    flag_weight_decay: A bool. Whether to calculate weight decay.
    training: A bool. Training flag for BatchNorm and Dropout etc.
    calc_grad: A bool. Whether to calculate gradients.
    # Returns
    total_loss:
    losses: A list of losses. Length = len(config["loss_indices"]) or
        len(config["loss_indices"]) + 1 if flag_weight_decay.
    logits: A tf.float32 Tensor with shape (batch, num classes).
    bottleneck: A tf.float32 Tensor with shape (batch, feat dims).
    gradients: Gradients or None if not calc_grad.
    """
    if calc_grad:
        with tf.GradientTape() as tape:
            # Calc logits and bottleneck feature vectors
            logits, bottleneck = model(x, training)
                # (batch, num_classes) and (batch, bottleneck feat dims)

            # Calc losses
            total_loss, losses = loss_manager(
                labels=y, 
                logits=logits, 
                bottleneck=bottleneck, 
                model=model)

            # Calc weight decay
            if flag_weight_decay:
                for variables in model.trainable_variables:
                    wd = tf.nn.l2_loss(variables)
                    losses.append(wd)
                    total_loss += prefactor_weight_decay * wd

        # Calc gradients
        gradients = tape.gradient(total_loss, model.trainable_variables)

    else:
        # Calc logits and bottleneck feature vectors
        logits, bottleneck = model(x, training)

        # Calc losses
        total_loss, losses = loss_manager(
            labels=y, 
            logits=logits, 
            bottleneck=bottleneck, 
            model=model)

        # Calc weight decay
        if flag_weight_decay:
            for variables in model.trainable_variables:
                wd = tf.nn.l2_loss(variables)
                losses.append(wd)
                total_loss += prefactor_weight_decay * wd

        # No gradient calculation
        gradients = None 

    return total_loss, losses, logits, bottleneck, gradients


def calc_loss_rnns(loss_manager, model, x, y, flag_weight_decay,
    calc_grad, prefactor_weight_decay):
    """
    # Args
    loss_manager: A LossManager instance.
    model: A tf.keras.Model object.
    x: A Tensor with shape=(batch, T, D). 
    y: A Tensor with shape (batch,).
    flag_weight_decay: A bool. Whether to calculate weight decay.
    training: A bool. Training flag for BatchNorm and Dropout etc.
    calc_grad: A bool. Whether to calculate gradients.
    # Returns
    total_loss:
    losses: A list of losses. Length = len(config["loss_indices"]) or
        len(config["loss_indices"]) + 1 if flag_weight_decay.
    logits: A tf.float32 Tensor with shape (batch, num classes).
    bottleneck: A tf.float32 Tensor with shape (batch, feat dims).
    gradients: Gradients or None if not calc_grad.
    """
    if calc_grad:
        with tf.GradientTape() as tape:
            # Calc logits and bottleneck feature vectors
            logits, bottleneck, y = model(x, y)
                # (batch, duration, num_classes) 
                # and (batch, duration, bottleneck feat dims)
                # or
                # (batch*(duration-order_tandem), order_tandem+1, num_classes)
                # and (batch*(duration-order_tandem), ) if order_tandem != None

            # Calc losses
            total_loss, losses = loss_manager(
                labels=y, 
                logits=logits, 
                bottleneck=bottleneck, 
                model=model)

            # Calc weight decay
            if flag_weight_decay:
                for variables in model.trainable_variables:
                    wd = tf.nn.l2_loss(variables)
                    losses.append(wd)
                    total_loss += prefactor_weight_decay * wd

        # Calc gradients
        gradients = tape.gradient(total_loss, model.trainable_variables)

    else:
        # Calc logits and bottleneck feature vectors
        logits, bottleneck, y = model(x, y)

        # Calc losses
        total_loss, losses = loss_manager(
            labels=y, 
            logits=logits, 
            bottleneck=bottleneck, 
            model=model)

        # Calc weight decay
        if flag_weight_decay:
            for variables in model.trainable_variables:
                wd = tf.nn.l2_loss(variables)
                losses.append(wd)
                total_loss += prefactor_weight_decay * wd

        # No gradient calculation
        gradients = None 

    return total_loss, losses, logits, bottleneck, y, gradients


def calc_loss_movrnn(loss_manager, model, x, y, flag_weight_decay, training, 
    calc_grad, prefactor_weight_decay): 
    """
    # Args
    loss_manager: A LossManager instance.
    model: A tf.keras.Model object.
    x: A Tensor with shape (batch_size, num_frames, H, W, C).
      First axis represents chunks.
      Second axis represents num of frames in a chunk. 
    y: A Tensor with shape (batch_size, num_frames).
      First axis represents chunks.
      Second axis represents num of frames in a chunk. 
    flag_weight_decay: A bool. Whether to calculate weight decay.
    training: A bool. Training flag for BatchNorm and Dropout etc.
    calc_grad: A bool. Whether to calculate gradients.
    # Returns
    total_loss:
    losses: A list of losses. Length = len(config["loss_indices"]) or
        len(config["loss_indices"]) + 1 if flag_weight_decay.
    logits: A tf.float32 Tensor with shape (batch, num classes).
    bottleneck: A tf.float32 Tensor with shape (batch, feat dims).
    gradients: Gradients or None if not calc_grad.
    """
    labels = movinet_label_generator(y)
        # (batch_size,) 

    if calc_grad:
        with tf.GradientTape() as tape:
            logits, bottleneck = model(x, training)
                # (batch_size, num_classes), (batch_size, units)
                # batch_size axis corresponds to the output of each RNN cell (temporal axis); 
                # i.e., Each component in the batch_size axis represents a single timestamp.

            # Calc losses
            total_loss, losses = loss_manager(
                labels=labels, 
                logits=logits, 
                bottleneck=bottleneck, 
                model=model)

            # Calc weight decay
            if flag_weight_decay:
                for variables in model.trainable_variables:
                    wd = tf.nn.l2_loss(variables)
                    losses.append(wd)
                    total_loss += prefactor_weight_decay * wd

        # Calc gradients
        gradients = tape.gradient(total_loss, model.trainable_variables)

    else:
        logits, bottleneck = model(x, training)
            # (batch_size, num_classes), (batch_size, units)
            # batch_size axis corresponds to the output of each RNN cell (temporal axis); 
            # i.e., Each component in the batch_size axis represents a single timestamp.

        # Calc losses
        total_loss, losses = loss_manager(
            labels=labels, 
            logits=logits, 
            bottleneck=bottleneck, 
            model=model)

        # Calc weight decay
        if flag_weight_decay:
            for variables in model.trainable_variables:
                wd = tf.nn.l2_loss(variables)
                losses.append(wd)
                total_loss += prefactor_weight_decay * wd

        # No gradient calculation
        gradients = None 

    return total_loss, losses, logits, bottleneck, gradients
