#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Description
Viewing the prediction of relevance as a conventional regression problem.
"""

import torch
import torch.nn.functional as F

from ptranking.base.ranker import NeuralRanker


def rankMSE_loss_function(batch_pred=None, batch_label=None, TL_AF=None, learn_from_data=True):
	'''
	Ranking loss based on mean square error
	:param batch_pred:
	:param batch_stds:
	:return:
	'''
	if ('S' == TL_AF or 'ST' == TL_AF) and learn_from_data:  # map to the same relevance level
		max_rele_level = torch.max(batch_label)
		batch_pred = batch_pred * max_rele_level
	
	# For query groups data with only negative samples, set data loss to 0.
	# print(batch_label)
	if torch.max(batch_label) == 0:
		return torch.cuda.FloatTensor([0.0])

	if learn_from_data:
		batch_loss = F.binary_cross_entropy_with_logits(batch_pred, batch_label)
	else:
		batch_label.detach()
		batch_loss = F.binary_cross_entropy_with_logits(batch_pred, torch.sigmoid(batch_label)) - F.binary_cross_entropy_with_logits(batch_label, torch.sigmoid(batch_label))
	return batch_loss


class RankMSE(NeuralRanker):
	def __init__(self, sf_para_dict=None, gpu=False, device=None):
		super(RankMSE, self).__init__(id='RankMSE', sf_para_dict=sf_para_dict, gpu=gpu, device=device)
		self.TL_AF = self.get_tl_af()

	def inner_train(self, batch_pred, batch_label, teacher_pred=None, **kwargs):
		'''
		:param batch_preds: [batch, ranking_size] each row represents the relevance predictions for documents within a ltr_adhoc
		:param batch_stds: [batch, ranking_size] each row represents the standard relevance grades for documents within a ltr_adhoc
		:return:
		'''
		mix_alpha = kwargs["pri_dict"].json_dict["mix_alpha"][0]
		batch_loss = rankMSE_loss_function(batch_pred, batch_label, TL_AF=self.TL_AF)

		_rec_data_loss = batch_loss.item()

		teacher_batch_loss = batch_loss
		if teacher_pred is not None:
			teacher_batch_loss = rankMSE_loss_function(batch_pred,
												teacher_pred,
												TL_AF=self.TL_AF,
												learn_from_data=False)
			_rec_teacher_loss = teacher_batch_loss.item()
		else:
			_rec_teacher_loss = 0
		# print(batch_loss, teacher_batch_loss)
		# print(teacher_pred)
		# print(batch_pred)
		# assert False
		# print(batch_loss)
		batch_loss = mix_alpha * batch_loss + \
		             (1 - mix_alpha) * teacher_batch_loss
		_rec_batch_loss = batch_loss.item()

		# kwargs['batch_cumulative_loss'] += batch_loss

		# if kwargs['take_gradient']:
		# 	kwargs['batch_cumulative_loss'] /= kwargs['pri_dict'].json_dict['batch_size'][0]
		# 	self.optimizer.zero_grad()
		# 	kwargs['batch_cumulative_loss'].backward()
		# 	print(bat)
		# 	self.optimizer.step()
		# 	kwargs['batch_cumulative_loss'] = torch.cuda.FloatTensor([0.0])
		assert batch_loss.item() >= 0
		if batch_loss > 0:
			batch_loss.backward()
		if kwargs['take_gradient']:
			self.optimizer.step()
			self.optimizer.zero_grad()

		return (_rec_batch_loss, _rec_data_loss, _rec_teacher_loss)
