import gc
import numpy as np
import sklearn as sk
import numpy as np
#import gc
import torch
import torch.nn as nn
from torch.nn.functional import relu

import lib.utils as utils
from lib.utils import get_device
from lib.plotting import *
from lib.encoder_decoder import *
from lib.likelihood_eval import *

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence, Independent
from lib.base_models import VAE_Baseline



class ODEVAE(VAE_Baseline):
	def __init__(self, input_dim, latent_dim, encoder_y0, decoder, diffeq_solver, 
		y0_prior, device, obsrv_std = None, 
		use_binary_classif = False, use_poisson_proc = False,
		linear_classifier = False,
		classif_per_tp = False,
		n_labels = 1):

		super(ODEVAE, self).__init__(
			input_dim = input_dim, latent_dim = latent_dim, 
			y0_prior = y0_prior, 
			device = device, obsrv_std = obsrv_std, 
			use_binary_classif = use_binary_classif,
			classif_per_tp = classif_per_tp, 
			linear_classifier = linear_classifier,
			use_poisson_proc = use_poisson_proc,
			n_labels = n_labels)

		self.encoder_y0 = encoder_y0
		self.diffeq_solver = diffeq_solver
		self.decoder = decoder
		self.use_poisson_proc = use_poisson_proc

		# self.first_point_enc = torch.nn.Parameter(torch.Tensor(1, 20, latent_dim*2)).to(device)
		# nn.init.normal_(self.first_point_enc, mean=0, std=0.1)


	def get_reconstruction(self, time_steps_to_predict, truth, truth_time_steps, 
		mask = None, n_traj_samples = 1, use_gp_sample = False, t0 = None):

		if isinstance(self.encoder_y0, Encoder_y0_ode_combine) or \
			isinstance(self.encoder_y0, Encoder_y0_from_rnn) or \
			isinstance(self.encoder_y0, Encoder_y0_ode_per_attr):
			# v3: Get y0 through ODE-based recognition network
			# For each y_t_i, run ODE backwards to get estimates of y0

			ode_mixtures = None

			# set t0 -- the time point when we want to get the latent state y using encoder_y0
			if (t0 is None) or (t0 == time_steps_to_predict[0]):
				t0 = time_steps_to_predict[0]
				cut_first_point = False
			else:
				if any(t0 < time_steps_to_predict[0]):	
					time_steps_to_predict = torch.cat((t0.reshape(-1), time_steps_to_predict), 0)
					cut_first_point = True
				else:
					raise Exception("t0 must be set to the time point before any observations")

			if isinstance(self.encoder_y0, Encoder_y0_ode_per_attr):
				first_point_mu, first_point_std = self.encoder_y0(truth, truth_time_steps, mask = mask, t0=t0)
			else:
				truth_w_mask = truth
				if mask is not None:
					truth_w_mask = torch.cat((truth, mask), -1)
				first_point_mu, first_point_std = self.encoder_y0(truth_w_mask, truth_time_steps, t0=t0)


			means_y0 = first_point_mu.repeat(n_traj_samples, 1, 1)
			sigma_y0 = first_point_std.repeat(n_traj_samples, 1, 1)
			first_point_enc = utils.sample_standard_gaussian(means_y0, sigma_y0)

		# elif isinstance(self.encoder_y0, Encoder_y0_gp):
		# 	# v2: Get y0 through GP
		# 	cut_first_point = True
		# 	# Set y0 to be at the time point before the data starts
		# 	y0_time_stamp = self.encoder_y0.get_y0_time_stamp(time_steps_to_predict)
		# 	first_point_mu, first_point_std = self.encoder_y0(truth, truth_time_steps, y0_time_stamp)
		# 	means_y0 = first_point_mu.repeat(n_traj_samples, 1, 1)
		# 	sigma_y0 = first_point_std.repeat(n_traj_samples, 1, 1)
		# 	first_point_enc = utils.sample_standard_gaussian(means_y0, sigma_y0)
			
		# 	time_steps_to_predict = torch.cat((y0_time_stamp, time_steps_to_predict), 0)
		# elif isinstance(self.encoder_y0, Encoder_y0_from_x0):
		# 	# v1: Get encoding y0 using only x0
		# 	cut_first_point = False
		# 	ode_mixtures = None
		# 	# Replicate the first point according to the number of samples that we want to take from GP for each trajectory
		# 	first_point_repeated = first_point.repeat((n_traj_samples, 1, 1))
		# 	first_point_mu, first_point_std = self.encoder_y0(first_point_repeated)
		# 	first_point_enc = Normal(first_point_mu, first_point_std).sample()
		else:
			raise Exception("Unknown encoder type {}".format(type(self.encoder_y0).__name__))
		
		first_point_std = first_point_std.abs()
		assert(torch.sum(first_point_std < 0) == 0.)

		gc.collect()

		mixtures = None
		# if self.n_ode_mixtures is not None:
		# 	mixtures = self.mixture_nn(first_point_enc)
		# 	mixtures = mixtures / mixtures.sum(-1, keepdim = True)

		# !!!!! TODO Correct sampling from prior in case of multiple mixtures

		# # # # # # 
		ind_points_tuple = None
		if use_gp_sample:
			ind_points_tuple = self.encoder_zt(truth, truth_time_steps)

		if self.use_poisson_proc:
			n_gp_samples, n_traj, n_dims = first_point_enc.size()
			# append a vector of zeros to compute the integral of lambda
			zeros = torch.zeros([n_gp_samples, n_traj,self.input_dim]).to(get_device(truth))
			first_point_enc_aug = torch.cat((first_point_enc, zeros), -1)
			means_y0_aug = torch.cat((means_y0, zeros), -1)
		else:
			first_point_enc_aug = first_point_enc
			means_y0_aug = means_y0

		if torch.isnan(first_point_enc).any():
			print("first_point_enc has Nans!")
			print(time_steps_to_predict)
			print(truth_time_steps)
			print(torch.isnan(truth).any())
			if mask is not None:
				print(torch.isnan(mask).any())
			else:
				print("mask is None")
			
		assert(not torch.isnan(time_steps_to_predict).any())
		assert(not torch.isnan(first_point_enc).any())
		assert(not torch.isnan(first_point_enc_aug).any())


		# Shape of sol_y [num_gp_samples, n_samples, n_timepoints, n_latents]
		sol_y, extra_info = self.diffeq_solver(
			first_point_enc_aug, time_steps_to_predict, ind_points_tuple,
			cut_first_point = cut_first_point, mixtures = mixtures)

		sol_mean_y, _ = self.diffeq_solver(
			means_y0_aug, time_steps_to_predict, ind_points_tuple,
			cut_first_point = cut_first_point, mixtures = mixtures)

		if self.use_poisson_proc:
			sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y)
			sol_mean_y, log_lambda_mean_y, int_lambda_from_y0_mean, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_mean_y)

			assert(torch.sum(int_lambda[:,:,0,:]) == 0.)
			assert(torch.sum(int_lambda[0,0,-1,:] <= 0) == 0.)

		pred_x = self.decoder(sol_y)
		pred_x_mean_y0 = self.decoder(sol_mean_y)

		all_extra_info = {"latent_traj": sol_y.detach(),
			"ode_func_ts": extra_info["ode_func_ts"].detach(),
			"first_point": (first_point_mu, first_point_std, first_point_enc),
			"pred_mean_y0": pred_x_mean_y0[0].detach(),
			"ode_func_ts": extra_info["ode_func_ts"].detach(),
			"ode_func_norms": extra_info["ode_func_norms"].detach(),
			"n_calls": extra_info["n_calls"]}

		if ind_points_tuple is not None:
			all_extra_info["ind_points_tuple"] = ind_points_tuple.detach()

		if ind_points_tuple is not None:
			all_extra_info["gp_samples"] = all_extra_info["gp_samples"].detach()

		if self.use_poisson_proc:
			# intergral of lambda from the last step of ODE Solver
			all_extra_info["int_lambda"] = int_lambda[:,:,-1,:]
			all_extra_info["int_lambda_from_y0_mean"] = int_lambda_from_y0_mean[:,:,-1,:]
			all_extra_info["log_lambda_y"] = log_lambda_y
			all_extra_info["log_lambda_mean_y"] = log_lambda_mean_y

		if self.use_binary_classif:
			if self.classif_per_tp:
				all_extra_info["label_predictions"] = self.classifier(sol_y)
			else:
				all_extra_info["label_predictions"] = self.classifier(first_point_enc).squeeze(-1)

		return pred_x, all_extra_info


	def sample_traj_from_prior(self, time_steps_to_predict, n_traj_samples = 1, mixtures = None):
		# input_dim = starting_point.size()[-1]
		# starting_point = starting_point.view(1,1,input_dim)

		# Sample y0 from prior
		if isinstance(self.y0_prior, MixtureOfGaussians):
			starting_point_enc = self.y0_prior.sample([n_traj_samples, 1])
		else:
			starting_point_enc = self.y0_prior.sample([n_traj_samples, 1, self.latent_dim]).squeeze(-1)

		starting_point_enc_aug = starting_point_enc
		if self.use_poisson_proc:
			n_gp_samples, n_traj, n_dims = starting_point_enc.size()
			# append a vector of zeros to compute the integral of lambda
			starting_point_enc_aug = torch.cat((starting_point_enc, torch.zeros(n_gp_samples, n_traj,self.input_dim)), -1)

		mixtures = None
		# if self.n_ode_mixtures is not None:	
		# 	mixtures = self.mixture_nn(starting_point_enc)
		# 	ode_mixtures = mixtures / mixtures.sum(-1, keepdim = True)

		sol_y = self.diffeq_solver.sample_traj_from_prior(starting_point_enc_aug, time_steps_to_predict, 
			gp_dim = self.input_dim, n_traj_samples = n_traj_samples, mixtures = mixtures)

		if self.use_poisson_proc:
			sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(sol_y)
		
		# !!!! Sample the time points from Poisson process

		return self.decoder(sol_y)


