import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import relu

import lib.utils as utils
from lib.plotting import *
from lib.encoder_decoder import *
from lib.likelihood_eval import *

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase

from torch.distributions.normal import Normal
from torch.distributions import Independent
from torch.nn.parameter import Parameter
from lib.base_models import Baseline


class ODE_GRU(Baseline): #seq2seq model
	def __init__(self, input_dim, latent_dim, device = torch.device("cpu"),
		y0_diffeq_solver = None, n_gru_units = 100, n_units = 100,
		concat_mask = False, obsrv_std = 0.1, use_binary_classif = False,
		classif_per_tp = False,
		use_poisson_proc = False, glob_dims = 0,
		n_labels = 1):

		Baseline.__init__(self, input_dim, latent_dim, device = device, 
			obsrv_std = obsrv_std, use_binary_classif = use_binary_classif,
			classif_per_tp = classif_per_tp,
			n_labels = n_labels)

		ode_combine_dim = latent_dim
		if use_poisson_proc:
			ode_combine_dim = latent_dim * 2 + input_dim

		self.ode_gru = Encoder_y0_ode_combine( 
			latent_dim = ode_combine_dim, 
			input_dim = (input_dim + glob_dims) * 2, # input and the mask
			y0_diffeq_solver = y0_diffeq_solver, 
			n_gru_units = n_gru_units, 
			device = device).to(device)

		self.y0_diffeq_solver = y0_diffeq_solver

		self.decoder = nn.Sequential(
		   nn.Linear(latent_dim, n_units),
		   nn.Tanh(),
		   nn.Linear(n_units, input_dim),)

		utils.init_network_weights(self.decoder)
		self.use_poisson_proc = use_poisson_proc

		self.glob_dims = glob_dims

	def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, 
		mask = None, n_traj_samples = None):

		return self.run_seq2seq(time_steps_to_predict, data, truth_time_steps, 
		mask = mask, n_traj_samples = n_traj_samples)


	def run_seq2seq(self, time_steps_to_predict, data, truth_time_steps, 
		mask = None, n_traj_samples = None):

		assert(mask is not None)
				
		data_and_mask = data
		if mask is not None:
			data_and_mask = torch.cat([data, mask],-1)

		_, _, latent_ys, _ = self.ode_gru.run_ode_combine(data_and_mask, truth_time_steps,
			t0 = time_steps_to_predict[0],
			save_latents = True, save_info = False)
		latent_ys = latent_ys.permute(0,2,1,3)
		last_hidden = latent_ys[:,:,-1,:]

		if self.use_poisson_proc:
			n_gp_samples, n_traj, n_dims = last_hidden.size()

			# append a vector of zeros to compute the integral of lambda
			zeros = torch.zeros([n_gp_samples, n_traj,self.input_dim]).to(get_device(data))
			last_hidden_aug = torch.cat((last_hidden, zeros), -1)
		else:
			last_hidden_aug = last_hidden

		ode_sol, _ = self.y0_diffeq_solver(last_hidden_aug, time_steps_to_predict)
		outputs = self.decoder(ode_sol)

		if self.use_poisson_proc:
			ode_sol, int_lambda = self.ode_gru.y0_diffeq_solver.ode_func.get_y_int_lambda(ode_sol)
			ode_sol, log_lambda_y = self.ode_gru.y0_diffeq_solver.ode_func.get_log_lambdas(ode_sol)

			int_lambda = int_lambda - int_lambda[:,:,0,:].unsqueeze(2)

			print("log_lambda_y")
			print(log_lambda_y[0,0,:,:])
			print(torch.exp(log_lambda_y[0,0,:,:]))

			print("int_lambda")
			print(int_lambda[0,0,-1,:])

			assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.)

		extra_info = {
			"gp_samples": None,
			"latent_traj": ode_sol,
			"n_calls": 0., 
			"first_point": (latent_ys[:,:,-1,:], 0.0, latent_ys[:,:,-1,:]),
			"pred_mean_y0": outputs[0]}

		if self.use_poisson_proc:
			# intergral of lambda from the last step of ODE Solver
			extra_info["int_lambda"] = int_lambda[:,:,-1,:]
			extra_info["log_lambda_y"] = log_lambda_y

		if self.use_binary_classif:
			extra_info["label_predictions"] = self.classifier(last_hidden).squeeze(-1)

		# outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
		return outputs, extra_info



class ODE_GRU_rnn(ODE_GRU):
	def __init__(self, input_dim, latent_dim, device = torch.device("cpu"),
		y0_diffeq_solver = None, n_gru_units = 100,  n_units = 100,
		concat_mask = False, obsrv_std = 0.1, use_binary_classif = False,
		classif_per_tp = False,
		use_poisson_proc = False, glob_dims = 0,
		n_labels = 1):

		ODE_GRU.__init__(self, input_dim = input_dim, device = device, 
			y0_diffeq_solver = y0_diffeq_solver, latent_dim = latent_dim,
			concat_mask = concat_mask,
			n_units = n_units,
			obsrv_std = obsrv_std, 
			use_binary_classif = use_binary_classif,
			classif_per_tp = classif_per_tp,
			use_poisson_proc = use_poisson_proc,
			glob_dims = glob_dims,
			n_labels = n_labels)

	def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, 
		mask = None, n_traj_samples = None):

		do_extrap = len(truth_time_steps) != len(time_steps_to_predict)
		do_extrap = do_extrap or (torch.sum(time_steps_to_predict - truth_time_steps) != 0)

		if do_extrap:
			return self.run_seq2seq(time_steps_to_predict, data, truth_time_steps, 
				mask = mask, n_traj_samples = n_traj_samples)

		# time_steps_to_predict and truth_time_steps should be the same 
		assert(len(truth_time_steps) == len(time_steps_to_predict))
		assert(mask is not None)
		
		data_and_mask = data
		if mask is not None:
			data_and_mask = torch.cat([data, mask],-1)

		_, _, latent_ys, ode_extra_info = self.ode_gru.run_ode_combine(data_and_mask, truth_time_steps,
			t0 = time_steps_to_predict[-1],
			save_latents = True, save_info = False)
		latent_ys = latent_ys.permute(0,2,1,3)
		last_hidden = latent_ys[:,:,-1,:]


		if self.use_poisson_proc:
			latent_ys, int_lambda = self.ode_gru.y0_diffeq_solver.ode_func.get_y_int_lambda(latent_ys)
			latent_ys, log_lambda_y = self.ode_gru.y0_diffeq_solver.ode_func.get_log_lambdas(latent_ys)

			int_lambda = int_lambda - int_lambda[:,:,0,:].unsqueeze(2)
			
			
			print("log_lambda_y")
			print(log_lambda_y[0,0,:,:])
			print(torch.exp(log_lambda_y[0,0,:,:]))
			print(torch.sum(torch.exp(log_lambda_y[0,0,:,:]) , 0))

			print("int_lambda")
			print(int_lambda[0,0,-1,:])

			assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.)

		outputs = self.decoder(latent_ys)
		# Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
		first_point = data[:,0,:]
		if self.glob_dims > 0:
			first_point = first_point[:,:-int(self.glob_dims)]
		outputs = utils.shift_outputs(outputs, first_point)

		extra_info = {
			"gp_samples": None,
			"latent_traj": latent_ys,
			"n_calls": 0., 
			"first_point": (latent_ys[:,:,-1,:], 0.0, latent_ys[:,:,-1,:]),
			"pred_mean_y0": outputs[0]}
			#"hidden_changes" : ode_extra_info["hidden_changes"]}

		if self.use_poisson_proc:
			# intergral of lambda from the last step of ODE Solver
			extra_info["int_lambda"] = int_lambda[:,:,-1,:]
			extra_info["log_lambda_y"] = log_lambda_y

		if self.use_binary_classif:
			if self.classif_per_tp:
				extra_info["label_predictions"] = self.classifier(latent_ys)
			else:
				extra_info["label_predictions"] = self.classifier(last_hidden).squeeze(-1)

		# outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
		return outputs, extra_info




