#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Description
Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. 2005.
Learning to rank using gradient descent. In Proceedings of the 22nd ICML. 89–96.
"""

import torch
import torch.nn.functional as F

from ptranking.base.ranker import NeuralRanker
from ptranking.ltr_adhoc.eval.parameter import ModelParameter

class RankNet(NeuralRanker):
    '''
    Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. 2005.
    Learning to rank using gradient descent. In Proceedings of the 22nd ICML. 89–96.
    '''
    def __init__(self, sf_para_dict=None, model_para_dict=None, gpu=False, device=None, lr=None):
        super(RankNet, self).__init__(id='RankNet', sf_para_dict=sf_para_dict, gpu=gpu, device=device, lr=lr)
        self.sigma = model_para_dict['sigma']
        self.sigma = 1.0

    def inner_train(self, batch_pred, batch_label, mask, teacher_pred=None, **kwargs):
        '''
        :param batch_preds: [batch, ranking_size] each row represents the relevance predictions for documents within a ltr_adhoc
        :param batch_label:  [batch, ranking_size] each row represents the standard relevance grades for documents within a ltr_adhoc
        :return:
        '''
        mix_alpha = kwargs["pri_dict"].json_dict["mix_alpha"][0]

        # print(batch_pred.dtype, batch_label.dtype, mask.dtype)
        batch_s_ij = torch.unsqueeze(batch_pred, dim=2) - torch.unsqueeze(batch_pred, dim=1)  # computing pairwise differences w.r.t. predictions, i.e., s_i - s_j
        # batch_p_ij = 1.0 / (torch.exp(-self.sigma * batch_s_ij) + 1.0)
        # batch_p_ij = torch.sigmoid(batch_s_ij)

        batch_std_diffs = torch.unsqueeze(batch_label, dim=2) - torch.unsqueeze(batch_label, dim=1)  # computing pairwise differences w.r.t. standard labels, i.e., S_{ij}
        batch_Sij = torch.clamp(batch_std_diffs, min=-1.0, max=1.0)  # ensuring S_{ij} \in {-1, 0, 1}
        # diff_mask = (batch_Sij != 0).float()
        # for i in range(100):
        #     if torch.sum(diff_mask[i]) > 0:
        #         print(diff_mask[i])
        #         assert False
        # assert False
        batch_std_p_ij = 0.5 * (1.0 + batch_Sij)

        batch_square_mask = torch.unsqueeze(mask, dim=2) * torch.unsqueeze(mask, dim=1)

        # print(torch.sum(mask), torch.sum(batch_square_mask), batch_square_mask.shape)
        # assert False

        qg_mask = torch.max(batch_label, dim=1)[0].unsqueeze(-1).unsqueeze(-1)
        # print(batch_square_mask.shape, qg_mask.shape)
        # print((qg_mask * batch_square_mask).dtype)
        # print(batch_p_ij.dtype, batch_std_p_ij.dtype)

        data_loss = F.binary_cross_entropy_with_logits(input=self.sigma * torch.triu(batch_s_ij, diagonal=1), target=torch.triu(batch_std_p_ij, diagonal=1), weight=qg_mask * batch_square_mask, reduction='sum')
        # data_loss = F.binary_cross_entropy_with_logits(input=self.sigma * torch.triu(batch_s_ij, diagonal=1), target=torch.triu(batch_std_p_ij, diagonal=1), weight=diff_mask * batch_square_mask, reduction='sum')
        # if torch.sum(qg_mask * batch_square_mask) > 0:
        #     data_loss /= torch.sum(qg_mask * batch_square_mask)
        _rec_data_loss = data_loss.item()

        teacher_loss = data_loss
        if teacher_pred is not None:
            # batch_teacher_prob = torch.sigmoid(teacher_pred)
            # batch_teacher_diffs = torch.unsqueeze(batch_teacher_prob, dim=2) - torch.unsqueeze(batch_teacher_prob, dim=1)
            batch_teacher_pred_diffs = torch.unsqueeze(teacher_pred, dim=2) - torch.unsqueeze(teacher_pred, dim=1)
            batch_teacher_diffs = torch.sigmoid(self.sigma * batch_teacher_pred_diffs)
            teacher_loss = F.binary_cross_entropy_with_logits(input=self.sigma * torch.triu(batch_s_ij, diagonal=1), target=torch.triu(batch_teacher_diffs, diagonal=1), weight=batch_square_mask, reduction='sum') - F.binary_cross_entropy_with_logits(input=self.sigma * torch.triu(batch_teacher_pred_diffs, diagonal=1), target=torch.triu(batch_teacher_diffs, diagonal=1), weight=batch_square_mask, reduction='sum')
            # if torch.sum(batch_square_mask) > 0:
            #     teacher_loss /= torch.sum(batch_square_mask)
            _rec_teacher_loss = teacher_loss.item()
        else:
            _rec_teacher_loss = 0

        batch_loss = mix_alpha * data_loss + (1 - mix_alpha) * teacher_loss
        _rec_total_loss = batch_loss.item()

        assert batch_loss.item() >= 0, batch_loss

        # about reduction, both mean & sum would work, mean seems straightforward due to the fact that the number of pairs differs from query to query
        # batch_loss = F.binary_cross_entropy(input=torch.triu(batch_p_ij, diagonal=1), target=torch.triu(batch_std_p_ij, diagonal=1), reduction='mean')

        self.optimizer.zero_grad()
        batch_loss.backward()
        self.optimizer.step()

        return (_rec_total_loss, _rec_data_loss, _rec_teacher_loss)

###### Parameter of RankNet ######

class RankNetParameter(ModelParameter):
    ''' Parameter class for RankNet '''
    def __init__(self, debug=False, para_json=None):
        super(RankNetParameter, self).__init__(model_id='RankNet', para_json=para_json)
        self.debug = debug

    def default_para_dict(self):
        """
        Default parameter setting for RankNet
        """
        self.ranknet_para_dict = dict(model_id=self.model_id, sigma=1.0)
        return self.ranknet_para_dict

    def to_para_string(self, log=False, given_para_dict=None):
        """
        String identifier of parameters
        :param log:
        :param given_para_dict: a given dict, which is used for maximum setting w.r.t. grid-search
        :return:
        """
        # using specified para-dict or inner para-dict
        ranknet_para_dict = given_para_dict if given_para_dict is not None else self.ranknet_para_dict

        s1, s2 = (':', '\n') if log else ('_', '_')
        ranknet_para_str = s1.join(['Sigma', '{:,g}'.format(ranknet_para_dict['sigma'])])
        return ranknet_para_str

    def grid_search(self):
        """
        Iterator of parameter settings for RankNet
        """
        if self.use_json:
            choice_sigma = self.json_dict['sigma']
        else:
            choice_sigma = [5.0, 1.0] if self.debug else [1.0]  # 1.0, 10.0, 50.0, 100.0

        for sigma in choice_sigma:
            self.ranknet_para_dict = dict(model_id=self.model_id, sigma=sigma)
            yield self.ranknet_para_dict
