import os
import numpy as np

import torch
import torch.nn as nn
from torch.nn.functional import relu

import lib.utils as utils
from lib.sdevae import ODEVAE
from lib.encoder_decoder import *
from lib.diffeq_solver import DiffeqSolver

from lib.gaussian_process import PytorchGaussianProcess, GaussianProcess
from torch.distributions.normal import Normal
from lib.ode_func import ODEFunc_one_ODE, ODEFunc_ODEmixture, ODEFunc_spiral, ODEFunc_LotkaVolterra, ODEFunc_w_Poisson

#####################################################################################################
# Constants

# !!! if we are learning kernel_param these constants are not used
# parameter of kernel for GP Wiener process prior
prior_kernel_param = 1.
# parameter of kernel for GP of distribution Q
q_kernel_param = 1.

#####################################################################################################

class ODEMixtures(nn.Module):
	def __init__(self, latent_dim, n_ode_mixtures):
		super(ODEMixtures, self).__init__()
		
		self.mixture_nn = nn.Sequential(
		   nn.Linear(latent_dim, 100),
		   nn.ReLU(),
		   nn.Linear(100, n_ode_mixtures),
		   nn.Softmax(),)

		utils.init_network_weights(self.mixture_nn)

	def forward(self, data):
		return self.mixture_nn(data)

class KernelParamModule(nn.Module):
	def __init__(self):
		super(KernelParamModule, self).__init__()
		self.kernel_param = nn.Parameter(torch.Tensor(1))
		nn.init.uniform_(self.kernel_param, a=0., b=2.)

	def forward(self):
		return self.kernel_param


class StdParamModule(nn.Module):
	def __init__(self):
		super(StdParamModule, self).__init__()
		self.std_param = nn.Parameter(torch.Tensor(1))
		nn.init.uniform_(self.std_param, a=0., b=1.)

	def forward(self):
		return self.std_param

#####################################################################################################

def create_ODEVAE_model(args, input_dim, obsrv_std, device, 
	glob_dims = 0, classif_per_tp = False, n_labels = 1,
	output_dim = None):
	# Create SDE-VAE model
	kernel_param_module_zt = KernelParamModule().to(device)

	q_gaussian_process = None
	gp_prior = None

	if args.zt_gp:
		q_gaussian_process = PytorchGaussianProcess(
			#kernel_param = q_kernel_param, 
			kernel_param = kernel_param_module_zt(),
			sigma_obs = args.qsigma, 
			kernel = args.q, device = device)

		gp_prior = PytorchGaussianProcess(
			#kernel_param = prior_kernel_param, 
			kernel_param = kernel_param_module_zt(),
			sigma_obs = args.psigma, 
			kernel = args.p, device = device)

	# Use one ODE or a mixture of ODEs 
	# Makes more sense to use with y0 inferred from GP. In this case, y0 depends on the whole time series
	mixture_nn = None

	if args.decoder == "spiral":
		gen_ode_func = ODEFunc_spiral(
			input_dim = input_dim, 
			latent_dim = args.latents, 
			device = device)

	elif args.decoder == "lotkavol":
		gen_ode_func = ODEFunc_LotkaVolterra(
			input_dim = input_dim, 
			latent_dim = args.latents, 
			device = device)
	elif args.n_odes is None:
		dim = args.latents if q_gaussian_process is None else (args.latents + input_dim)
		if args.poisson:
			lambda_net = utils.create_net(dim, input_dim, 
				n_layers = args.poisson_layers, n_units = args.units, nonlinear = nn.Tanh)

			# ODE function produces the gradient for latent state and for poisson rate
			ode_func_net = utils.create_net(dim * 2, args.latents * 2, 
				n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh)

			gen_ode_func = ODEFunc_w_Poisson(
				input_dim = input_dim, 
				latent_dim = args.latents * 2,
				ode_func_net = ode_func_net,
				lambda_net = lambda_net,
				q_gaussian_process = q_gaussian_process,
				gp_prior = gp_prior,
				fix_decoder = args.fix_gen,
				device = device).to(device)
		else:
			dim = args.latents if q_gaussian_process is None else (args.latents + input_dim)
			ode_func_net = utils.create_net(dim, args.latents, 
				n_layers = args.gen_layers, n_units = args.units, nonlinear = nn.Tanh)

			gen_ode_func = ODEFunc_one_ODE(
				input_dim = input_dim, 
				latent_dim = args.latents, 
				ode_func_net = ode_func_net,
				q_gaussian_process = q_gaussian_process,
				gp_prior = gp_prior,
				fix_decoder = args.fix_gen,
				device = device).to(device)
	else:
		gen_ode_func = ODEFunc_ODEmixture(
			input_dim = input_dim, 
			latent_dim = args.latents, 
			q_gaussian_process = q_gaussian_process,
			gp_prior = gp_prior,
			n_ode_dynamics = args.n_odes,
			device = device).to(device)
		mixture_nn = ODEMixtures(args.latents, args.n_odes)

	if args.dataset == "pickle" and args.restore_gen:
		res = utils.get_item_from_pickle(args.datafile, "gen_ode_func")

		if res is not None:
			print("Loading gen_ode_func from pickle...")
			gen_ode_func = res
		else:
			raise Exception("ERROR: decoder is not in pickle file")

	encoder_zt = Encoder_zt(input_dim, device).to(device)
	encoder_y0 = Encoder_y0_from_x0(args.latents, input_dim, device).to(device)

	# We want to learn the parameter of the gaussian kernel but share it between q and p 
	kernel_param_module_y0 = KernelParamModule().to(device)

	y0_gauss_proc = None
	if args.y0_gp:
		y0_gauss_proc = PytorchGaussianProcess(
			kernel_param = kernel_param_module_y0(),
			sigma_obs = 0.1, 
			kernel = "OU", device = device)

		encoder_y0 = Encoder_y0_gp(args.latents, input_dim, y0_gauss_proc, device).to(device)

	y0_diffeq_solver = None
	n_rec_dims = args.rec_dims
	enc_input_dim = int(input_dim + glob_dims) * 2 # we concatenate the mask
	gen_data_dim = input_dim
	if output_dim is not None:
		gen_data_dim = output_dim

	y0_dim = args.latents
	if args.poisson:
		y0_dim += args.latents # predict the initial poisson rate

	if args.y0_ode_mean or args.y0_ode_combine or args.y0_ode_prod or args.y0_ode_per_attr or args.y0_ode_sparse:
		if args.reuse_ode:
			rec_ode_func = gen_ode_func
			n_rec_dims = args.latents
		else:
			ode_func_net = utils.create_net(n_rec_dims, n_rec_dims, 
				n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh)

			rec_ode_func = ODEFunc_one_ODE(
				input_dim = enc_input_dim, 
				latent_dim = n_rec_dims,
				ode_func_net = ode_func_net,
				q_gaussian_process = None,
				gp_prior = None,
				save_ode_ts = False,
				device = device).to(device)

		y0_diffeq_solver = DiffeqSolver(enc_input_dim, rec_ode_func, "euler", args.latents, 
			odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)
		
		if args.y0_ode_mean:
			encoder_y0 = Encoder_y0_mean_odes_yi(n_rec_dims, enc_input_dim, y0_diffeq_solver, device = device).to(device)
		if args.y0_ode_prod:
			encoder_y0 = Encoder_y0_gauss_product_odes_yi(n_rec_dims, enc_input_dim, y0_diffeq_solver, 
				y0_dim = y0_dim, device = device).to(device)
		if args.y0_ode_combine:
			
			encoder_y0 = Encoder_y0_ode_combine(n_rec_dims, enc_input_dim, y0_diffeq_solver, 
				y0_dim = y0_dim, n_gru_units = args.gru_units, device = device).to(device)
		if args.y0_ode_per_attr:
			encoder_y0 = Encoder_y0_ode_per_attr(n_rec_dims, input_dim, y0_diffeq_solver, 
				y0_dim = y0_dim, device = device).to(device)
		if args.y0_ode_sparse:
			encoder_y0 = Encoder_y0_ode_sparse(n_rec_dims, enc_input_dim, y0_diffeq_solver, 
				y0_dim = y0_dim, n_gru_units = args.gru_units, device = device).to(device)

	if args.y0_rnn:
		concat_mask = True#(args.dataset == "physionet")
		encoder_y0 = Encoder_y0_from_rnn(y0_dim, enc_input_dim,
			lstm_output_size = n_rec_dims, device = device).to(device)

	# t0_gauss_proc = PytorchGaussianProcess(
	# 	kernel_param = kernel_param_module(),
	# 	sigma_obs = 0.1, 
	# 	kernel = "OU", device = device)

	if args.decoder == "spiral" or args.decoder == "pickle":
		decoder = Decoder_nonlinear(args.latents, gen_data_dim, fix_decoder = args.fix_gen).to(device)
	else:
		decoder = Decoder(args.latents, gen_data_dim, fix_decoder = args.fix_gen).to(device)
	
	if args.dataset == "pickle" and args.restore_gen:
		res = utils.get_item_from_pickle(args.datafile, "decoder")
		if res is not None:
			print("Loading decoder from pickle...")
			decoder.load_state_dict(res)
			decoder.fix_decoder()

	diffeq_solver = DiffeqSolver(gen_data_dim, gen_ode_func, args.method, args.latents, 
		odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)

	if args.dataset == "pickle" and args.restore_gen:
		y0_prior = utils.get_item_from_pickle(args.datafile, "y0_prior")
	if args.y0_prior == "standard":
		y0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
	elif args.y0_prior == "mix":
		n_components = args.mix_components
		y0_prior = MixtureOfGaussians(args.latents, n_components, device = device)
	else:
		raise Exception("Unknown y0 prior: {}".format(args.y0_prior))

	model = ODEVAE(
		input_dim = gen_data_dim, 
		latent_dim = args.latents, 
		encoder_y0 = encoder_y0, 
		decoder = decoder, 
		diffeq_solver = diffeq_solver, 
		y0_prior = y0_prior, 
		device = device,
		obsrv_std = obsrv_std,
		use_poisson_proc = args.poisson, 
		use_binary_classif = args.classif,
		linear_classifier = args.linear_classif,
		classif_per_tp = classif_per_tp,
		n_labels = n_labels).to(device)

	return model
