import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import relu

import lib.utils as utils
from lib.plotting import *
from lib.encoder_decoder import *
from lib.likelihood_eval import *

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase

from torch.distributions.normal import Normal
from torch.distributions import Independent
from torch.nn.parameter import Parameter


class Baseline(nn.Module):
	def __init__(self, input_dim, latent_dim, device, 
		obsrv_std = 0.1, use_binary_classif = False,
		classif_per_tp = False,
		use_poisson_proc = False,
		linear_classifier = False,
		n_labels = 1):
		super(Baseline, self).__init__()

		self.input_dim = input_dim
		self.latent_dim = latent_dim

		self.obsrv_std = torch.Tensor([obsrv_std]).to(device)
		self.device = device

		self.use_binary_classif = use_binary_classif
		self.classif_per_tp = classif_per_tp
		self.use_poisson_proc = use_poisson_proc
		self.linear_classifier = linear_classifier

		y0_dim = latent_dim
		if use_poisson_proc:
			y0_dim += latent_dim

		if use_binary_classif: 
			if linear_classifier:
				self.classifier = nn.Sequential(
					#nn.Linear(y0_dim, n_labels))
					nn.Linear(y0_dim, 300),
					nn.Tanh(),
					nn.Linear(300, n_labels))
			else:
				self.classifier = nn.Sequential(
					nn.Linear(y0_dim, 300),
					nn.Tanh(),
					nn.Linear(300, 300),
					nn.Tanh(),
					nn.Linear(300, n_labels),)
			utils.init_network_weights(self.classifier)


	def get_likelihood(self, truth, pred_y, mask = None):
		# pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
		# truth shape  [n_traj, n_tp, n_dim]
		if mask is not None:
			mask = mask.repeat(pred_y.size(0), 1, 1, 1)

		# Compute likelihood of the data under the predictions
		log_density_data = data_log_density(pred_y, truth, 
			obsrv_std = self.obsrv_std, mask = mask)
		log_density_data = log_density_data.permute(1,0)

		# Compute the total density
		# Take mean over n_traj_samples
		log_density = torch.mean(log_density_data, 0)

		# shape: [n_traj]
		return log_density


	def get_mse(self, truth, pred_y, mask = None):
		# pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
		# truth shape  [n_traj, n_tp, n_dim]
		if mask is not None:
			mask = mask.repeat(pred_y.size(0), 1, 1, 1)

		# Compute likelihood of the data under the predictions
		log_density_data = compute_mse(pred_y, truth, mask = mask)
		# shape: [1]
		return torch.mean(log_density_data)


	def compute_all_losses(self, batch_dict,
		n_tp_to_sample = None, n_traj_samples = 1, kl_coef = 1.):

		# Condition on subsampled points
		# Make predictions for all the points
		pred_x, info = self.get_reconstruction(batch_dict["tp_to_predict"], 
			batch_dict["observed_data"], batch_dict["observed_tp"], 
			mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples)

		# Compute likelihood of all the points
		likelihood = self.get_likelihood(batch_dict["data_to_predict"], pred_x,
			mask = batch_dict["mask_predicted_data"])

		mse = self.get_mse(batch_dict["data_to_predict"], pred_x,
			mask = batch_dict["mask_predicted_data"])

		################################
		# Compute CE loss for binary classification on Physionet
		# Use only last attribute -- mortatility in the hospital 
		device = get_device(batch_dict["data_to_predict"])
		ce_loss = torch.Tensor([0.]).to(device)
		
		if (batch_dict["labels"] is not None) and self.use_binary_classif:
			if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1):
				ce_loss = compute_CE_loss(
					info["label_predictions"], 
					batch_dict["labels"])
			else:
				ce_loss = compute_multiclass_CE_loss(
					info["label_predictions"], 
					batch_dict["labels"],
					mask = batch_dict["mask_predicted_data"])

			if torch.isnan(ce_loss):
				print("label pred")
				print(info["label_predictions"])
				print("labels")
				print( batch_dict["labels"])
				raise Exception("CE loss is Nan!")

		pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"]))
		if self.use_poisson_proc:
			pois_log_likelihood = compute_poisson_proc_likelihood(
				batch_dict["data_to_predict"], pred_x, 
				info, mask = batch_dict["mask_predicted_data"])
			# Take mean over n_traj
			pois_log_likelihood = torch.mean(pois_log_likelihood, 1)

		loss = - torch.mean(likelihood)

		if self.use_poisson_proc:
			loss = loss - 0.1 * pois_log_likelihood 
			#loss = - pois_log_likelihood 

		if self.use_binary_classif:
			#loss = loss +  ce_loss * 10
			#loss += kl_coef * ce_loss * 1000
			loss =  ce_loss

		# Take mean over the number of samples in a batch
		results = {}
		results["loss"] = torch.mean(loss)
		results["likelihood"] = torch.mean(likelihood).detach()
		results["mse"] = torch.mean(mse).detach()
		results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
		results["ce_loss"] = torch.mean(ce_loss).detach()
		results["kl"] = 0.
		results["kl_first_p"] =  0.
		results["std_first_p"] = 0.
		results["n_calls"] = 0.

		if batch_dict["labels"] is not None and self.use_binary_classif:
			results["label_predictions"] = info["label_predictions"].detach()
		return results



def compute_KL_full_covmatrix(q_mus, q_cov_matrices, prior_mu, prior_cov):
	eye_matrix = 1e-4 * torch.eye(prior_cov.size(0)).to(get_device(q_mus))

	prior_mu_rep = prior_mu.unsqueeze(0).repeat(q_mus.size(0),1)
	prior_cov_rep = prior_cov.unsqueeze(0).repeat(q_mus.size(0),1,1)

	prior = MultivariateNormal(prior_mu_rep, covariance_matrix = prior_cov_rep + eye_matrix)
	approx_post = MultivariateNormal(q_mus, covariance_matrix = q_cov_matrices + eye_matrix)
	# q_mus.size(0): n_samples (training examples) * n_dims
	# q_mus.size(1)): n_timepoints

	#kldiv = torch.sum(kl_divergence(approx_post, prior) / q_mus.size(0))
	kldiv = kl_divergence(approx_post, prior)
	return kldiv



class VAE_Baseline(nn.Module):
	def __init__(self, input_dim, latent_dim, 
		y0_prior, device,
		obsrv_std = None, 
		use_binary_classif = False,
		classif_per_tp = False,
		use_poisson_proc = False,
		linear_classifier = False,
		n_labels = 1):

		super(VAE_Baseline, self).__init__()
		
		self.input_dim = input_dim
		self.latent_dim = latent_dim
		self.device = device

		self.obsrv_std = obsrv_std
		if obsrv_std is None:
			self.obsrv_std = torch.Tensor([0.1]).to(device)
		#self.n_ode_mixtures = n_ode_mixtures
		#self.mixture_nn = mixture_nn

		self.y0_prior = y0_prior
		self.use_binary_classif = use_binary_classif
		self.classif_per_tp = classif_per_tp
		self.use_poisson_proc = use_poisson_proc
		self.linear_classifier = linear_classifier

		# if (n_ode_mixtures is not None) and (mixture_nn is None):
		# 	raise Exception("Please provide a neural net to compute ODE mixtures")

		y0_dim = latent_dim
		if use_poisson_proc:
			y0_dim += latent_dim

		if use_binary_classif:
			if linear_classifier:
				self.classifier = nn.Sequential(
				nn.Linear(y0_dim, n_labels))
			else:
				self.classifier = nn.Sequential(
					nn.Linear(y0_dim, 300),
					nn.Sigmoid(),
					nn.Linear(300, 300),
					nn.Sigmoid(),
					nn.Linear(300, n_labels))
			utils.init_network_weights(self.classifier)


	def get_total_likelihood(self, truth, pred_y, mask = None):
		# pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
		# truth shape  [n_traj, n_tp, n_dim]
		n_traj, n_tp, n_dim = truth.size()

		# Compute likelihood of the data under the predictions
		truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
		
		if mask is not None:
			mask = mask.repeat(pred_y.size(0), 1, 1, 1)
		log_density_data = data_log_density(pred_y, truth_repeated, 
			obsrv_std = self.obsrv_std, mask = mask)
		log_density_data = log_density_data.permute(1,0)

		# log_density_data is a sum over n_tp, n_dim
		log_density_data = log_density_data# / n_tp / n_dim
		# Compute the total log_density
		# Take mean over n_traj
		log_density = torch.mean(log_density_data, 1)

		# shape: [n_traj_samples]
		return log_density


	def get_mse(self, truth, pred_y, mask = None):
		# pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
		# truth shape  [n_traj, n_tp, n_dim]
		n_traj, n_tp, n_dim = truth.size()

		# Compute likelihood of the data under the predictions
		truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
		
		if mask is not None:
			mask = mask.repeat(pred_y.size(0), 1, 1, 1)

		# Compute likelihood of the data under the predictions
		log_density_data = compute_mse(pred_y, truth_repeated, mask = mask)
		# shape: [1]
		return torch.mean(log_density_data)


	def compute_all_losses(self, batch_dict, n_traj_samples = 1, 
		kl_coef = 1.):
		# Condition on subsampled points
		# Make predictions for all the points
		pred_y, info = self.get_reconstruction(batch_dict["tp_to_predict"], 
			batch_dict["observed_data"], batch_dict["observed_tp"], 
			mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples)

		#print("get_reconstruction done -- computing likelihood")
		fp_mu, fp_std, fp_enc = info["first_point"]
		fp_std = fp_std.abs()
		fp_distr = Normal(fp_mu, fp_std)

		assert(torch.sum(fp_std < 0) == 0.)

		if isinstance(self.y0_prior, MixtureOfGaussians):
			kldiv_y0 = self.y0_prior.kl_div(fp_distr, fp_enc)
		else:
			kldiv_y0 = kl_divergence(fp_distr, self.y0_prior)


		if torch.isnan(kldiv_y0).any():
			print(fp_mu)
			print(fp_std)
			raise Exception("kldiv_y0 is Nan!")

		# Mean over number of latent dimensions
		# kldiv_y0 shape: [n_traj_samples, n_traj, n_latent_dims] if prior is a mixture of gaussians (KL is estimated)
		# kldiv_y0 shape: [1, n_traj, n_latent_dims] if prior is a standard gaussian (KL is computed exactly)
		# shape after: [n_traj_samples]
		kldiv_y0 = torch.mean(kldiv_y0,(1,2))

		# Compute likelihood of all the points
		rec_likelihood = self.get_total_likelihood(
			batch_dict["data_to_predict"], pred_y,
			mask = batch_dict["mask_predicted_data"])

		mse = self.get_mse(
			batch_dict["data_to_predict"], pred_y,
			mask = batch_dict["mask_predicted_data"])

		pois_log_likelihood = torch.Tensor([0.]).to(get_device(batch_dict["data_to_predict"]))
		if self.use_poisson_proc:
			pois_log_likelihood = compute_poisson_proc_likelihood(
				batch_dict["data_to_predict"], pred_y, 
				info, mask = batch_dict["mask_predicted_data"])
			# Take mean over n_traj
			pois_log_likelihood = torch.mean(pois_log_likelihood, 1)

		################################
		# Compute CE loss for binary classification on Physionet
		device = get_device(batch_dict["data_to_predict"])
		ce_loss = torch.Tensor([0.]).to(device)
		if (batch_dict["labels"] is not None) and self.use_binary_classif:

			if (batch_dict["labels"].size(-1) == 1) or (len(batch_dict["labels"].size()) == 1):
				ce_loss = compute_CE_loss(
					info["label_predictions"], 
					batch_dict["labels"])
			else:
				ce_loss = compute_multiclass_CE_loss(
					info["label_predictions"], 
					batch_dict["labels"],
					mask = batch_dict["mask_predicted_data"])
			
			if torch.isnan(ce_loss):
				print("label pred")
				print(info["label_predictions"])
				print("labels")
				print( batch_dict["labels"])
				raise Exception("CE loss is Nan!")
			#print("CE loss: {}".format(torch.mean(ce_loss)))

		# IWAE loss
		loss = - torch.logsumexp(rec_likelihood + kl_coef * kldiv_y0,0) # + kldiv_z
		if torch.isnan(loss):
			loss = - torch.mean(rec_likelihood + kl_coef * kldiv_y0,0)
			
		if self.use_poisson_proc:
			loss = loss - 0.1 * pois_log_likelihood 
			#loss = - pois_log_likelihood 

		if self.use_binary_classif:
			#loss = loss +  ce_loss * 100
			#loss += kl_coef * ce_loss * 1000
			loss =  ce_loss

		results = {}
		results["loss"] = torch.mean(loss)
		results["likelihood"] = torch.mean(rec_likelihood).detach()
		results["mse"] = torch.mean(mse).detach()
		results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
		results["ce_loss"] = torch.mean(ce_loss).detach()
		results["kl_first_p"] =  torch.mean(kldiv_y0).detach()
		results["std_first_p"] = torch.mean(fp_std).detach()
		results["n_calls"] = info["n_calls"]

		if batch_dict["labels"] is not None and self.use_binary_classif:
			results["label_predictions"] = info["label_predictions"].detach()

		return results



