import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import relu
import lib.utils as utils
from torch.distributions import Categorical, Normal
import lib.utils as utils
from torch.nn.modules.rnn import LSTM, GRU
from lib.utils import get_device

# def gaussian_prod(mean1, std1, mean2, std2):
# 	means = torch.stack((mean1, mean2), -1)
# 	stds  = torch.stack((std1, std2), -1)

# 	# take the product of two gaussians
	
# 	# general case
# 	# prod_std = 1. / torch.sum(1. / stds,-1)
# 	# prod_mean = torch.sum(means / stds,-1) * prod_std

# 	# product of two gaussians 
# 	prod_std = std1 * std2 / (std1 + std2)
# 	prod_mean = (std1 * mean2 + std2 * mean1) / (std1 + std2)

# 	return prod_mean, prod_std


# GRU description: 
# http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/
class GRU_unit(nn.Module):
	def __init__(self, latent_dim, input_dim, 
		update_gate = None,
		reset_gate = None,
		new_state_net = None,
		n_units = 100,
		device = torch.device("cpu")):
		super(GRU_unit, self).__init__()

		if update_gate is None:
			self.update_gate = nn.Sequential(
			   nn.Linear(latent_dim * 2 + input_dim, n_units),
			   nn.Tanh(),
			   nn.Linear(n_units, latent_dim),
			   nn.Sigmoid())
			utils.init_network_weights(self.update_gate)
		else: 
			self.update_gate  = update_gate

		if reset_gate is None:
			self.reset_gate = nn.Sequential(
			   nn.Linear(latent_dim * 2 + input_dim, n_units),
			   nn.Tanh(),
			   nn.Linear(n_units, latent_dim),
			   nn.Sigmoid())
			utils.init_network_weights(self.reset_gate)
		else: 
			self.reset_gate  = reset_gate

		if new_state_net is None:
			self.new_state_net = nn.Sequential(
			   nn.Linear(latent_dim * 2 + input_dim, n_units),
			   nn.Tanh(),
			   nn.Linear(n_units, latent_dim * 2))
			utils.init_network_weights(self.new_state_net)
		else: 
			self.new_state_net  = new_state_net


	def forward(self, y_mean, y_std, x):
		y_concat = torch.cat([y_mean, y_std, x], -1)

		update_gate = self.update_gate(y_concat)
		reset_gate = self.reset_gate(y_concat)
		concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1)
		
		new_state, new_state_std = utils.split_last_dim(self.new_state_net(concat))
		new_state_std = new_state_std.abs()

		new_y = (1-update_gate) * new_state + update_gate * y_mean
		new_y_std = (1-update_gate) * new_state_std + update_gate * y_std

		new_y_std = new_y_std.abs()
		return new_y, new_y_std


def check_t0(t0, gt_time_steps):
	if t0 is None:
		t0 = gt_time_steps[0]
		run_backwards = True
	elif t0 <= gt_time_steps[0]:
		run_backwards = True
	elif t0 >= gt_time_steps[-1]:
		run_backwards = False
	else:
		raise Exception("Error: t0 provided to ODE Combine must be "
			"either before or after any observed time point.")
	return t0, run_backwards


def make_triplets(gt_data, gt_time_steps):
	# take triplets of datapoints x_t_(i-1), x_t_i, x_t_(i+1)
	n_traj, n_tp, n_dims = gt_data.size()
	n_triplets = n_tp - 2

	triplets = []
	for i in range(1, n_triplets+1):
		triplets.append(gt_data[:, (i-1):(i+2), :].reshape(n_traj, -1))
	triplets = torch.stack(triplets, 1)
	triplet_time_steps = gt_time_steps[1:-1]
	return triplets, triplet_time_steps



class Encoder_y0_from_rnn(nn.Module):
	def __init__(self, latent_dim, input_dim, lstm_output_size = 20, 
		use_delta_t = True, device = torch.device("cpu")):
		
		super(Encoder_y0_from_rnn, self).__init__()
	
		self.gru_rnn_output_size = lstm_output_size
		self.latent_dim = latent_dim
		self.input_dim = input_dim
		self.device = device
		self.use_delta_t = use_delta_t

		self.hiddens_to_y0 = nn.Sequential(
		   nn.Linear(self.gru_rnn_output_size, 50),
		   nn.Tanh(),
		   nn.Linear(50, latent_dim * 2),)

		utils.init_network_weights(self.hiddens_to_y0)

		input_dim = self.input_dim

		if use_delta_t:
			self.input_dim += 1
		self.gru_rnn = GRU(self.input_dim, self.gru_rnn_output_size).to(device)

	def forward(self, gt_data, gt_time_steps, mask = None, t0 = None):
		# gt_data shape: [n_traj, n_tp, n_dims]
		# shape required for rnn: (seq_len, batch, input_size)
		# t0: not used here
		n_traj = gt_data.size(0)
		
		assert(not torch.isnan(gt_data).any())
		assert(not torch.isnan(gt_time_steps).any())

		if mask is not None:
			assert(not torch.isnan(mask).any())

		data = gt_data.permute(1,0,2) 
		if mask is not None:
			mask = mask.permute(1,0,2) 
			data = torch.cat((data, mask), -1)

		t0, run_backwards = check_t0(t0, gt_time_steps)

		if run_backwards:
			# Look at data in the reverse order: from later points to the first
			data = utils.reverse(data)

		assert(not torch.isnan(data).any())

		if self.use_delta_t:
			delta_t = gt_time_steps[1:] - gt_time_steps[:-1]
			if run_backwards:
				# we are going backwards in time with
				delta_t = utils.reverse(delta_t)
			# append zero delta t in the end
			delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device)))
			delta_t = delta_t.unsqueeze(1).repeat((1,n_traj)).unsqueeze(-1)
			data = torch.cat((delta_t, data),-1)

		assert(not torch.isnan(data).any())

		outputs, _ = self.gru_rnn(data)
		assert(not torch.isnan(outputs).any())

		# LSTM output shape: (seq_len, batch, num_directions * hidden_size)
		last_output = outputs[-1]

		self.extra_info ={"rnn_outputs": outputs, "time_points": gt_time_steps}

		mean, std = utils.split_last_dim(self.hiddens_to_y0(last_output))
		std = std.abs()

		assert(not torch.isnan(mean).any())
		assert(not torch.isnan(std).any())

		return mean.unsqueeze(0), std.unsqueeze(0)





class Encoder_y0_ode_combine(nn.Module):
	# Derive y0 by running ode backwards.
	# For every y_i we have two versions: encoded from data and derived from ODE by running it backwards from t_i+1 to t_i
	# Compute a weighted sum of y_i from data and y_i from ode. Use weighted y_i as an initial value for ODE runing from t_i to t_i-1
	# Continue until we get to y0
	def __init__(self, latent_dim, input_dim, y0_diffeq_solver = None, 
		y0_dim = None, GRU_update = None, 
		n_gru_units = 100, 
		device = torch.device("cpu")):
		
		super(Encoder_y0_ode_combine, self).__init__()

		if y0_dim is None:
			self.y0_dim = latent_dim
		else:
			self.y0_dim = y0_dim

		# if self.y0_dim > latent_dim:
		# 	raise Exception("ODE dimensionality in the recognition model (" + str(y0_dim) + 
		# 		") must be bigger than in generative model (" + str(latent_dim) + ")")

		if GRU_update is None:
			self.GRU_update = GRU_unit(latent_dim, input_dim, 
				n_units = n_gru_units, 
				device=device).to(device)
		else:
			self.GRU_update = GRU_update

		self.y0_diffeq_solver = y0_diffeq_solver
		self.latent_dim = latent_dim
		self.input_dim = input_dim
		self.device = device
		self.extra_info = None

		self.transform_y0 = nn.Sequential(
		   nn.Linear(latent_dim * 2, 100),
		   nn.Tanh(),
		   nn.Linear(100, self.y0_dim * 2),)
		utils.init_network_weights(self.transform_y0)


	def forward(self, gt_data, gt_time_steps, t0 = None, save_info = False):
		# gt_data, gt_time_steps -- observations and their time stamps
		# t0 -- time stamp of y0 which we are making the encoding for
		# t0 must be either before or after any point within gt_time_steps
	
		n_traj, n_tp, n_dims = gt_data.size()
		if len(gt_time_steps) == 1:
			prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(self.device)
			prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(self.device)

			xi = gt_data[:,0,:].unsqueeze(0)

			last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi)
			extra_info = None
		else:
			
			last_yi, last_yi_std, _, extra_info = self.run_ode_combine(
				gt_data, gt_time_steps, t0 = t0, 
				save_latents = False, save_info = save_info)

		means_y0 = last_yi.reshape(1, n_traj, self.latent_dim)
		std_y0 = last_yi_std.reshape(1, n_traj, self.latent_dim)

		mean_y0, std_y0 = utils.split_last_dim( self.transform_y0( torch.cat((means_y0, std_y0), -1)))
		std_y0 = std_y0.abs()
		if save_info:
			self.extra_info = extra_info
		return mean_y0, std_y0


	def run_ode_combine(self, gt_data, gt_time_steps, 
		t0 = None, save_latents = False, save_info = False):
		# t0 -- time stamp of y0 which we are making the encoding for 
		# t0 must be either before or after any point within gt_time_steps
		n_traj, n_tp, n_dims = gt_data.size()
		extra_info = []
		
		t0, run_backwards = check_t0(t0, gt_time_steps)

		device = get_device(gt_data)

		prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device)
		prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device)

		prev_t, t_i = gt_time_steps[-1] + 0.01,  gt_time_steps[-1]

		interval_length = gt_time_steps[-1] - gt_time_steps[0]
		minimum_step = interval_length / 50

		#print("minimum step: {}".format(minimum_step))

		assert(not torch.isnan(gt_data).any())
		assert(not torch.isnan(gt_time_steps).any())

		latent_ys = []
		# Run ODE backwards and combine the y(t) estimates using gating
		time_points_iter = range(0, len(gt_time_steps))
		if run_backwards:
			time_points_iter = reversed(time_points_iter)

		for i in time_points_iter:
			#print("time point " + str(gt_time_steps[i]))
			
			if (prev_t - t_i) < minimum_step:
				time_points = torch.stack((prev_t, t_i))
				inc = self.y0_diffeq_solver.ode_func(prev_t, prev_y) * (t_i - prev_t)

				if torch.isnan(inc).any():
					print("inc is None!!")
					print(i)
					print(self.y0_diffeq_solver.ode_func(prev_t, prev_y))
					print(torch.isnan(prev_y).any())
				
				assert(not torch.isnan(t_i - prev_t).any())
				assert(not torch.isnan(prev_t).any())
				assert(not torch.isnan(prev_y).any())
				assert(not torch.isnan(inc).any())

				ode_sol = prev_y + inc
				ode_sol = torch.stack((prev_y, ode_sol), 2).to(device)

				assert(not torch.isnan(ode_sol).any())
			else:
				#print("running ODE solver in encoder: {}".format((prev_t - t_i)))
				n_intermediate_tp = max(2, ((prev_t - t_i) / minimum_step).int())

				time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp)
				
				#print("solving ode. N steps: {}".format(n_intermediate_tp))

				assert(not torch.isnan(time_points).any())
				assert(not torch.isnan(prev_y).any())

				#time_points = torch.stack((prev_t, t_i))
				ode_sol, _ = self.y0_diffeq_solver(prev_y, time_points)

				assert(not torch.isnan(ode_sol).any())

			if torch.mean(ode_sol[:, :, 0, :]  - prev_y) >= 0.001:
				print("Error: first point of the ODE is not equal to initial value")
				print(torch.mean(ode_sol[:, :, 0, :]  - prev_y))
				exit()
			#assert(torch.mean(ode_sol[:, :, 0, :]  - prev_y) < 0.001)

			yi_ode = ode_sol[:, :, -1, :]
			xi = gt_data[:,i,:].unsqueeze(0)
			
			assert(not torch.isnan(xi).any())
			assert(not torch.isnan(yi_ode).any())
			assert(not torch.isnan(prev_std).any())

			yi, yi_std = self.GRU_update(yi_ode, prev_std, xi)

			yi_change = torch.norm(yi - yi_ode) / torch.norm(yi)

			assert(not torch.isnan(yi).any())
			assert(not torch.isnan(yi_std).any())

			# print("yi")
			# print(torch.isnan(yi).any())
			# print("yi_std")
			# print(torch.isnan(yi_std).any())

			prev_y, prev_std = yi, yi_std			
			prev_t, t_i = gt_time_steps[i],  gt_time_steps[i-1]

			if save_latents:
				latent_ys.append(yi)

			if save_info:
				d = {"yi_ode": yi_ode.detach(), #"yi_from_data": yi_from_data,
					 "yi": yi.detach(), "yi_std": yi_std.detach(), 
					 "time_points": time_points.detach(), "ode_sol": ode_sol.detach(),
					 "hidden_changes": yi_change}
				extra_info.append(d)

		# t0_before_first_tp = False
		# if len(gt_time_steps.size()) == 1: 
		# 	t0_before_first_tp = t0 < gt_time_steps[0]
		# else:
		# 	t0_before_first_tp = any(t0 < gt_time_steps[0])
		
		# if t0_before_first_tp:
		# 	# if we need to produce encoding before the first observation, run ODE from the first observation back to t0
		# 	n_intermediate_tp = 5
		# 	time_points = utils.linspace_vector(gt_time_steps[0], t0-0.01, n_intermediate_tp)

		# 	ode_sol, _ = self.y0_diffeq_solver(yi, time_points)
		# 	assert(torch.sum(ode_sol[:, :, 0, :]  - yi) < 0.001)

		# 	yi = yi_ode = ode_sol[:, :, -1, :]

		# 	if save_latents:
		# 		latent_ys.append(yi)

		# 	d = {"yi_ode": yi_ode.detach(), #"yi_from_data": yi_from_data,
		# 	 "yi": yi.detach(), "yi_std": yi_std.detach(), 
		# 	 "time_points": time_points.detach(), "ode_sol": ode_sol.detach()}
		# 	if save_info:
		# 		extra_info.append(d)

		if save_latents:
			latent_ys = torch.stack(latent_ys, 1)
			assert(not torch.isnan(latent_ys).any())

		assert(not torch.isnan(yi).any())
		assert(not torch.isnan(yi_std).any())

		return yi, yi_std, latent_ys, extra_info









class Encoder_y0_ode_per_attr(nn.Module):
	# for physionet
	# run recognition ODE for each attribute (data dimension) separately, concatenate latent representations,
	# then run another neural net to get y0 for the generative model
	def __init__(self, latent_dim, input_dim, y0_diffeq_solver = None, 
		y0_dim = None, n_gru_units = 100, device = torch.device("cpu")):
		super(Encoder_y0_ode_per_attr, self).__init__()

		self.encoders_per_dim = []
		for i in range(input_dim):
			enc_dim = Encoder_y0_ode_combine(latent_dim, 1, 
				y0_diffeq_solver, n_gru_units = n_gru_units,
				y0_dim = latent_dim, device = device).to(device)
			self.encoders_per_dim.append(enc_dim)

		self.transform_y0 = nn.Sequential(
		   nn.Linear(latent_dim * input_dim * 2, 100),
		   nn.Tanh(),
		   nn.Linear(100, 100),
		   nn.Tanh(),
		   nn.Linear(100, y0_dim * 2),)
		utils.init_network_weights(self.transform_y0)

		self.y0_diffeq_solver = y0_diffeq_solver
		self.latent_dim = latent_dim
		self.input_dim = input_dim
		self.device = device
		self.extra_info = None

		if y0_dim is None:
			self.y0_dim = latent_dim
		else:
			self.y0_dim = y0_dim

	def forward(self, gt_data, gt_time_steps, mask, t0 = None):
		# gt_data, gt_time_steps -- observations and their time stamps
		# t0 -- time stamp of y0 which we are making the encoding for

		n_traj = gt_data.size(0)
		assert(gt_data.size(-1) == self.input_dim)

		assert(mask is not None)
		means_y0, stds_y0 = [], []
		for traj_id in range(n_traj):
			latent_means = []
			latent_stds = []
			for param_id in range(self.input_dim):
				tp_mask = mask[traj_id,:, param_id].byte()

				if torch.sum(tp_mask) == 0.:
					zeros = torch.zeros((1,1,self.latent_dim)).to(self.device)
					latent_means.append(zeros)
					latent_stds.append(zeros)
					continue

				tp_cur_param = gt_time_steps[tp_mask == 1.]
				data_cur_param = gt_data[traj_id, tp_mask == 1., param_id]

				assert (tp_cur_param[1:] > tp_cur_param[:-1]).all()

				encoder = self.encoders_per_dim[param_id]
				means, stds_ = encoder(data_cur_param.reshape(1,-1,1), tp_cur_param, t0=t0)

				latent_means.append(means)
				latent_stds.append(stds_)
	
			latent_means = torch.stack(latent_means, -1).reshape(1,1,-1)
			latent_stds = torch.stack(latent_stds, -1).reshape(1,1,-1)

			mean_y0, std_y0 = utils.split_last_dim( self.transform_y0( torch.cat((latent_means, latent_stds), -1)))

			means_y0.append(mean_y0)
			stds_y0.append(std_y0)
			

		means_y0 = torch.stack(means_y0, 0).reshape(1, n_traj, self.y0_dim)
		stds_y0 = torch.stack(stds_y0, 0).reshape(1, n_traj, self.y0_dim)
		return means_y0, stds_y0


class Encoder_zt(nn.Module):
	def __init__(self, input_dim, device = torch.device("cpu")):
		super(Encoder_zt, self).__init__()
		# encode the data point x to get the inducing point for distribution q.
		
		# for now, use the encoding of points themselves as inducing points 
		# If needed, we could use the context of 3 points. 
		# However, if each inducing point is inferred from the context of 3 points, how to find the time stamp for it?
		self.encoder_zt_gp = nn.Sequential(
		   nn.Linear(input_dim, 100),
		   nn.Tanh(),
		   nn.Linear(100, input_dim * 2),)

		utils.init_network_weights(self.encoder_zt_gp)
	
		self.input_dim = input_dim
		self.device = device

	def forward(self, gt_data, gt_time_steps, mask = None):
		# Encode the input data into inducing points for z(t)
		# output: tuple with inducing points, their stdiance and time stamps (ind_points, ind_point_std, time_steps_ind)
		
		ind_points, std = utils.split_last_dim(self.encoder_zt_gp(gt_data))
		return (ind_points, std, gt_time_steps)


class Encoder_y0_from_x0(nn.Module):
	def __init__(self, latent_dim, input_dim, device = torch.device("cpu")):
		super(Encoder_y0_from_x0, self).__init__()
	
		self.encoder_y0_from_x0 = nn.Linear(input_dim, latent_dim * 2)
		utils.init_network_weights(self.encoder_y0_from_x0)

		self.latent_dim = latent_dim
		self.input_dim = input_dim

		self.device = device

	def forward(self, x0):
		mean, std = utils.split_last_dim(self.encoder_y0_from_x0(x0))
		std = std.abs()
		return mean, std


class Encoder_y0_from_z0(nn.Module):
	def __init__(self, latent_dim, input_dim, device = torch.device("cpu")):
		super(Encoder_y0_from_z0, self).__init__()
		
		self.encoder_y0_from_z0 = nn.Sequential(
		   nn.Linear(input_dim * 2, 50),
		   nn.Tanh(),
		   nn.Linear(50, latent_dim * 2),)

		utils.init_network_weights(self.encoder_y0_from_z0)

		self.latent_dim = latent_dim
		self.input_dim = input_dim
		self.device = device


	def forward(self, z0_mean, z0_std):
		z0_concat = torch.cat((z0_mean, z0_std),1)
		mean, std = utils.split_last_dim(self.encoder_y0_from_z0(z0_concat))
		std = std.abs()
		return mean, std


class Encoder_y0_gp(nn.Module):
	def __init__(self, latent_dim, input_dim, y0_gauss_proc = None, device = torch.device("cpu")):
		super(Encoder_y0_gp, self).__init__()
		"""
		y0_gauss_proc: GP for sampling y0 conditioned on the whole observed trajectory.
		"""
		self.encoder_y0_gp = nn.Sequential(
		   nn.Linear(input_dim, 100),
		   nn.Tanh(),
		   nn.Linear(100, latent_dim * 2),)

		utils.init_network_weights(self.encoder_y0_gp)

		self.y0_gauss_proc = y0_gauss_proc
		self.latent_dim = latent_dim
		self.input_dim = input_dim

		self.device = device

	def forward(self, gt_data, gt_time_steps, y0_time_stamp, mask = None):
		ind_points, std = utils.split_last_dim(self.encoder_y0_gp(gt_data))
		
		n_traj, n_tp, n_dims = gt_data.size()
		assert(self.input_dim == n_dims)

		means_flat, cov_matrix_flat = self.y0_gauss_proc.do_regression_from_ind_points(
			y0_time_stamp, ind_points, gt_time_steps)
		assert(means_flat.size(0) == self.latent_dim * n_traj)

		means_y0 = means_flat.reshape(1, n_traj, self.latent_dim)
		std_y0 = cov_matrix_flat.reshape(1, n_traj, self.latent_dim)
		return means_y0, std_y0

	def get_y0_time_stamp(self, time_steps):
		interval = time_steps.max() - time_steps.min()
		y0_time_stamp = time_steps.min().reshape(1,) - 0.1 * interval
		return y0_time_stamp


class BaseEncoder_y0_independent_yi(nn.Module):
	# Derive y0 by running ode backwards from every y_i to y_0. 
	# Then take the mean and standard dev of all y0 estimates. Sample y0 from this distribution
	def __init__(self, latent_dim, input_dim, y0_diffeq_solver = None, device = torch.device("cpu")):
		super(BaseEncoder_y0_independent_yi, self).__init__()

		self.encoder_yi_from_xi = nn.Sequential(
		   nn.Linear(input_dim * 3, 50),
		   nn.Tanh(),
		   nn.Linear(50, latent_dim * 2),)

		utils.init_network_weights(self.encoder_yi_from_xi)

		self.y0_diffeq_solver = y0_diffeq_solver
		self.latent_dim = latent_dim
		self.input_dim = input_dim
		self.device = device
		self.extra_info = None

	
	def get_y_t0_estimates(self, gt_data, gt_time_steps):
		triplets, triplet_time_steps = make_triplets(gt_data, gt_time_steps)
		triplets = triplets.to(self.device)
		mean_y_tis, std_y_tis = utils.split_last_dim(self.encoder_yi_from_xi(triplets))
		std_y_tis = std_y_tis.abs()

		y_tis = mean_y_tis

		t_0 = gt_time_steps[0]
		y_t0s = []
		# ode_info for plotting 
		extra_info = []
		for i in range(len(triplet_time_steps)):
			y_t_i = y_tis[:, i, :].unsqueeze(0)
			t_i = triplet_time_steps[i]

			# Intermediate points are used for plotting only
			n_intermediate_tp = 20
			time_points = utils.linspace_vector(t_i, t_0, n_intermediate_tp)

			sol, _ = self.y0_diffeq_solver(y_t_i, time_points)
			sol = sol.squeeze(0)

			# append only ODE solution at t_0
			y_t0s.append(sol[:,-1,:])

			d = {"time_points": time_points, "sol": sol}
			extra_info.append(d)

		y_t0s = torch.stack(y_t0s, 1).to(self.device)
		self.extra_info = extra_info

		return y_t0s, std_y_tis




class Encoder_y0_mean_odes_yi(BaseEncoder_y0_independent_yi):
	# Derive y0 by running ode backwards from every y_i to y_0. 
	# Then take the mean and standard dev of all y0 estimates. Sample y0 from this distribution
	def __init__(self, latent_dim, input_dim, y0_diffeq_solver = None, device = torch.device("cpu")):
		super(Encoder_y0_mean_odes_yi, self).__init__(latent_dim, input_dim, y0_diffeq_solver, device)

	def forward(self, gt_data, gt_time_steps, mask = None):
		n_traj, n_tp, n_dims = gt_data.size()
		y_t0s, std_y_tis = self.get_y_t0_estimates(gt_data, gt_time_steps)

		# Mean over the multiple estimates of y0 that we got by going back from y_ti to y0
		means_y0 = torch.mean(y_t0s, 1)

		#std_y0 = torch.std(y_t0s, 1)
		# Assuming that the stdiance is the same from y_tis to y0
		std_y0 = torch.mean(std_y_tis, 1)

		means_y0 = means_y0.reshape(1, n_traj, self.latent_dim)
		std_y0 = std_y0.reshape(1, n_traj, self.latent_dim)
		return means_y0, std_y0



class Encoder_y0_gauss_product_odes_yi(BaseEncoder_y0_independent_yi):
	# Derive y0 by running ode backwards from every y_i to y_0. 
	# Then take the mean and standard dev of all y0 estimates. Sample y0 from this distribution
	def __init__(self, latent_dim, input_dim, y0_diffeq_solver = None, y0_dim = None, device = torch.device("cpu")):
		super(Encoder_y0_gauss_product_odes_yi, self).__init__(latent_dim, input_dim, y0_diffeq_solver, device)

		if y0_dim is None:
			self.y0_dim = latent_dim
		else:
			self.y0_dim = y0_dim

	def forward(self, gt_data, gt_time_steps, mask = None):
		n_traj, n_tp, n_dims = gt_data.size()
		y_t0s, std_y_tis = self.get_y_t0_estimates(gt_data, gt_time_steps)
		# shape y_t0s and std_y_tis: [n_traj, n_tp-2, n_latent_dims] (n_tp-2 is because we are using triplets of time points)

		# Sum over the multiple estimates of y0 that we got by going back from y_ti to y0
		# Product of gaussians: https://ccrma.stanford.edu/~jos/sasp/Product_Two_Gaussian_PDFs.html
		# means_y0 = torch.sum(y_t0s * std_y_tis,1) / torch.sum(std_y_tis,1)
		# std_y0 = 1. / torch.mean(1. / std_y_tis,1)

		std_y0 = 1. / torch.sum(1. / std_y_tis, 1)
		means_y0 = torch.sum(y_t0s / std_y_tis, 1) * std_y0

		means_y0 = means_y0.reshape(1, n_traj, self.latent_dim)
		std_y0 = std_y0.reshape(1, n_traj, self.latent_dim)
		
		return means_y0[:,:,:self.y0_dim], std_y0[:,:,:self.y0_dim]



class Decoder(nn.Module):
	def __init__(self, latent_dim, input_dim, fix_decoder = False):
		super(Decoder, self).__init__()
		# decode data from latent space where we are solving an ODE back to the data space
		
		# Was:
		# self.decoder = nn.Sequential(
		#    nn.Linear(latent_dim, 50),
		#    nn.ReLU(),
		#    nn.Linear(50, input_dim),)
		decoder = nn.Sequential(
		   nn.Linear(latent_dim, input_dim),)

		utils.init_network_weights(decoder)

		if fix_decoder:
			self.decoder = decoder

			for param in self.parameters():
				param.requires_grad = False

			#self.decoder = lambda x: torch.mean(x, -1).unsqueeze(-1)
		else:
			self.decoder = decoder

	def forward(self, data):
		return self.decoder(data)

	def fix_decoder(self):
		for param in self.parameters():
			param.requires_grad = False


# Decoder for density plots
class Decoder_nonlinear(nn.Module):
	def __init__(self, latent_dim, input_dim, fix_decoder = False):
		super(Decoder_nonlinear, self).__init__()
		# decode data from latent space where we are solving an ODE back to the data space
		
		# stdiance sigma in the likelihood should be set to 0.5
		# Sigmoid also creates interesting, smooth likelihood landscape with one local minima
		# Relu creates likelihood with many small local minimas
		self.decoder = nn.Sequential(
		   nn.Linear(latent_dim, 200), # 100 or 200 units
		   nn.Sigmoid(), # Tanh or Sigmoid
		   nn.Linear(200, input_dim),)


		self.scale = nn.Parameter(torch.Tensor(1))
		#nn.init.constant_(self.scale, 2.) for Ay
		nn.init.constant_(self.scale, 5.) #for Ay**3

		if fix_decoder:
			for param in self.parameters():
				param.requires_grad = False

	def forward(self, data):
		return self.decoder(data) * self.scale

	def fix_decoder(self):
		for param in self.parameters():
			param.requires_grad = False



class MixtureOfGaussians(nn.Module):
	def __init__(self, latent_dim, n_components, device = torch.device("cpu")):
		super(MixtureOfGaussians, self).__init__()
	
		self.prior_means = nn.Parameter(torch.Tensor(n_components, latent_dim)).to(device)
		self.prior_stds = nn.Parameter(torch.Tensor(n_components, latent_dim)).to(device)

		nn.init.normal_(self.prior_means,  mean=0, std=0.01)
		nn.init.normal_(self.prior_stds,  mean=0, std=0.01)

		self.n_components = n_components
		self.latent_dim = latent_dim
		self.device = device

	def get_mixtures(self):
			# Assign equal probability to each mixture component
		prob = 1 / self.n_components
		mixtures = torch.tensor([ prob ] * self.n_components)
		return mixtures

	def sample(self, shape):
		mixtures = self.get_mixtures()

		n_samples = np.prod(shape)

		prior_means = self.prior_means.repeat(n_samples, 1, 1)
		prior_stds = self.prior_stds.repeat(n_samples, 1, 1)
		sample = self.sample_mixture_gaussians(mixtures, prior_means, prior_stds)
		sample = sample.reshape(*shape, self.latent_dim)
		return sample


	def sample_mixture_gaussians(self, mixtures, means, stds):
		gaussian_id = Categorical(mixtures).sample()
		sample = utils.sample_standard_gaussian(means[:,gaussian_id], stds[:,gaussian_id].abs())
		return sample


	def log_prob(self, sample):
		mixtures = self.get_mixtures().to(get_device(sample))

		densities = []
		for i in range(self.n_components):
			distr = Normal(self.prior_means[i], self.prior_stds[i].abs())
			density = distr.log_prob(sample)
			densities.append(density)

		densities = torch.stack(densities)

		# shape before: n_componenets, n_q_samples, n_traj, n_latent_dims
		# shape after:  n_q_samples, n_traj, n_latent_dims, n_componenets
		densities = densities.permute(1,2,3,0)

		# sum over mixture componenets
		sample_density = torch.sum(torch.mul(densities, mixtures),-1)
		return sample_density


	def kl_div(self, q_distr, q_sample):
		# Estimate KL divergence between a gaussian and a mixture of gaussians
		q_sample_density = q_distr.log_prob(q_sample)
		prior_sample_density = self.log_prob(q_sample)
		return q_sample_density - prior_sample_density



