import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.spectral_norm import spectral_norm

import lib.utils as utils
from lib.gaussian_process import PytorchGaussianProcess

#####################################################################################################

def get_q_for_one_time_point(gaussian_process, t_local, gp_sample_size, prev_samples = None, prev_sample_ts = None):
	# Get distribution Q for one time point from GP given previous samples

	# if we didn't sample from this GP before
	if (prev_sample_ts is not None) and (len(prev_sample_ts) > 0):
		# make time dimension to be first
		# shape before: [n_traj_samples, n_traj, n_timepoints, n_dim]
		# shape after: [n_timepoints, n_traj_samples, n_traj, n_dim]
		prev_samples = prev_samples.permute(2,0,1,3)
		#ind_point_variance = ind_point_variance.permute(1,0,2)

		# Double-check number of time points
		assert(prev_samples.size(0) == len(prev_sample_ts))

		# shape after: [n_timepoints, n_traj_samples * n_traj * n_dim] 
		# where n_timepoints = 1
		prev_samples = prev_samples.contiguous().view(len(prev_sample_ts), -1)

		# initialize GP with inducing points and predict posterior over point t_local
		# q_cov_matrix has shape [1,1]
		q_mean, q_cov_matrix = gaussian_process.gp_regression(
			t_local.view(1), prev_tp = prev_sample_ts, prev_samples = prev_samples)

	else:
		q_mean, q_cov_matrix = gaussian_process.get_mean_and_cov(t_local.view(1))
		q_mean = q_mean.repeat(1, torch.prod(torch.Tensor(gp_sample_size)).int())
	
	# q is only for 1 time point
	# q has dimensionality data_dim

	q_mean = q_mean.view(1, *gp_sample_size)
	# shape: [n_traj_samples, n_traj, n_tp, n_dim]
	q_mean = q_mean.permute(1,2,0,3)

	# shape of both q_mu and q_sigma to comply with ODEFuncLocalQ: [n_traj_samples, n_traj, n_dim]
	q_mean = q_mean.squeeze(2)
	q_cov_matrix = q_cov_matrix.repeat(q_mean.size())
	return q_mean, q_cov_matrix


class ODEFunc(nn.Module):
	# ODE function where we compute the approx posterior Q is a deep gaussian process (kernel is learned using a neural net)
	# Q is derived for all the time points together before we start solving an ODE
	def __init__(self, input_dim, latent_dim, 
		q_gaussian_process = None, gp_prior = None, 
		save_ode_ts = True, device = torch.device("cpu")):
		"""
		input_dim: dimensionality of the input
		latent_dim: dimensionality used for ODE. Analog of a continous latent state
		q_gaussian_process: Gaussian process for sampling z(t). Can be None, if we don't use samples from z(t)
		"""
		super(ODEFunc, self).__init__()

		self.q_gaussian_process = q_gaussian_process
		self.gp_prior = gp_prior

		self.samples = []
		self.ode_func_ts = []
		self.ode_func_norm = []

		self.prior_samples = []
		self.prior_sample_ts = []
		
		self.q_mus = []
		self.q_sigmas = []

		self.input_dim = input_dim
		self.device = device

		self.n_calls = 0
		self.save_ode_ts = save_ode_ts

	def get_approx_posterior(self, t_local, y, inducing_points, time_steps_ind):
		# Get distribution Q (which is a GP) over the noise term at the particular time point

		"""
		t_local: current time point 
		y: the value of ODE solution at the previous point
		inducing_points: inducing points for gaussian process that is our q
		time_steps_ind: time stamps for the corresponding data in 'inducing_points' 
		"""
		assert(inducing_points is not None)
		assert(time_steps_ind is not None)

		assert(inducing_points.size()[0] == y.size()[1])

		n_samples, n_tp, n_dim = inducing_points.size()
		n_traj_samples, _, _ = y.size()

		inducing_points = inducing_points.repeat(n_traj_samples, 1, 1, 1)

		points_for_q = inducing_points
		points_for_q_ts = time_steps_ind

		# condition on the previous samples from GP
		if len(self.samples) > 0:
			# for speed, condition only up to n_prev_samples_to_use previous points
			n_prev_samples_to_use = 50

			gp_samples = torch.stack(self.samples[-n_prev_samples_to_use:])
			gp_samples = gp_samples.permute(1,2,0,3)		
			gp_sample_ts = torch.Tensor().new_tensor(self.sample_ts[-n_prev_samples_to_use:], device = self.device)

			points_for_q = torch.cat((inducing_points, gp_samples), 2)
			points_for_q_ts = torch.cat((time_steps_ind, gp_sample_ts))
			assert(len(points_for_q_ts) == points_for_q.size(2))

		q_mean, q_cov_matrix = get_q_for_one_time_point(
			self.q_gaussian_process, t_local,
			gp_sample_size = (n_traj_samples, n_samples, n_dim),
			prev_samples = points_for_q, prev_sample_ts = points_for_q_ts)

		assert(q_mean.size() == (n_traj_samples, n_samples, n_dim))

		self.q_mus.append(q_mean)
		self.q_sigmas.append(q_cov_matrix)

		return q_mean, q_cov_matrix


	def sample_q(self, q_mu, q_cov):
		# get a sample from approximate posterior over the noise at the current time point
		# since we are taking samples from q only for t_local, covariance matrix is 1x1

		# gp_posterior_sample shape after: [n_traj_samples, n_samples, n_dim]
		return utils.sample_standard_gaussian(q_mu, q_cov)

	def forward(self, t_local, y, context = None, backwards = False, mixtures = None):
		"""
		Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point

		t_local: current time point
		y: value at the current time point
		context: "context" for the distribution Q to infer the posterior over z at time point t_local. We use all ground truth trajectory as context. 
			Can be None, samples from GP are not used
		"""
		n_traj, n_tp, n_latent_dims = y.size()

		gp_posterior_sample = None
		if (self.q_gaussian_process is not None) and (context is not None):
			inducing_points, time_steps_ind = context
			
			q_mu, q_cov = self.get_approx_posterior(t_local, y, inducing_points, time_steps_ind)
			gp_posterior_sample = self.sample_q(q_mu, q_cov)
		else:
			gp_posterior_sample = None #torch.zeros((n_traj, n_tp, self.input_dim)).to(self.device)

		self.samples.append(gp_posterior_sample)

		self.n_calls += 1
		grad = self.get_ode_gradient_nn(t_local, y, gp_posterior_sample, mixtures)
		if backwards:
			grad = -grad

		if self.save_ode_ts:
			self.ode_func_ts.append(t_local.detach())
			self.ode_func_norm.append(torch.norm(grad).detach())

		return grad


	def sample_next_point_from_prior(self, t_local, y, gp_sample_size, mixtures = None):
		"""
		Perform one step in solving ODE. 
		Given current data point y and current time point t_local, return gradient dy/dt at this time point

		t_local: current time point
		y: value at the current time point
		"""
		n_traj, n_tp, n_latent_dims = y.size()

		prev_samples = prev_tp = []
		gp_sample = None
		if self.gp_prior is not None:
			if len(self.prior_sample_ts) > 0:
				prev_samples = torch.stack(self.prior_samples)
				prev_tp = torch.stack(self.prior_sample_ts)
				# shape after permute: [n_traj_samples, n_traj, n_timepoints, n_dim]
				prev_samples = prev_samples.permute(1,2,0,3)	


			prior_mu_one_tp, prior_cov_one_tp = get_q_for_one_time_point(
				self.gp_prior, t_local,
				gp_sample_size = gp_sample_size,
				prev_samples = prev_samples,
				prev_sample_ts = prev_tp)

			gp_sample = self.sample_q(prior_mu_one_tp, prior_cov_one_tp)
		else:
			gp_sample = None #torch.zeros((n_traj, n_tp, self.input_dim)).to(self.device)

		self.prior_samples.append(gp_sample)
		self.prior_sample_ts.append(t_local)

		return self.get_ode_gradient_nn(t_local, y, gp_sample, mixtures)


#####################################################################################################

class ODEFunc_one_ODE(ODEFunc):
	# ODE function where we compute the approx posterior Q is a deep gaussian process (kernel is learned using a neural net)
	# Q is derived for all the time points together before we start solving an ODE
	def __init__(self, input_dim, latent_dim, ode_func_net, q_gaussian_process = None, gp_prior = None, 
		fix_decoder = False, save_ode_ts = True, device = torch.device("cpu")):
		"""
		input_dim: dimensionality of the input
		latent_dim: dimensionality used for ODE. Analog of a continous latent state
		"""
		super(ODEFunc_one_ODE, self).__init__(input_dim, latent_dim, q_gaussian_process, gp_prior, save_ode_ts, device)

		# # dim = latent_dim if (self.q_gaussian_process is None) else (latent_dim + input_dim)
		# dim = latent_dim + input_dim # If we don't add the sample from z(t), just append zeros to the input
	
		# gradient_net = nn.Sequential(
		# 	nn.Linear(dim, 100),
		# 	nn.Tanh(),
		# 	nn.Linear(100, 100),
		# 	nn.Tanh(),
		# 	nn.Linear(100, latent_dim))

		if fix_decoder:
			# Don't update the paramters of the decoder to make the comparison between ode- and rnn-based decoder fair.
			#utils.init_network_weights(ode_func_net, std = 0.1)

			# for i, m in enumerate(ode_func_net.modules()):
			# 	if isinstance(m, nn.Linear):
			# 			nn.init.normal_(m.weight, mean=0., std=1.)

			for param in ode_func_net.parameters():
				param.requires_grad = False

			# Oscillating spring
			# ODE from here: https://en.wikipedia.org/wiki/Examples_of_differential_equations
			# In the example: a = -0.25, b = 1.
			# For 2d: [[0, 1], [-a**2-b**2, 2*a]]
			ode_matrix = np.zeros((latent_dim,latent_dim+1))
			ode_matrix[0,:2] = [0, 1]
			for i in range(latent_dim-1):
				a = np.random.uniform(low=-2., high=0.) # must be negative
				b = np.random.uniform(low=-1., high=1.)
				ode_matrix[i,i:(i+2)] = [-a**2-b**2, 2*a]

			ode_matrix = torch.Tensor(ode_matrix, device = self.device).permute(1,0)

			def func(x):
				batch_size = x.size(0)

				return torch.bmm(x, ode_matrix.repeat(batch_size, 1, 1))
			self.gradient_net = func

			#self.gradient_net = ode_func_net

		else:
			utils.init_network_weights(ode_func_net)
			self.gradient_net = ode_func_net

		# # !!!!!!!!!!! Experiment: can it learn the global state g?
		# self.n_global_dims = 3
		# self.glob = nn.Parameter(torch.FloatTensor(1, 100, self.n_global_dims)) # [1, n_total_traj, n_g_dims]
		# nn.init.normal_(self.glob, mean=0, std=0.1)

		# self.gradient_net = nn.Sequential(
		# 	nn.Linear(dim + self.n_global_dims, 100),
		# 	nn.Tanh(),
		# 	nn.Linear(100, latent_dim),)

	def get_ode_gradient_nn(self, t_local, y, gp_posterior_sample, mixtures = None):
		"""
		mixtures: not used
		"""
		#t_local_tiled = (t_local.repeat(y.size()))
		inputs = y if (gp_posterior_sample is None) else torch.cat((y, gp_posterior_sample),-1)
		# n_traj_samples = inputs.size(0)
		# inputs = torch.cat((inputs, self.glob.repeat(n_traj_samples,1,1) ), 2)
		return self.gradient_net(inputs)


class ODEFunc_ODEmixture(ODEFunc):
	# ODE function where we compute the approx posterior Q is a deep gaussian process (kernel is learned using a neural net)
	# Q is derived for all the time points together before we start solving an ODE
	def __init__(self, input_dim, latent_dim, q_gaussian_process = None, 
		gp_prior = None, n_ode_dynamics = 1, save_ode_ts = True, device = torch.device("cpu")):
		"""
		input_dim: dimensionality of the input
		latent_dim: dimensionality used for ODE. Analog of a continous latent state
		"""
		super(ODEFunc_ODEmixture, self).__init__(input_dim, latent_dim, q_gaussian_process, gp_prior, save_ode_ts, device)

		self.gradient_nets = []
		#dim = latent_dim if (self.q_gaussian_process is None) else (latent_dim + input_dim)
		dim = latent_dim + input_dim # If we don't add the sample from z(t), just append zeros to the input
		for i in range(n_ode_dynamics):
			ode_instance = nn.Sequential(
				nn.Linear(dim, 100),
				nn.Tanh(),
				nn.Linear(100, latent_dim),)
			utils.init_network_weights(ode_instance)
			self.gradient_nets.append(ode_instance)

	def get_ode_gradient_nn(self, t_local, y, gp_posterior_sample, mixtures):
		#t_local_tiled = (t_local.repeat(y.size()))
		inputs = y if (gp_posterior_sample is None) else torch.cat((y, gp_posterior_sample),-1)

		assert(mixtures.size(-1) == len(self.gradient_nets))
		res = torch.zeros_like(y).to(self.device)
		for i in range(len(self.gradient_nets)):
			res += self.gradient_nets[i](inputs) * mixtures[:,:,i].unsqueeze(2)
		return res



class ODEFunc_spiral():
	# ODE function where we compute the approx posterior Q is a deep gaussian process (kernel is learned using a neural net)
	# Q is derived for all the time points together before we start solving an ODE
	def __init__(self, input_dim, latent_dim, device = torch.device("cpu")):
		"""
		input_dim: dimensionality of the input
		latent_dim: dimensionality used for ODE. Analog of a continous latent state
		"""

		if latent_dim != 2:
			raise Exception("Number of latents should be equal to 2")

		# dim = latent_dim if (self.q_gaussian_process is None) else (latent_dim + input_dim)
		dim = latent_dim + input_dim # If we don't add the sample from z(t), just append zeros to the input

		self.input_dim = input_dim
		self.latent_dim = latent_dim
		self.device = device

		self.n_calls = 0

		# Not used in this class
		# for compatibility with other ODE classes
		self.samples = []
		self.ode_func_ts = []

		self.prior_samples = []
		self.prior_sample_ts = []

		self.q_gaussian_process = None
		self.gp_prior = None

	def get_ode_gradient_nn(self, t_local, y, gp_posterior_sample, mixtures = None):
		# requires 2d y
		batch_size = y.size(0)
		# http://www.math.psu.edu/tseng/class/Math251/Notes-PhasePlane.pdf
		# smaller coef -- less curvy trajectory, larger coef -- more curvy
		coef = 3. # coef = 3.0 or 8.0 for Ay. coef = 2.0 in ode-demo
		minor_coef = -0.1 # minor_coef = 0.1 for Ay exploding spiral. minor_coef = -0.1 in ode-demo. Positive minor_coef means exploding spiral; negative -- imploding
		bias = 1. # adding a bias term causes all trajectories to be periodic
		# We can also increase the variance in y0_prior to get more varios trajectories
		A =  torch.Tensor([[minor_coef, coef], [-coef, minor_coef]]).unsqueeze(0).repeat(batch_size, 1, 1)
		A = A.to(self.device)
		grad_y = torch.bmm(y**3, A) + bias
		return grad_y

	def __call__(self, t_local, y, context = None, backwards = False, mixtures = None):
		"""
		Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point

		t_local: current time point
		y: value at the current time point
		context: "context" for the distribution Q to infer the posterior over z at time point t_local. We use all ground truth trajectory as context. 
			Can be None, samples from GP are not used
		"""
		self.n_calls += 1
		grad = self.get_ode_gradient_nn(t_local, y, None)
		return grad


	def sample_next_point_from_prior(self, t_local, y, gp_sample_size, mixtures = None):
		"""
		Perform one step in solving ODE. 
		Given current data point y and current time point t_local, return gradient dy/dt at this time point

		t_local: current time point
		y: value at the current time point
		"""
		return self.get_ode_gradient_nn(t_local, y, None)




class ODEFunc_LotkaVolterra():
	# ODE function where we compute the approx posterior Q is a deep gaussian process (kernel is learned using a neural net)
	# Q is derived for all the time points together before we start solving an ODE
	def __init__(self, input_dim, latent_dim, device = torch.device("cpu")):
		"""
		input_dim: dimensionality of the input
		latent_dim: dimensionality used for ODE. Analog of a continous latent state
		"""

		if latent_dim != 2:
			raise Exception("Number of latents should be equal to 2")

		# dim = latent_dim if (self.q_gaussian_process is None) else (latent_dim + input_dim)
		dim = latent_dim + input_dim # If we don't add the sample from z(t), just append zeros to the input

		self.input_dim = input_dim
		self.latent_dim = latent_dim
		self.device = device

		self.n_calls = 0

		# Not used in this class
		# for compatibility with other ODE classes
		self.samples = []
		self.ode_func_ts = []

		self.prior_samples = []
		self.prior_sample_ts = []

		self.q_gaussian_process = None
		self.gp_prior = None

	def get_ode_gradient_nn(self, t_local, y, gp_posterior_sample, mixtures = None):
		# requires 2d y
		batch_size = y.size(0)
		
		y1, y2 = y[:,:,0], y[:,:,1]

		alpha = 1
		beta = 0.2
		delta = 0.5
		gamma = 0.2

		grad_y = torch.stack([alpha* y1 - beta * y1 * y2, delta * y1 * y2 - gamma * y2])
		return grad_y

	def __call__(self, t_local, y, context = None, backwards = False, mixtures = None):
		"""
		Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point

		t_local: current time point
		y: value at the current time point
		context: "context" for the distribution Q to infer the posterior over z at time point t_local. We use all ground truth trajectory as context. 
			Can be None, samples from GP are not used
		"""
		self.n_calls += 1
		grad = self.get_ode_gradient_nn(t_local, y, None)
		return grad


	def sample_next_point_from_prior(self, t_local, y, gp_sample_size, mixtures = None):
		"""
		Perform one step in solving ODE. 
		Given current data point y and current time point t_local, return gradient dy/dt at this time point

		t_local: current time point
		y: value at the current time point
		"""
		return self.get_ode_gradient_nn(t_local, y, None)



#####################################################################################################

class ODEFunc_w_Poisson(ODEFunc):
	
	def __init__(self, input_dim, latent_dim, ode_func_net,
		lambda_net,
		q_gaussian_process = None, gp_prior = None, 
		fix_decoder = False, save_ode_ts = True, device = torch.device("cpu")):
		"""
		input_dim: dimensionality of the input
		latent_dim: dimensionality used for ODE. Analog of a continous latent state
		"""
		super(ODEFunc_w_Poisson, self).__init__(input_dim, latent_dim, q_gaussian_process, gp_prior, save_ode_ts, device)

		self.latent_ode = ODEFunc_one_ODE(input_dim = input_dim, 
			latent_dim = latent_dim, 
			ode_func_net = ode_func_net, #q_gaussian_process = q_gaussian_process, gp_prior = gp_prior, 
			fix_decoder = fix_decoder, device = device)

		self.latent_dim = latent_dim
		self.lambda_net = lambda_net
		# The computation of poisson likelihood can become numerically unstable. 
		#The integral lambda(t) dt can take large values. In fact, it is equal to the expected number of events on the interval [0,T]
		#Exponent of lambda can also take large values
		# So we divide lambda by the constant and then multiply the integral of lambda by the constant
		self.const_for_lambda = torch.Tensor([100.]).to(device)

	def extract_poisson_rate(self, augmented, final_result = True):
		y, log_lambdas, int_lambda = None, None, None

		assert(augmented.size(-1) == self.latent_dim + self.input_dim)		
		latent_lam_dim = self.latent_dim // 2

		if len(augmented.size()) == 3:
			int_lambda  = augmented[:,:,-self.input_dim:] 
			
			# Multiply the intergral over lambda by a constant 
			# only when we have finished the integral computation (i.e. this is not a call in get_ode_gradient_nn)
			if final_result:
				int_lambda = int_lambda * self.const_for_lambda
			
			y_latent_lam = augmented[:,:,:-self.input_dim]

			log_lambdas  = self.lambda_net(y_latent_lam[:,:,-latent_lam_dim:])
			y = y_latent_lam[:,:,:-latent_lam_dim]

		elif len(augmented.size()) == 4:
			int_lambda  = augmented[:,:,:,-self.input_dim:]
			
			if final_result:
				int_lambda = int_lambda * self.const_for_lambda

			y_latent_lam = augmented[:,:,:,:-self.input_dim]

			log_lambdas  = self.lambda_net(y_latent_lam[:,:,:,-latent_lam_dim:])
			y = y_latent_lam[:,:,:,:-latent_lam_dim]

		# Latents for performing reconstruction (y) have the same size as latent poisson rate (log_lambdas)
		assert(y.size(-1) == latent_lam_dim)

		return y, log_lambdas, int_lambda, y_latent_lam


	def get_ode_gradient_nn(self, t_local, augmented, gp_posterior_sample, mixtures = None):
		y, log_lam, int_lambda, y_latent_lam = self.extract_poisson_rate(augmented, final_result = False)
		dydt_dldt = self.latent_ode(t_local, y_latent_lam, gp_posterior_sample, mixtures)

		log_lam = log_lam - torch.log(self.const_for_lambda)
		return torch.cat((dydt_dldt, torch.exp(log_lam)),-1)




