import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import relu

import lib.utils as utils
from lib.utils import get_device
from lib.plotting import *
from lib.encoder_decoder import *
from lib.likelihood_eval import *

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
from torch.nn.modules.rnn import GRUCell, LSTMCell, RNNCellBase

from torch.distributions.normal import Normal
from torch.distributions import Independent
from torch.nn.parameter import Parameter
from lib.base_models import Baseline, VAE_Baseline



# Exp decay between hidden states
class GRUCellExpDecay(RNNCellBase):
	def __init__(self, input_size, input_size_for_decay, hidden_size, device, bias=True):
		super(GRUCellExpDecay, self).__init__(input_size, hidden_size, bias, num_chunks=3)

		self.device = device
		self.input_size_for_decay = input_size_for_decay
		self.decay = nn.Sequential(nn.Linear(input_size_for_decay, 1),)
		utils.init_network_weights(self.decay)

	def gru_exp_decay_cell(self, input, hidden, w_ih, w_hh, b_ih, b_hh):
		# INPORTANT: assumes that cum delta t is the last dimension of the input
		batch_size, n_dims = input.size()
		
		# "input" contains the data, mask and also cumulative deltas for all inputs
		cum_delta_ts = input[:, -self.input_size_for_decay:]
		data = input[:, :-self.input_size_for_decay]

		#assert(data.size(-1) == self.input_size * 2)

		decay = torch.exp( - torch.min(torch.max(
			torch.zeros([1]).to(self.device), self.decay(cum_delta_ts)), 
			torch.ones([1]).to(self.device) * 1000 )   )

		hidden = hidden * decay

		gi = torch.mm(data, w_ih.t()) + b_ih
		gh = torch.mm(hidden, w_hh.t()) + b_hh
		i_r, i_i, i_n = gi.chunk(3, 1)
		h_r, h_i, h_n = gh.chunk(3, 1)

		resetgate = torch.sigmoid(i_r + h_r)
		inputgate = torch.sigmoid(i_i + h_i)
		newgate = torch.tanh(i_n + resetgate * h_n)
		hy = newgate + inputgate * (hidden - newgate)
		return hy

	def forward(self, input, hx=None):
		# type: (Tensor, Optional[Tensor]) -> Tensor
		#self.check_forward_input(input)
		if hx is None:
			hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
		#self.check_forward_hidden(input, hx, '')
		
		return self.gru_exp_decay_cell(
			input, hx,
			self.weight_ih, self.weight_hh,
			self.bias_ih, self.bias_hh
		)




def run_rnn(inputs, delta_ts, cell, first_hidden=None, 
	mask = None, feed_previous=False, n_steps=0,
	decoder = None, input_decay_params = None,
	append_to_previous = None, 
	feed_previous_w_prob = 0.):
	if (feed_previous or feed_previous_w_prob) and decoder is None:
		raise Exception("feed_previous is set to True -- please specify RNN decoder")

	if n_steps == 0:
		n_steps = inputs.size(1)

	if (feed_previous or feed_previous_w_prob) and mask is None:
		mask = torch.ones((inputs.size(0), n_steps, inputs.size(-1))).to(get_device(inputs))

	cum_delta_ts = get_cum_delta_ts(inputs, delta_ts, mask)

	if input_decay_params is not None:
		w_input_decay, b_input_decay = input_decay_params
		inputs = impute_using_input_decay(inputs, delta_ts, mask,
			w_input_decay, b_input_decay)

	all_hiddens = []
	hidden = first_hidden

	if hidden is not None:
		all_hiddens.append(hidden)
		n_steps -= 1

	for i in range(n_steps):
		delta_t = delta_ts[:,i]
		cum_delta_t = cum_delta_ts[:,i]
		if i == 0:
			rnn_input = inputs[:,i]
		elif feed_previous:
			rnn_input = decoder(hidden)
			if append_to_previous is not None:
				rnn_input = torch.cat((rnn_input, append_to_previous), -1)
		elif feed_previous_w_prob > 0:
			feed_prev = np.random.uniform() > feed_previous_w_prob
			if feed_prev:
				rnn_input = decoder(hidden)
				if append_to_previous is not None:
					rnn_input = torch.cat((rnn_input, append_to_previous), -1)
			else:
				rnn_input = inputs[:,i]
		else:
			rnn_input = inputs[:,i]

		if mask is not None:
			mask_i = mask[:,i,:]
			rnn_input = torch.cat((rnn_input, mask_i), -1)

		if isinstance(cell, GRUCellExpDecay):
			input_w_t = torch.cat((rnn_input, cum_delta_t), -1).squeeze(1)
		else:
			input_w_t = torch.cat((rnn_input, delta_t), -1).squeeze(1)

		hidden = cell(input_w_t, hidden)
		all_hiddens.append(hidden)

	all_hiddens = torch.stack(all_hiddens, 0)
	all_hiddens = all_hiddens.permute(1,0,2).unsqueeze(0)
	return hidden, all_hiddens







def impute_using_input_decay(data, delta_ts, mask, w_input_decay, b_input_decay):
	n_traj, n_tp, n_dims = data.size()

	cum_delta_ts = delta_ts.repeat(1, 1, n_dims)
	missing_index = np.where(mask.cpu().numpy() == 0)

	data_last_obsv = np.copy(data.cpu().numpy())
	for idx in range(missing_index[0].shape[0]):
		i = missing_index[0][idx] 
		j = missing_index[1][idx]
		k = missing_index[2][idx]

		if j != 0 and j != (n_tp-1):
		 	cum_delta_ts[i,j+1,k] = cum_delta_ts[i,j+1,k] + cum_delta_ts[i,j,k]
		if j != 0:
			data_last_obsv[i,j,k] = data_last_obsv[i,j-1,k] # last observation
	cum_delta_ts = cum_delta_ts / cum_delta_ts.max() # normalize
	
	data_last_obsv = torch.Tensor(data_last_obsv).to(get_device(data))

	zeros = torch.zeros([n_traj, n_tp, n_dims]).to(get_device(data))
	decay = torch.exp( - torch.min( torch.max(zeros, 
		w_input_decay * cum_delta_ts + b_input_decay), zeros + 1000 ))

	data_means = torch.mean(data, 1).unsqueeze(1)

	data_imputed = data * mask + (1-mask) * (decay * data_last_obsv + (1-decay) * data_means)
	return data_imputed



def get_cum_delta_ts(data, delta_ts, mask):
	n_traj, n_tp, n_dims = data.size()
	
	cum_delta_ts = delta_ts.repeat(1, 1, n_dims)
	missing_index = np.where(mask.cpu().numpy() == 0)

	for idx in range(missing_index[0].shape[0]):
		i = missing_index[0][idx] 
		j = missing_index[1][idx]
		k = missing_index[2][idx]

		if j != 0 and j != (n_tp-1):
		 	cum_delta_ts[i,j+1,k] = cum_delta_ts[i,j+1,k] + cum_delta_ts[i,j,k]
	cum_delta_ts = cum_delta_ts / cum_delta_ts.max() # normalize

	return cum_delta_ts





class Classic_RNN(Baseline):
	def __init__(self, input_dim, latent_dim, device, 
		concat_mask = False, obsrv_std = 0.1, 
		use_binary_classif = False,
		linear_classifier = False,
		classif_per_tp = False,
		input_space_decay = False,
		cell = "gru", n_units = 100,
		n_labels = 1,
		glob_dims = int(0)):
		
		super(Classic_RNN, self).__init__(input_dim, latent_dim, device, 
			obsrv_std = obsrv_std, 
			use_binary_classif = use_binary_classif,
			classif_per_tp = classif_per_tp,
			linear_classifier = linear_classifier,
			n_labels = n_labels)

		self.glob_dims = glob_dims
		self.concat_mask = concat_mask
		
		encoder_dim = int(input_dim + glob_dims)
		if concat_mask:
			encoder_dim = encoder_dim * 2

		self.decoder = nn.Sequential(
			nn.Linear(latent_dim, n_units),
			nn.Tanh(),
			nn.Linear(n_units, input_dim),)

		#utils.init_network_weights(self.encoder)
		utils.init_network_weights(self.decoder)

		if cell == "gru":
			self.rnn_cell = GRUCell(encoder_dim + 1, latent_dim) # +1 for delta t
		elif cell == "expdecay":
			self.rnn_cell = GRUCellExpDecay(
				input_size = encoder_dim, 
				input_size_for_decay = input_dim,
				hidden_size = latent_dim, 
				device = device)
		else:
			raise Exception("Unknown RNN cell: {}".format(cell))

		if input_space_decay:
			self.w_input_decay =  Parameter(torch.Tensor(1, int(input_dim + glob_dims))).to(self.device)
			self.b_input_decay =  Parameter(torch.Tensor(1, int(input_dim + glob_dims))).to(self.device)
		self.input_space_decay = input_space_decay

		self.y0_net = lambda hidden_state: hidden_state


	def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, 
		mask = None, n_traj_samples = 1):

		assert(mask is not None)

		n_traj, n_tp, n_dims = data.size()

		# time_steps_to_predict is not the same as truth_time_steps, do seq2seq model (take previous RNN preditions)
		do_extrap = len(truth_time_steps) != len(time_steps_to_predict)
		do_extrap = do_extrap or (torch.sum(time_steps_to_predict - truth_time_steps) != 0)

		if do_extrap:
			return self.run_seq2seq(time_steps_to_predict, data, truth_time_steps, 
				rnn_enc = self.rnn_cell, rnn_dec = self.rnn_cell,
				mask = mask, n_traj_samples = n_traj_samples)

		# for classic RNN time_steps_to_predict should be the same as  truth_time_steps
		assert(len(truth_time_steps) == len(time_steps_to_predict))

		batch_size = data.size(0)
		zero_delta_t = torch.Tensor([0.]).to(self.device)

		delta_ts = truth_time_steps[1:] - truth_time_steps[:-1]
		delta_ts = torch.cat((delta_ts, zero_delta_t))
		if len(delta_ts.size()) == 1:
			# delta_ts are shared for all trajectories in a batch
			assert(data.size(1) == delta_ts.size(0))
			delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1))

		input_decay_params = None
		if self.input_space_decay:
			input_decay_params = (self.w_input_decay, self.b_input_decay)

		globs = None
		if self.glob_dims > 0:
			globs = data[:,0,-int(self.glob_dims):]

		hidden_state, all_hiddens = run_rnn(data, delta_ts, 
			cell = self.rnn_cell, mask = mask,
			input_decay_params = input_decay_params,
			feed_previous_w_prob = (0. if self.use_binary_classif else 0.5),
			decoder = self.decoder,
			append_to_previous = globs)

		outputs = self.decoder(all_hiddens)
		# Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
		
		first_point = data[:,0,:]
		if self.glob_dims > 0:
			first_point = first_point[:,:-int(self.glob_dims)]
		outputs = utils.shift_outputs(outputs, first_point)

		hidden_changes = all_hiddens[:,:,1:] - all_hiddens[:,:,:-1]
		hidden_changes = torch.norm(hidden_changes, dim = -1) / torch.norm(all_hiddens, dim = -1)[:,:,1:]

		extra_info = { #"ind_points_tuple": ind_points_tuple, "latent_traj": sol_y,
			"gp_samples": None, #"gp_sample_ts": gp_sample_ts,
			"latent_traj": all_hiddens,
			"n_calls": 0., 
			"first_point": (hidden_state.unsqueeze(0), 0.0, hidden_state.unsqueeze(0)),
			"pred_mean_y0": outputs.squeeze(0),
			"hidden_changes": hidden_changes}

		if self.use_binary_classif:
			if self.classif_per_tp:
				extra_info["label_predictions"] = self.classifier(all_hiddens)
			else:
				extra_info["label_predictions"] = self.classifier(hidden_state).reshape(1,-1)

		# outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
		return outputs, extra_info


	def run_seq2seq(self, time_steps_to_predict, data, truth_time_steps, 
		rnn_enc, rnn_dec, mask = None, n_traj_samples = 1):

		assert(mask is not None)

		batch_size = data.size(0)
		zero_delta_t = torch.Tensor([0.]).to(self.device)
	
		# run encoder backwards
		run_backwards = bool(time_steps_to_predict[0] < truth_time_steps[-1])

		if run_backwards:
			# Look at data in the reverse order: from later points to the first
			data = utils.reverse(data)

		delta_ts = truth_time_steps[1:] - truth_time_steps[:-1]
		if run_backwards:
			# we are going backwards in time
			delta_ts = utils.reverse(delta_ts)

		delta_ts = torch.cat((delta_ts, zero_delta_t))
		if len(delta_ts.size()) == 1:
			# delta_ts are shared for all trajectories in a batch
			assert(data.size(1) == delta_ts.size(0))
			delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1))

		input_decay_params = None
		if self.input_space_decay:
			input_decay_params = (self.w_input_decay, self.b_input_decay)

		hidden_state, all_hiddens_enc = run_rnn(data, delta_ts, 
			cell = rnn_enc, mask = mask, 
			input_decay_params = input_decay_params)

		reconstr_observed_data = self.decoder(all_hiddens_enc)
		# Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
		
		first_point = data[:,0,:]
		if self.glob_dims > 0:
			first_point = first_point[:,:-int(self.glob_dims)]
		reconstr_observed_data = utils.shift_outputs(reconstr_observed_data, first_point)

		hidden_state = self.y0_net(hidden_state)

		# Decoder
		delta_ts = torch.cat((zero_delta_t, time_steps_to_predict[1:] - time_steps_to_predict[:-1]))
		if len(delta_ts.size()) == 1:
			delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1))

		globs = None
		if self.glob_dims > 0:
			globs = data[:,0,-int(self.glob_dims):]

		_, all_hiddens = run_rnn(data, delta_ts,
			cell = rnn_dec,
			first_hidden = hidden_state, feed_previous = True, 
			n_steps = time_steps_to_predict.size(0), 
			decoder = self.decoder,
			append_to_previous = globs)

		outputs = self.decoder(all_hiddens)
		
		# Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
		#outputs = utils.shift_outputs(outputs, data[:,0,:])


		extra_info = { #"ind_points_tuple": ind_points_tuple, "latent_traj": sol_y,
			"gp_samples": None, #"gp_sample_ts": gp_sample_ts,
			"latent_traj": all_hiddens,
			"latent_traj_enc": all_hiddens_enc,
			"reconstr_observed_data": reconstr_observed_data,
			"n_calls": 0., 
			"first_point": (hidden_state.unsqueeze(0), 0.0, hidden_state.unsqueeze(0)),
			"pred_mean_y0": outputs.squeeze(0)}


		if self.use_binary_classif:
			if self.classif_per_tp:
				extra_info["label_predictions"] = self.classifier(all_hiddens)
			else:
				extra_info["label_predictions"] = self.classifier(hidden_state).reshape(1,-1)

		# outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
		return outputs, extra_info





class RNN_Seq2Seq(Classic_RNN):
	def __init__(self, input_dim, latent_dim, rec_dims, device, 
		concat_mask = False, obsrv_std = 0.1, 
		input_space_decay = False,
		use_binary_classif = False,
		linear_classifier = False,
		cell = "gru", n_units = 100,
		glob_dims = 0,
		n_labels = 1):
		
		super(RNN_Seq2Seq, self).__init__(input_dim, latent_dim, device, 
			obsrv_std = obsrv_std, use_binary_classif = use_binary_classif,
			linear_classifier = linear_classifier,
			n_labels = n_labels)

		self.glob_dims = glob_dims
		self.concat_mask = concat_mask
		
		encoder_dim = int(input_dim + glob_dims)
		if concat_mask:
			encoder_dim = encoder_dim * 2


		if cell == "gru":
			self.rnn_cell_enc = GRUCell(encoder_dim + 1, rec_dims) # +1 for delta t
			self.rnn_cell_dec = GRUCell(encoder_dim + 1, latent_dim) # +1 for delta t
		elif cell == "expdecay":
			self.rnn_cell_enc = GRUCellExpDecay(
				input_size = encoder_dim, 
				input_size_for_decay = input_dim,
				hidden_size = rec_dims, 
				device = device)
			self.rnn_cell_dec = GRUCellExpDecay(
				input_size = encoder_dim, 
				input_size_for_decay = input_dim,
				hidden_size = latent_dim, 
				device = device)
		else:
			raise Exception("Unknown RNN cell: {}".format(cell))

		self.decoder = nn.Sequential(
			nn.Linear(latent_dim, n_units),
			nn.Tanh(),
			nn.Linear(n_units, input_dim),)
		utils.init_network_weights(self.decoder)

		self.y0_net = nn.Sequential(
			nn.Linear(rec_dims, n_units),
			nn.Tanh(),
			nn.Linear(n_units, latent_dim),)
		utils.init_network_weights(self.y0_net)


	def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, 
		mask = None, n_traj_samples = 1):

		return self.run_seq2seq(time_steps_to_predict, data, truth_time_steps, 
			rnn_enc = self.rnn_cell_enc, rnn_dec = self.rnn_cell_dec,
			mask = mask, n_traj_samples = n_traj_samples)





class RNN_VAE(VAE_Baseline):
	def __init__(self, input_dim, latent_dim, rec_dims, 
		y0_prior, device, 
		concat_mask = False, obsrv_std = 0.1, 
		input_space_decay = False,
		use_binary_classif = False,
		classif_per_tp =False,
		linear_classifier = False, 
		cell = "gru", n_units = 100,
		glob_dims= 0,
		n_labels = 1):
	
		super(RNN_VAE, self).__init__(
			input_dim = input_dim, latent_dim = latent_dim, 
			y0_prior = y0_prior, 
			device = device, obsrv_std = obsrv_std, 
			use_binary_classif = use_binary_classif, 
			classif_per_tp = classif_per_tp,
			linear_classifier = linear_classifier,
			n_labels = n_labels)

		self.glob_dims = glob_dims
		self.concat_mask = concat_mask

		encoder_dim = int(input_dim + glob_dims)
		if concat_mask:
			encoder_dim = encoder_dim * 2

		if cell == "gru":
			self.rnn_cell_enc = GRUCell(encoder_dim + 1, rec_dims) # +1 for delta t
			self.rnn_cell_dec = GRUCell(encoder_dim + 1, latent_dim) # +1 for delta t
		elif cell == "expdecay":
			self.rnn_cell_enc = GRUCellExpDecay(
				input_size = encoder_dim,
				input_size_for_decay = input_dim,
				hidden_size = rec_dims, 
				device = device)
			self.rnn_cell_dec = GRUCellExpDecay(
				input_size = encoder_dim,
				input_size_for_decay = input_dim,
				hidden_size = latent_dim, 
				device = device) 
		else:
			raise Exception("Unknown RNN cell: {}".format(cell))

		self.y0_net = nn.Sequential(
			nn.Linear(rec_dims, n_units),
			nn.Tanh(),
			nn.Linear(n_units, latent_dim * 2),)
		utils.init_network_weights(self.y0_net)

		self.decoder = nn.Sequential(
			nn.Linear(latent_dim, n_units),
			nn.Tanh(),
			nn.Linear(n_units, input_dim),)

		#utils.init_network_weights(self.encoder)
		utils.init_network_weights(self.decoder)

		if input_space_decay:
			self.w_input_decay =  Parameter(torch.Tensor(1, int(input_dim + glob_dims))).to(self.device)
			self.b_input_decay =  Parameter(torch.Tensor(1, int(input_dim + glob_dims))).to(self.device)
		self.input_space_decay = input_space_decay

	def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, 
		mask = None, n_traj_samples = 1):

		assert(mask is not None)

		batch_size = data.size(0)
		zero_delta_t = torch.Tensor([0.]).to(self.device)
	
		# run encoder backwards
		run_backwards = bool(time_steps_to_predict[0] < truth_time_steps[-1])

		if run_backwards:
			# Look at data in the reverse order: from later points to the first
			data = utils.reverse(data)

		delta_ts = truth_time_steps[1:] - truth_time_steps[:-1]
		if run_backwards:
			# we are going backwards in time
			delta_ts = utils.reverse(delta_ts)


		delta_ts = torch.cat((delta_ts, zero_delta_t))
		if len(delta_ts.size()) == 1:
			# delta_ts are shared for all trajectories in a batch
			assert(data.size(1) == delta_ts.size(0))
			delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1))

		input_decay_params = None
		if self.input_space_decay:
			input_decay_params = (self.w_input_decay, self.b_input_decay)

		hidden_state, _ = run_rnn(data, delta_ts, 
			cell = self.rnn_cell_enc, mask = mask,
			input_decay_params = input_decay_params)

		y0_mean, y0_std = utils.split_last_dim(self.y0_net(hidden_state))
		y0_std = y0_std.abs()
		y0_sample = utils.sample_standard_gaussian(y0_mean, y0_std)

		# Decoder # # # # # # # # # # # # # # # # # # # #
		delta_ts = torch.cat((zero_delta_t, time_steps_to_predict[1:] - time_steps_to_predict[:-1]))
		if len(delta_ts.size()) == 1:
			delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size,1,1))

		globs = None
		if self.glob_dims > 0:
			globs = data[:,0,-int(self.glob_dims):]

		_, all_hiddens = run_rnn(data, delta_ts,
			cell = self.rnn_cell_dec,
			first_hidden = y0_sample, feed_previous = True, 
			n_steps = time_steps_to_predict.size(0),
			decoder = self.decoder,
			input_decay_params = input_decay_params,
			append_to_previous = globs)

		outputs = self.decoder(all_hiddens)
		# Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
		first_point = data[:,0,:]
		if self.glob_dims > 0:
			first_point = first_point[:,:-int(self.glob_dims)]
		outputs = utils.shift_outputs(outputs, first_point)

		extra_info = { #"ind_points_tuple": ind_points_tuple, "latent_traj": sol_y,
			"gp_samples": None, #"gp_sample_ts": gp_sample_ts,
			"latent_traj": all_hiddens,
			"n_calls": 0., 
			"first_point": (y0_mean.unsqueeze(0), y0_std.unsqueeze(0), y0_sample.unsqueeze(0)),
			"pred_mean_y0": outputs.squeeze(0)}

		if self.use_binary_classif:
			if self.classif_per_tp:
				extra_info["label_predictions"] = self.classifier(all_hiddens)
			else:
				extra_info["label_predictions"] = self.classifier(y0_mean).reshape(1,-1)

		# outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
		return outputs, extra_info



