import time
import numpy as np

import torch
import torch.nn as nn

import lib.utils as utils
from torch.distributions.multivariate_normal import MultivariateNormal
from lib.gaussian_process import PytorchGaussianProcess

# git clone https://github.com/rtqichen/torchdiffeq.git
# Don't use adjoint method -- it is very slow
from torchdiffeq.torchdiffeq import odeint as odeint

#####################################################################################################
# Constants

# odeint_rtol = 1e-4
# odeint_atol = 1e-5

#####################################################################################################

# def sample_multiple_traj_from_ode(method, starting_point, time_steps, ode_func, n_examples = 1, context = None,
# 	device = torch.device("cpu")):
	
# 	starting_point = np.tile(starting_point, (n_examples, 1, 1))

# 	func = lambda t_local, y: ode_func(t_local, y, context = context)
# 	sol = odeint(func, torch.Tensor().new_tensor(starting_point, device = device), time_steps, 
# 		rtol=odeint_rtol, atol=odeint_atol, method = method)
# 	sol = sol.permute(1,2,0,3)
# 	return sol


def get_tp_indices(current_tp_seq, all_tp):
	# Using the property that both arrays are sorted and have only unique values
	# get indices of values from current_tp_seq in array all_tp
	
	reversed = False
	if current_tp_seq[0] > current_tp_seq[-1]:
		# decreasing order!
		# the other sequence has to be in the decreasing order as well
		assert(all_tp[0] > all_tp[-1])
		current_tp_seq = utils.reverse(current_tp_seq)
		all_tp = utils.reverse(all_tp)
		reversed = True

	s = 0
	indices = []
	for j in range(len(current_tp_seq)):
		tp_to_find = current_tp_seq[j]

		ind = (all_tp[s:] == tp_to_find).nonzero()[0] + s
		indices.append(ind)
		s = ind
	
	# j = 0
	# tp_to_find = current_tp_seq[j]
	# indices = []
	# for i in range(len(all_tp)):
	# 	if all_tp[i] == tp_to_find:
	# 		indices.append(i)
	# 		j += 1
	# 		if j >= len(current_tp_seq):
	# 			break
	# 		tp_to_find = current_tp_seq[j]

	indices = torch.stack(indices)[:,0]

	if reversed:
		indices = all_tp.size(0) - indices - 1

	return indices


class DiffeqSolver(nn.Module):
	def __init__(self, input_dim, ode_func, method, latents, 
			odeint_rtol = 1e-4, odeint_atol = 1e-5, device = torch.device("cpu")):
		super(DiffeqSolver, self).__init__()

		self.ode_method = method
		self.latents = latents		
		self.device = device
		self.ode_func = ode_func

		self.odeint_rtol = odeint_rtol
		self.odeint_atol = odeint_atol

	def forward(self, first_point, time_steps_to_predict, ind_points_tuple = None, 
		backwards = False, cut_first_point = False, mixtures = None):
		"""
		# Decode the trajectory through ODE Solver given inducing points from Q

		ind_points_tuple: tuple (inducing_points, time_steps_ind) contatining inducing points and time stamps for them.
			Can be None if we don't use sample from z(t) GP. If so, just feed zero to ode_func instead of GP sample.

		mixtures: list of mixtures if using mixture of several ODEs. Otherwise None
		"""
		n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1]
		n_dims = first_point.size()[-1]

		context = None
		if ind_points_tuple is not None:
			ind_points, ind_point_var, time_steps_ind = ind_points_tuple		
			context = (ind_points, time_steps_ind)
		
		# ode_func = lambda t_local, y: self.ode_func(t_local, y, context = context, 
		# 	backwards = backwards, mixtures = mixtures)

		# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
		# Version 1
		# if time_steps_to_predict is a tensor, then time steps are different for each time series
		# make a vector of all time stamps where we need to evaluate this ODE
		# then recover the predictions for a required subset of time points for each time series
		start = time.time()

		# if len(time_steps_to_predict.size()) > 1:
		# 	# different time series can have same time stamps
		# 	# all the sequences have to start at the same time point so that the ODE starts at the correct first point
		# 	time_steps_to_predict = time_steps_to_predict - time_steps_to_predict[0]
		# 	tp_union = torch.sort(torch.unique(time_steps_to_predict.reshape(-1)))[0]

		# 	# !!! assuming that all time points sequences are either all increasing or all decreasing
		# 	if time_steps_to_predict[0,0] > time_steps_to_predict[-1,0]:
		# 		# time points are in decreasing order
		# 		tp_union = utils.reverse(tp_union)

		# 	def make_mask_tp_per_sequence(tp_per_sequence, tp_union):
		# 		n_tp, n_traj = time_steps_to_predict.size()
		# 		mask = torch.zeros(n_traj, len(tp_union))

		# 		for i in range(n_traj):
		# 			current_tp = time_steps_to_predict[:,i]
		# 			tp_ind = get_tp_indices(current_tp, tp_union)
		# 			mask[i, tp_ind] = 1
		# 		mask = mask.byte()
		# 		return mask

		# 	mask = make_mask_tp_per_sequence(time_steps_to_predict, tp_union)
		# 	assert(torch.sum(tp_union[mask[0].nonzero()].reshape(-1) -  time_steps_to_predict[:,0]) == 0.)
		# else:


		# if time_steps_to_predict is a vector, then time steps are shared across of all time series in the batch -- no need to do anything extra
		tp_union = time_steps_to_predict


		#try:
		pred_y = odeint(self.ode_func, first_point, tp_union, 
			rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method)
		# except:
		# 	print("Reducing tolerance to rtol={} atol={}".format(self.odeint_rtol * 10, self.odeint_atol * 10))
		# 	try:
		# 		pred_y = odeint(self.ode_func, first_point, tp_union, 
		# 			rtol=self.odeint_rtol * 10, atol=self.odeint_atol * 10, method = self.ode_method)
		# 	except:
		# 		print("Reducing tolerance to rtol={} atol={}".format(self.odeint_rtol * 100, self.odeint_atol * 100))
		# 		pred_y = odeint(self.ode_func, first_point, tp_union, 
		# 			rtol=self.odeint_rtol * 100, atol=self.odeint_atol * 100, method = self.ode_method)


		pred_y_v1 = pred_y = pred_y.permute(1,2,0,3)

		assert(torch.mean(pred_y[:, :, 0, :]  - first_point) < 0.001)

		if len(time_steps_to_predict.size()) > 1:
			res = []
			for i in range(n_traj):
				res.append(pred_y[:,i,mask[i].nonzero().reshape(-1)])
			pred_y_v1 = torch.stack(res, 1)

		#print("V1 time point union: time: " + str(time.time() - start))


		# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
		# Version 2
		# !!!! working code
		# def run_odeint(ode_func, first_point, time_steps_to_predict, 
		# 		odeint_rtol, odeint_atol):
		# 	pred_y = odeint(ode_func, first_point, time_steps_to_predict, 
		# 		rtol=odeint_rtol, atol=odeint_atol, method = self.ode_method)
		# 	pred_y = pred_y.permute(1,2,0,3)
		# 	return pred_y

		# start = time.time()

		# if len(time_steps_to_predict.size()) == 1:
		# 	# all time series share the same time points
		# 	pred_y = run_odeint(ode_func, first_point, time_steps_to_predict, 
		# 		self.odeint_rtol, self.odeint_atol)
		# else:
		# 	n_traj = time_steps_to_predict.shape[1]
		# 	pred_y = []
		# 	for i in range(n_traj):
		# 		time_steps_to_predict_one = time_steps_to_predict[:,i]
		# 		first_point_one = first_point[:,i:(i+1)]

		# 		res = run_odeint(ode_func, first_point_one, time_steps_to_predict_one, 
		# 			self.odeint_rtol, self.odeint_atol)
		# 		pred_y.append(res)
		# 	pred_y_v2 = torch.stack(pred_y, 1).squeeze(2)
			

		#print("V2 time point union: time: " + str(time.time() - start))

		# print("Mean difference between two approaches of computing diffeq")
		# print(torch.mean(pred_y_v1 - pred_y_v2))

		pred_y = pred_y_v1

		# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 

		info = self.collect_and_cleanup()

		assert(len(self.ode_func.samples) == 0)
		assert(len(self.ode_func.ode_func_ts) == 0)

		assert(pred_y.size()[0] == n_traj_samples)
		assert(pred_y.size()[1] == n_traj)

		if cut_first_point:
			pred_y = pred_y[:,:,1:,:]
			
		if cut_first_point and (self.ode_func.q_gaussian_process is not None) and (ind_points_tuple is not None):
			info["gp_samples"] = info["gp_samples"][:,:,1:,:]
			info["ode_func_ts"] = info["ode_func_ts"][1:]
			info["q_mus"] = info["q_mus"][:,1:,:]
			info["q_sigmas"] = info["q_sigmas"][:,1:,:]

		return pred_y, info


	def collect_and_cleanup(self):
		q_mus = q_sigmas = gp_samples = gp_sample_ts_sorted = None

		ode_func_ts = ode_func_norms = None
		if self.ode_func.save_ode_ts:
			ode_func_ts = torch.stack(self.ode_func.ode_func_ts)
			ode_func_norms = torch.stack(self.ode_func.ode_func_norm)

		if self.ode_func.q_gaussian_process is not None:
			# Sort by time
			gp_sample_ts = torch.stack(self.ode_func.ode_func_ts)
			gp_sample_ts_sorted, sort_order = torch.sort(gp_sample_ts)

			gp_samples = torch.stack(self.ode_func.samples)
			gp_samples = gp_samples.permute(1,2,0,3)
			gp_samples = gp_samples[:,:,sort_order]
		
			q_mus = q_sigmas = None
			# q_mus and q_sigmas can be an empty list is we sampling from the GP prior
			if len(self.ode_func.q_mus) != 0:
				q_mus = torch.stack(self.ode_func.q_mus)
				q_sigmas = torch.stack(self.ode_func.q_sigmas)
				q_mus = q_mus.permute(1,2,0,3)
				q_sigmas = q_sigmas.permute(1,2,0,3)

				q_mus = q_mus[:,:, sort_order]
				q_sigmas = q_sigmas[:,:, sort_order]

				# since q is deterministic given the data, all the q_mus and q_sigmas should be the same along the first dimension (n_traj_samples)
				# new shape: [n_samples, n_timepoints, dim]
				q_mus = q_mus[0]
				q_sigmas = q_sigmas[0]

		self.ode_func.samples = []
		self.ode_func.ode_func_ts = []
		self.ode_func.ode_func_norm = []

		self.ode_func.q_mus = []
		self.ode_func.q_sigmas = []

		n_calls = self.ode_func.n_calls
		self.ode_func.n_calls = 0

		d = {"q_mus": q_mus, 
			"q_sigmas": q_sigmas, 
			"gp_samples": gp_samples, 
			"gp_sample_ts_sorted": gp_sample_ts_sorted, 
			"n_calls": n_calls,
			"ode_func_ts": ode_func_ts,
			"ode_func_norms": ode_func_norms}
		return d


	# def add_dimensions(self, first_point, n_dims):
	# 	additional_dims = torch.Tensor().new_zeros((first_point.size()[0], first_point.size()[1], n_dims), device = self.device)
	# 	#torch.randn(first_point.size()[0], n_dims).float()

	# 	first_point = torch.cat((first_point, additional_dims), 2)
	# 	return first_point

	def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict, 
		gp_dim = 1, n_traj_samples = 1, mixtures = None):
		"""
		# Decode the trajectory through ODE Solver using samples from the prior

		time_steps_to_predict: time steps at which we want to sample the new trajectory
		"""
		func = lambda t_local, y: self.ode_func.sample_next_point_from_prior(t_local, y, 
			gp_sample_size = (n_traj_samples, 1, gp_dim),
			mixtures = mixtures)

		pred_y = odeint(func, starting_point_enc, time_steps_to_predict, 
			rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method)
		# shape: [n_traj_samples, n_traj, n_tp, n_dim]
		pred_y = pred_y.permute(1,2,0,3)

		# Clean-up the samples that we took from the prior
		self.ode_func.prior_samples = []
		self.ode_func.prior_sample_ts = []

		return pred_y


