import os
import gc
import sys
import matplotlib
import matplotlib.pyplot
import matplotlib.pyplot as plt

import time
import datetime
import argparse
import numpy as np
from random import SystemRandom
from sklearn import model_selection

import torch
import torch.nn as nn
from torch.nn.functional import relu
import torch.optim as optim

import lib.utils as utils
from lib.plotting import *

from lib.rnn_baselines import *
from lib.ode_gru import *
from lib.create_sdevae_model import create_ODEVAE_model, StdParamModule
from lib.parse_datasets import parse_datasets
from lib.ode_func import ODEFunc_one_ODE, ODEFunc_w_Poisson
from lib.diffeq_solver import DiffeqSolver
from mujoco_physics import HopperPhysics

import sklearn as sk
import pandas as pd
from sklearn.neural_network import MLPClassifier

from lib.utils import compute_loss_all_batches

# Generative model for noisy data based on ODE
parser = argparse.ArgumentParser('ODEVAE')
parser.add_argument('-t', '--timepoints', type=int, default=100, help="Total number of time-points")
# parser.add_argument('-e', '--extrap', type=int, default=100, 
# 	help="Extra time points to extrapolate to. Max-t is extrapolated respectively according to the number of time points.")
parser.add_argument('--niters', type=int, default=300)
parser.add_argument('--viz', action='store_true')
parser.add_argument('--cuda', action='store_true')
parser.add_argument('--method', default='dopri5', help="Type of ODE SOlver")
parser.add_argument('--lr',  type=float, default=1e-2, help="Starting learning rate.")

#parser.add_argument('--noise-weight',  type=float, default=1., help="Coefficient for the noise that is added to the ground truth dataset")
parser.add_argument('--n-traj-samples', type=int, default=3)
parser.add_argument('-l', '--latents', type=int, default=6)

# parser.add_argument('-q', default="WienerProcess", help="GP kernel for approx posterior Q. Available kernels: WienerProcess, RBF, OU")
# parser.add_argument('-p', default="WienerProcess", help="GP kernel for prior P. Available kernels: WienerProcess, RBF, OU")

# parser.add_argument('--qsigma', type=float, default=0.1, help="Variance for approx posterior Q")
# parser.add_argument('--psigma', type=float, default=0.1, help="Variance for prior P")

parser.add_argument('--kl', type=float, default=1., help="KL coefficient")
# not used
parser.add_argument('--fit-kl', action='store_true', help="Perform sanity check: can we make KL divergence go to zero?")

parser.add_argument('-n',  type=int, default=100)
parser.add_argument('--max-t',  type=float, default=1., help="We subsample points in the interval [0, args.max_tp]")
parser.add_argument('-b', '--batch-size', type=int, default=50)
parser.add_argument('-s', '--sample-tp', type=int, default=None, help="How many points to sub-sample. If None, subsample random number of points")
parser.add_argument('--save', type=str, default='experiments/', help="Path for save checkpoints")
parser.add_argument('--load', type=str, default=None, help="ID of the experiment to load for evaluation. If None, run a new experiment.")

parser.add_argument('--dataset', type=str, default='parabola', 
	help="Dataset to load. Available: physionet, activity, hopper, parabola, periodic, bimodal, periodic2s, periodic2d, periodic3w, zigzag, house, airquality, pickle."
	"pickle: Load dataset from pickle file provided in --datafile argument")
parser.add_argument('--datafile', type=str, default=None) 

# not used
parser.add_argument('--n-odes', type=int, default=None, help="Number of ODEs")
# not used
parser.add_argument('--zt-gp',  action='store_true', help="Use z(t) GP. If this parameter is NOT set, "
	"ODE depends only on y0 and samples from z(t) are not concatenated to ODE function.")

# not used
parser.add_argument('--y0-gp',  action='store_true', help="Use y0 derived through GP. ")
# not used
parser.add_argument('--y0-ode-mean',  action='store_true', help="Use y0 derived from ODE. y0 is a mean of y0 estimates from ODEs run backwards from different yis to y0.")
# not used
parser.add_argument('--y0-ode-prod',  action='store_true', help="Use y0 derived from ODE. y0 is a product of gaussians from different y_tis. The estimate of y0 from each y_ti is derived by running ODE backwards yi to y0.")

parser.add_argument('--y0-ode-combine',  action='store_true', help="Use y0 derived from ODE. y0 is derived by running ODE backwards and combining estimates of different yis through gating.")
parser.add_argument('--y0-ode-per-attr',  action='store_true', help="Use y0 derived from ODE. Run ode-combine model for each attribute (dimension) of the input data.")
parser.add_argument('--y0-rnn',  action='store_true', help="Use RNN to derive y0")

# not used
parser.add_argument('--y0-ode-sparse',  action='store_true', help="Use y0 derived from ODE. Run one ODE for all attributes and have GRU unit for each attribute to make and update for ODE state")

parser.add_argument('--y0-prior', default="standard", help="Type of y0 prior: standard gaussian ('standard') or mixture of gaussians ('mix')")
parser.add_argument('-m', '--mix-components', type=int, default=1, help="Number of mixture components, if using mixture of gaussians y0 prior ('--y0-prior mix')")

# not used
parser.add_argument('--reuse-ode',  action='store_true', help="Use the same ODE function for recognition and generative parts")


parser.add_argument('--fix-gen', action='store_true', help="Do not train the generative model -- fix the weights.")
parser.add_argument('--decoder', type=str, default='nn', help="Type of decoder: neural net ('nn') or a decoder with a fixed ODE (e.g. 'spiral')")

parser.add_argument('--rnn-seq2seq', action='store_true', help="Run RNN baseline: seq2seq style with encoder and decoder")
parser.add_argument('--classic-rnn', action='store_true', help="Run RNN baseline: classic RNN that sees true points at every point. Used for interpolation only.")
parser.add_argument('--rnn-vae', action='store_true', help="Run RNN baseline: seq2seq model with sampling of the h0 and ELBO loss.")
parser.add_argument('--rnn-cell', default="gru", help="RNN Cell type. Available: gru (default), expdecay")
parser.add_argument('--input-decay', action='store_true', help="For RNN: use the input that is the weighted average of impirical mean and previous value (like in GRU-D)")

parser.add_argument('--ode-gru', action='store_true', help="Run ODE-RNN baseline: seq2seq style with encoder and decoder")
parser.add_argument('--ode-gru-rnn', action='store_true', help="Run ODE-RNN baseline: RNN-style that sees true points at every point. Used for interpolation only.")

parser.add_argument('--rec-layers', type=int, default=1, help="Number of layers in ODE func in recognition ODE")
parser.add_argument('--gen-layers', type=int, default=1, help="Number of layers in ODE func in generative ODE")

# not used
parser.add_argument('--poisson-layers', type=int, default=1, help="Number of layers predicting lambda (rate for poisson process)")

parser.add_argument('-u', '--units', type=int, default=100, help="Number of units per layer in ODE func")
parser.add_argument('-g', '--gru-units', type=int, default=100, help="Number of units per layer in each of GRU update networks")

parser.add_argument('--rec-dims', type=int, default=20, help="Dimensionality of the recognition model (ODE or RNN)."
	"If --reuse-ode is used, --rec-dims parameter is overwritten and set to be equal to '--latents, i.e. number of dimentions in generative ODE.")

parser.add_argument('--restore-gen', action='store_true', help="Restore parameters of generative model and y0 prior from pickle file."
	"Useful if want to fix the weights of decoder to the ones that generated the data.")

parser.add_argument('--dim-only', type=int, default=None, help="Train only on specified number of dimensions in the dataset."
	" If None, train on all dimensions. Used for AirQuality dataset.")

parser.add_argument('--obsrv-std', type=float, default=0.02, help="Std for computing the likelihood of the data")
parser.add_argument('--obsrv-std-learn', action='store_true', help="Treat Std of data likelihood as a parameter")

parser.add_argument('--poisson', action='store_true', help="Model poisson-process likelihood for the density of events in addition to reconstruction.")
parser.add_argument('--classif', action='store_true', help="Include binary classification loss -- used for Physionet dataset for hospiral mortality")

parser.add_argument('--linear-classif', action='store_true', help="If using a classifier, use a linear classifier instead of 1-layer NN")
parser.add_argument('--extrap', action='store_true', help="Set extrapolation mode. If this flag is not set, run interpolation mode.")
parser.add_argument('--extrap-test', action='store_true', help="Do extrapolation at test time. For RNN it means training to do reconstruction but do exrapolation at test time.")

parser.add_argument('--noise-weight', type=float, default=0.1, help="Noise amplitude for generated traejctories")
parser.add_argument('--name', type=str, default="", help="Reason for running the experiment. Not used.")

parser.add_argument('--cv', type=int, default=5, help="Use cross-validation or classification. The parameter specifies number of folds.")
parser.add_argument('--fold', type=int, default=0, help="Fold id for cross-validation.")

parser.add_argument('--do-not-store', action='store_true', help="Test run. Do not store results")

parser.add_argument('--concat-globs', action='store_true', help="Condition the model on time-invariant features. Used for Physionet. "
	"'train_globs' and 'test_globs' attributes must be present in data dictionary")

parser.add_argument('--quantization', type=float, default=0.1, help="Quantization on the physionet dataset. Value 1 means quantization by 1 hour, value 0.1 means quantization by 0.1 hour = 6 min")
parser.add_argument('--predict-from-labels',  action='store_true', help="Predict the sequence from labels. Tried on Human Activity dataset.")

args = parser.parse_args()

#if args.cuda:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#else:
#	device = torch.device("cpu")

file_name = os.path.basename(__file__)[:-3]

utils.makedirs(args.save)
utils.makedirs("plots/cond_on_ind_points/")
utils.makedirs("plots/samples_same_traj/")

#####################################################################################################

import sys, psutil, gc, os
def memReport():
	total_elems = 0.
	for obj in gc.get_objects():
		if torch.is_tensor(obj):
			total_elems += np.prod(list(obj.size()))
			#print(type(obj), obj.size())
	return total_elems
	
def cpuStats():
		# print(sys.version)
		# print(psutil.cpu_percent())
		# print(psutil.virtual_memory())  # physical memory usage
		pid = os.getpid()
		py = psutil.Process(pid)
		memoryUse = py.memory_info()[0] / 2. ** 30  # memory use in GB...I think
		print('memory GB:', memoryUse)



def get_model_name(args):
	model_name = ""
	if args.rnn_seq2seq:
		model_name = "rnn_seq2seq"
	if args.classic_rnn:
		model_name = "classic_rnn"
	if args.rnn_vae:
		model_name = "rnn_vae"

	if args.rnn_seq2seq or args.classic_rnn or args.rnn_vae:
		model_name += "_cell_" + args.rnn_cell

	if args.ode_gru:
		model_name = "ode_gru_seq2seq"
	if args.ode_gru_rnn:
		model_name = "ode_gru_rnn"

	if args.y0_rnn:
		model_name = "y0_rnn"
	if args.y0_ode_per_attr:
		model_name = "y0_ode_per_attr"
	if args.y0_ode_combine:
		model_name = "y0_ode_combine"
	
	if args.poisson:
		model_name += "_poisson"

	if args.input_decay:
			model_name += "_input_decay"

	if model_name == "":
		raise Exception("Model name not found")
	return model_name






def store_best_results(args, train_dict, test_dict, experimentID, res_files):
	if args.do_not_store:
		print("WARNING: Results are not stored!!")
		return

	# ################## TRAIN
	# if args.dataset == "physionet" or args.dataset == "activity":
	# 	train_res = compute_loss_all_batches(
	# 		model, train_dict, args, experimentID = experimentID, n_traj_samples = args.n_traj_samples, kl_coef = kl_coef,
	# 		max_samples_for_eval = 1000)
	# else:
	# 	# Compute the likelihood on the unseen time series. GP z(t) is conditioned on sub-sampled observations.
	# 	train_res  = model.compute_all_losses(
	# 		train_dict, n_traj_samples = args.n_traj_samples, 
	# 		kl_coef = kl_coef)


	# print("Experiment " + str(experimentID))
	# message = '[TRAIN] Loss {:.6f} | Likelihood {:.6f} | MSE {:.6f} | KL fp {:.4f} | FP STD {:.4f}| # time points {} |'.format(
	# 	train_res["loss"], train_res["likelihood"], train_res["mse"], train_res["kl_first_p"], train_res["std_first_p"], n_timepoints)
	# print(message)

	# if "auc" in train_res:
	# 	print("Classification AUC (train): {:.4f}".format(train_res["auc"]))

	# if "accuracy" in train_res:
	# 	print("Classification accuracy (train): {:.4f}".format(train_res["accuracy"]))

	# if args.poisson and "pois_likelihood" in train_res:
	# 	print("Poisson likelihood: {}".format(train_res["pois_likelihood"]))

	# if "ce_loss" in train_res:
	# 	print("CE loss: {}".format(train_res["ce_loss"]))


	# ################## TEST
	if args.dataset == "physionet" or args.dataset == "activity":
		test_res = compute_loss_all_batches(
			model, test_dict, args, experimentID = experimentID, 
			n_traj_samples = args.n_traj_samples, kl_coef = kl_coef)
	else:
		# Compute the likelihood on the unseen time series. GP z(t) is conditioned on sub-sampled observations.
		test_res  = model.compute_all_losses(
			test_dict,
			n_traj_samples = args.n_traj_samples, 
			kl_coef = kl_coef)


	print("Experiment " + str(experimentID))
	message = '[TEST] Loss {:.6f} | Likelihood {:.6f} | MSE {:.6f} | KL fp {:.4f} | FP STD {:.4f}| # time points {} |'.format(
		test_res["loss"], test_res["likelihood"], test_res["mse"], test_res["kl_first_p"], test_res["std_first_p"], n_timepoints)
	print(message)

	if "auc" in test_res:
		print("Classification AUC (test, fold {}): {:.4f}".format(args.fold, test_res["auc"]))

	if "accuracy" in test_res:
		print("Classification accuracy (test, fold {}): {:.4f}".format(args.fold, test_res["accuracy"]))

	if args.poisson and "pois_likelihood" in test_res:
		print("Poisson likelihood: {}".format(test_res["pois_likelihood"]))

	if "ce_loss" in test_res:
		print("CE loss: {}".format(test_res["ce_loss"]))


	###### Update table with results
	res_dict, value_updated = utils.update_metric_if_larger(res_files["results_file_name"], model_name, 
		"neg_MSE", -test_res["mse"].cpu().numpy())
	
	if value_updated:
		res_dict = utils.update_value(res_files["best_experim_file"], model_name, 
			"neg_MSE", experimentID)

		with open(res_files["script_run_best"], "a") as res_file:
			res_file.write("\n ./submit_job.sh python3 " + input_command)

	metric_name = str(args.sample_tp)
	res_dict, value_updated = utils.update_metric_if_larger(res_files["results_per_missing_rate"], model_name, 
		metric_name, -test_res["mse"].cpu().numpy())
	
	if value_updated:
		res_dict = utils.update_value(res_files["best_per_missing_rate"], model_name, 
			metric_name, experimentID)

		with open(res_files["script_run_best_per_missing_rate"], "a") as res_file:
			res_file.write("\n ./submit_job.sh python3 " + input_command)

	if "auc" in test_res:
		res_dict, value_updated = utils.update_metric_if_larger(res_files["classif_results"], model_name, 
			"auc_fold" + str(args.fold), test_res["auc"])
		
		if value_updated:
			res_dict = utils.update_value(res_files["classif_experim"], model_name, 
				"auc_fold" + str(args.fold), experimentID)

			with open(res_files["script_run_best_classif"], "a") as res_file:
				res_file.write("\n ./submit_job.sh python3 " + input_command)

	if "accuracy" in test_res:
		res_dict, value_updated = utils.update_metric_if_larger(res_files["classif_results"], model_name, 
			"accuracy_fold" + str(args.fold), test_res["accuracy"])
		
		if value_updated:
			res_dict = utils.update_value(res_files["classif_experim"], model_name, 
				"accuracy_fold" + str(args.fold), experimentID)

			with open(res_files["script_run_best_classif"], "a") as res_file:
				res_file.write("\n ./submit_job.sh python3 " + input_command)


	if args.poisson and "pois_likelihood" in test_res:
		res_dict, value_updated = utils.update_metric_if_larger(res_files["poisson_results"], model_name, 
			"pois_likelihood", test_res["pois_likelihood"])
		
		if value_updated:
			res_dict = utils.update_value(res_files["poisson_experim"], model_name, "pois_likelihood", experimentID)

			with open(res_files["script_run_best_poisson"], "a") as res_file:
				res_file.write("\n ./submit_job.sh python3 " + input_command)



def save_latents(model, data_dict, experimentID = 0., name = ""):
	batch_dict = utils.get_next_batch(data_dict, 0, 10000)

	data =  batch_dict["data_to_predict"]
	time_steps = batch_dict["tp_to_predict"]
	mask = batch_dict["mask_predicted_data"]
	
	observed_data =  batch_dict["observed_data"]
	observed_time_steps = batch_dict["observed_tp"]
	observed_mask = batch_dict["observed_mask"]

	labels =  batch_dict["labels"]

	reconstructions, info = model.get_reconstruction(
		time_steps, observed_data, observed_time_steps,
		mask = observed_mask,
		n_traj_samples = 1)
	reconstructions = reconstructions.squeeze(1)

	first_point_mu, first_point_std, first_point_enc = info["first_point"]
	n_latent_dims = first_point_enc.size(-1)

	info_to_save = {"first_point_enc": first_point_enc,
		"labels": labels,
		"latent_traj": info["latent_traj"]}

	pickle_file = "plots/" + str(experimentID) + "/latents_" + \
		str(experimentID) + "_" + name + ".pickle"
	utils.dump_pickle(info_to_save, pickle_file)






def run_classif_on_latents(model, experimentID = 0):
	pickle_file = "plots/" + str(experimentID) + "/latents_" + \
		str(experimentID) + "_train.pickle"
	train_info = utils.load_pickle(pickle_file)


	pickle_file = "plots/" + str(experimentID) + "/latents_" + \
		str(experimentID) + "_test.pickle"
	test_info = utils.load_pickle(pickle_file)


	# info = {"first_point_enc": first_point_enc,
	# 	"labels": data_dict["observed_data"]}
	train_data, test_data = train_info["first_point_enc"][0].cpu(), test_info["first_point_enc"][0].cpu()
	train_labels, test_labels = train_info["labels"], test_info["labels"]

	print("data")
	print(train_data.size())
	print(train_labels.size())



	# logreg = sk.linear_model.LogisticRegression(
	# 	C=1e5, solver = 'lbfgs',
	# 	max_iter=10000, class_weight = 'balanced')
	for n_units in range(100, 1000, 200):
		for lr in range(1, 5):
			logreg = MLPClassifier(solver = 'adam', alpha=1e-5,
				learning_rate_init=(1 / 10**lr),
				hidden_layer_sizes=(train_data.size(-1), n_units, n_units, 2), 
				random_state=42, max_iter = 1000)

			# Half of the samples are missing a label
			idx_not_nan = 1 - torch.isnan(train_labels)

			if torch.sum(idx_not_nan) == 0:
				raise Exception("No training examples -- all nans")

			# Create an instance of Logistic Regression Classifier and fit the data.
			logreg.fit(train_data[idx_not_nan], train_labels[idx_not_nan].cpu())

			idx_not_nan = 1 - torch.isnan(test_labels)
			if torch.sum(idx_not_nan) == 0:
				raise Exception("No test examples -- all nans")

			test_pred = logreg.predict_proba(test_data[idx_not_nan])[:,1]
			test_labels_non_missing = test_labels[idx_not_nan].cpu().numpy()

			test_auc = sk.metrics.roc_auc_score(test_labels_non_missing, test_pred)
			print(n_units)
			print(lr)
			print("Classification AUC on pre-trained latents: {}".format(test_auc))
	exit()


	# !!!!!!!! save scores to the table



if __name__ == '__main__':
	torch.manual_seed(1992)
	np.random.seed(1992)

	starting_time = datetime.datetime.today().strftime('%Y-%m-%d') + " " + datetime.datetime.now().time().strftime("%H:%M:%S")

	experimentID = args.load
	if experimentID is None:
		# Make a new experiment ID
		experimentID = int(SystemRandom().random()*100000)
	ckpt_path = os.path.join(args.save, "experiment_" + str(experimentID) + '.ckpt')

	start = time.time()
	print("Sampling dataset of {} training examples".format(args.n))
	
	input_command = sys.argv
	ind = [i for i in range(len(input_command)) if input_command[i] == "--load"]
	if len(ind) == 1:
		ind = ind[0]
		input_command = input_command[:ind] + input_command[(ind+2):]
	input_command = " ".join(input_command)

	# Regular data points
	#time_steps = torch.Tensor().new_tensor(np.arange(args.timepoints).astype(float)).to(device) / args.timepoints

	dataset_name = args.dataset
	if args.dataset == "pickle":
		name = utils.get_item_from_pickle(args.datafile, "dataset_name")
	if args.dataset == "physionet":
		dataset_name += "_" + str(args.quantization)
	file_name = file_name + "_" + dataset_name

	data_set_up = "interp"
	if args.extrap:
		data_set_up = "extrap"
	elif args.extrap_test:
		data_set_up = "extrap_test"
	else:
		data_set_up = "interp"

	fill_str = "{}_{}".format(dataset_name, data_set_up)
	
	utils.makedirs("results/")
	res_files = {
		"results_file_name": "results/results_" + fill_str + ".csv",
		"best_experim_file": "results/experiments_" + fill_str + ".csv",

		"results_per_missing_rate": "results/results_" + fill_str + "_n_subsampled_points_likelihood.csv",
		"best_per_missing_rate": "results/experiments_" + fill_str + "_n_subsampled_points_likelihood.csv",
		
		"classif_results": "results/results_{}_classif.csv".format(dataset_name),
		"classif_experim": "results/experim_{}_classif.csv".format(dataset_name),

		"poisson_results": "results/results_{}_poisson.csv".format(dataset_name),
		"poisson_experim": "results/experim_{}_poisson.csv".format(dataset_name),

		"script_run_best": "results/run_" + fill_str + ".sh",
		"script_run_best_per_missing_rate": "results/run_" + fill_str + "_n_subsampled_points_likelihood.sh",
		"script_run_best_classif": "results/run_{}_classif.sh".format(dataset_name),
		"script_run_best_poisson": "results/run_{}_poisson.sh".format(dataset_name),

		"pred_from_labels_results": "results/results_{}_pred_from_labels.csv".format(dataset_name),
		"pred_from_labels_experim": "results/experim_{}_pred_from_labels.csv".format(dataset_name)
	}


	model_name = get_model_name(args)
	metric_name = "Likelihood"

	density_plot_dir = "plots/y0_density_plots/" + str(experimentID) + "/"
	##################################################################
	
	extrapolation_mode = args.extrap
	data_dict = parse_datasets(args, device)




	# # #Train MLP classifier on the mean value
	# # #Perform 5-fold cross validation and compute AUC on classification task
	# train_data = torch.cat((data_dict["train_globs"].cpu(), torch.mean(data_dict["train_y"].cpu().float(), 1)) ,-1)
	# test_data = torch.cat((data_dict["test_globs"].cpu(), torch.mean(data_dict["test_y"].cpu().float(), 1)) ,-1)


	# all_data = torch.cat((train_data, test_data),0)
	# all_labels = torch.cat((data_dict["train_labels"], data_dict["test_labels"]),0)

	# n_tr_ex = len(all_labels)

	# from sklearn.neural_network import MLPClassifier
	# from sklearn.model_selection import KFold
	# kf = KFold(n_splits=5, shuffle=True, random_state=42)

	# #logreg = sk.linear_model.LogisticRegression(C=1e5, solver = 'lbfgs')
	# test_auc_list = []
	# for i in range(1,10):

	# 	labels = []
	# 	pred = []

	# 	print("======================")
	# 	for train_index, test_index in kf.split(all_data):
	# 		train_data, test_data = all_data[train_index], all_data[test_index]
	# 		train_labels, test_labels = all_labels[train_index], all_labels[test_index]

	# 		logreg = sk.linear_model.LogisticRegression(C=1e5, solver = 'lbfgs',
	# 			max_iter=10000, class_weight = 'balanced')
	# 		# logreg = MLPClassifier(solver = 'adam', alpha=1e-5,
	# 		# 	learning_rate_init=0.0001,
	# 		# 	hidden_layer_sizes=(train_data.size(-1), 500, 2), 
	# 		# 	random_state=i, max_iter = 1000)

	# 		# Half of the samples are missing a label
	# 		idx_not_nan = 1 - torch.isnan(train_labels)

	# 		if torch.sum(idx_not_nan) == 0:
	# 			continue

	# 		# Create an instance of Logistic Regression Classifier and fit the data.
	# 		logreg.fit(train_data[idx_not_nan], train_labels[idx_not_nan].cpu())

	# 		train_pred = logreg.predict_proba(train_data[idx_not_nan])[:,1]
	# 		print("Train auc {}".format(sk.metrics.roc_auc_score(train_labels[idx_not_nan].cpu(), train_pred)))

	# 		idx_not_nan = 1 - torch.isnan(test_labels)
	# 		if torch.sum(idx_not_nan) == 0:
	# 			continue

	# 		test_pred = logreg.predict_proba(test_data[idx_not_nan])[:,1]
		
	# 		labels.append(test_labels[idx_not_nan].cpu())
	# 		pred.append(test_pred)

	# 	labels = torch.cat(labels).numpy()
	# 	pred = np.concatenate(pred)

	# 	test_auc = sk.metrics.roc_auc_score(labels, pred)
	# 	print("Test auc {}".format(test_auc))
	# 	test_auc_list.append(test_auc)
	# print(np.mean(test_auc_list))
	# print(np.std(test_auc_list))
	# print(np.max(test_auc_list))
	# exit()

	if args.extrap:
		train_dict = utils.split_data_extrap(data_dict, "train", concat_globs = args.concat_globs, dataset = args.dataset)
		test_dict = utils.split_data_extrap(data_dict, "test", concat_globs = args.concat_globs, dataset = args.dataset)
	elif args.extrap_test:
		train_dict = utils.split_data_interp(data_dict, "train", concat_globs = args.concat_globs)
		test_dict = utils.split_data_extrap(data_dict, "test", concat_globs = args.concat_globs, dataset = args.dataset)
	elif args.predict_from_labels:
		train_dict = utils.split_data_for_pred_from_labels(data_dict, "train", concat_globs = True)
		test_dict = utils.split_data_for_pred_from_labels(data_dict, "test", concat_globs = True)
	else:
		train_dict = utils.split_data_interp(data_dict, "train", concat_globs = args.concat_globs)
		test_dict = utils.split_data_interp(data_dict, "test", concat_globs = args.concat_globs)


	# Subsample points and add mask
	train_dict = utils.subsample_observed_data(train_dict, n_tp_to_sample = args.sample_tp)
	test_dict = utils.subsample_observed_data(test_dict, n_tp_to_sample = args.sample_tp)

	classif_per_tp = False
	if ("train_labels" in data_dict) and (data_dict["train_labels"] is not None):
		if (len(data_dict["train_labels"].size()) == 3):
			# do classification per time point rather than on a time series as a whole
			classif_per_tp = True

	_, n_timepoints, input_dim = data_dict["train_y"].size()
	
	n_labels = 1
	if ("train_labels" in data_dict) and (data_dict["train_labels"] is not None):
		if len(data_dict["train_labels"].size()) != 1:
			n_labels = data_dict["train_labels"].size(-1)

	glob_dims = 0
	if args.concat_globs:
		glob_dims = data_dict["train_globs"].size(-1)

	#print("Time for get_dataset {}".format(time.time() - start))
	# traj_set: torch object of size [N, T, D]. 
	# N -- number of training examples
	# T -- number of time points
	# D -- number of dimensions

	##################################################################
	# Create the model
	obsrv_std = StdParamModule()() if args.obsrv_std_learn else torch.Tensor([args.obsrv_std]).to(device)

	if args.rnn_seq2seq:
		# Create RNN model
		concat_mask = True #(args.dataset == "physionet")
		model = RNN_Seq2Seq(input_dim, args.latents, 
			rec_dims = args.rec_dims, 
			concat_mask = concat_mask, obsrv_std = obsrv_std,
			use_binary_classif = args.classif,
			classif_per_tp = classif_per_tp,
			n_units = args.units,
			input_space_decay = args.input_decay,
			linear_classifier = args.linear_classif,
			device = device,
			glob_dims = glob_dims,
			n_labels = n_labels).to(device)
	elif args.rnn_vae:
		if args.dataset == "pickle" and args.restore_gen:
			y0_prior = utils.get_item_from_pickle(args.datafile, "y0_prior")
		if args.y0_prior == "standard":
			y0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
		elif args.y0_prior == "mix":
			n_components = args.mix_components
			y0_prior = MixtureOfGaussians(args.latents, n_components, device = device)
		else:
			raise Exception("Unknown y0 prior: {}".format(args.y0_prior))

		# Create RNN-VAE model
		concat_mask = True #(args.dataset == "physionet")
		model = RNN_VAE(input_dim, args.latents, 
			device = device, 
			rec_dims = args.rec_dims, 
			concat_mask = concat_mask, 
			obsrv_std = obsrv_std,
			y0_prior = y0_prior,
			use_binary_classif = args.classif,
			classif_per_tp = classif_per_tp,
			linear_classifier = args.linear_classif,
			n_units = args.units,
			input_space_decay = args.input_decay,
			glob_dims = glob_dims,
			cell = args.rnn_cell,
			n_labels = n_labels).to(device)


	elif args.classic_rnn:
		if args.extrap:
			raise Exception("Classic RNN cannot be used for extrapolation. Please use --rnn-seq2seq instead" 
				"OR --extrap-test flag to perform extrapolation at test time")
		# Create RNN model
		concat_mask = True #(args.dataset == "physionet")
		model = Classic_RNN(input_dim, args.latents, device, 
			concat_mask = concat_mask, obsrv_std = obsrv_std,
			n_units = args.units,
			use_binary_classif = args.classif,
			classif_per_tp = classif_per_tp,
			linear_classifier = args.linear_classif,
			input_space_decay = args.input_decay,
			glob_dims = glob_dims,
			cell = args.rnn_cell,
			n_labels = n_labels).to(device)
	elif args.ode_gru or args.ode_gru_rnn:
		# Create ODE-GRU model
		concat_mask = True
		n_ode_gru_dims = args.latents
				
		if args.poisson:
			lambda_net = utils.create_net(n_ode_gru_dims, input_dim, 
				n_layers = args.poisson_layers, n_units = args.units, nonlinear = nn.Tanh)

			# ODE function produces the gradient for latent state and for poisson rate
			ode_func_net = utils.create_net(n_ode_gru_dims * 2, n_ode_gru_dims * 2, 
				n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh)

			rec_ode_func = ODEFunc_w_Poisson(
				input_dim = input_dim, 
				latent_dim = n_ode_gru_dims * 2,
				ode_func_net = ode_func_net,
				lambda_net = lambda_net,
				q_gaussian_process = None,
				gp_prior = None,
				device = device).to(device)
		else:
			ode_func_net = utils.create_net(n_ode_gru_dims, n_ode_gru_dims, 
				n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh)

			rec_ode_func = ODEFunc_one_ODE(
				input_dim = input_dim, 
				latent_dim = n_ode_gru_dims,
				ode_func_net = ode_func_net,
				q_gaussian_process = None,
				gp_prior = None,
				save_ode_ts = False,
				device = device).to(device)

		y0_diffeq_solver = DiffeqSolver(input_dim, rec_ode_func, "euler", args.latents, 
			odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)
		
		if args.ode_gru:
			# seq2seq model -- deterministic
			model = ODE_GRU(input_dim, n_ode_gru_dims, device = device, 
				y0_diffeq_solver = y0_diffeq_solver, n_gru_units = args.gru_units,
				concat_mask = concat_mask, obsrv_std = obsrv_std,
				use_binary_classif = args.classif,
				classif_per_tp = classif_per_tp,
				use_poisson_proc = args.poisson,
				n_units = args.units, 
				glob_dims = glob_dims,
				n_labels = n_labels).to(device)
		if args.ode_gru_rnn:
			if args.extrap:
				raise Exception("ODE-GRU in RNN-style cannot be used for extrapolation. Please use --ode-gru instead."
					"OR --extrap-test flag to perform extrapolation at test time")
		
			model = ODE_GRU_rnn(input_dim, n_ode_gru_dims, device = device, 
				y0_diffeq_solver = y0_diffeq_solver, n_gru_units = args.gru_units,
				concat_mask = concat_mask, obsrv_std = obsrv_std,
				use_poisson_proc = args.poisson,
				use_binary_classif = args.classif,
				classif_per_tp = classif_per_tp,
				glob_dims = glob_dims,
				n_labels = n_labels).to(device)
	else:
		model = create_ODEVAE_model(args, input_dim, obsrv_std, device, 
			classif_per_tp = classif_per_tp,
			glob_dims = glob_dims,
			n_labels = n_labels)
	##################################################################

	model_parameters = filter(lambda p: p.requires_grad, model.parameters())
	n_params = sum([np.prod(p.size()) for p in model_parameters])
	print("Number of model parameters")
	print(n_params)

	if args.viz:
		viz = Visualizations(device)
		if args.latents == 2 and (not args.rnn_seq2seq):
			viz.init_viz_for_one_density_plot()

	##################################################################
	
	#Load checkpoint and evaluate the model
	if args.load is not None:
		utils.get_ckpt_model(ckpt_path, model, device)

		kl_coef = args.kl
		with torch.no_grad():
			#store_best_results(args, train_dict, test_dict, experimentID, res_files)

			if args.dataset == "hopper":
				pass
				# Plot y0 space on test data
				# batch_dict = utils.get_next_batch(test_dict, 0, 10000)
				# plot_y0_space_hopper(model, batch_dict,
				# 	experimentID = experimentID, itr = "test", n_traj_to_show = 10000)

			elif args.dataset == "activity":
				# save_latents(model, train_dict, experimentID = experimentID, name="train")
				# save_latents(model, test_dict, experimentID = experimentID, name="test")
				#run_classif_on_latents(model, experimentID = experimentID)
				plot_metric_versus_n_points(model, test_dict, args, 
					metric = "accuracy", experimentID = experimentID)

				batch_dict = utils.get_next_batch(test_dict, 0, 100)
				plot_y0_space_activity(model, batch_dict,
					experimentID = experimentID, itr = "test", n_traj_to_show = 100)

			elif args.dataset == "physionet":
				batch_dict = utils.get_next_batch(test_dict, 0, 1000)
				save_latents(model, train_dict, experimentID = experimentID, name="train")
				save_latents(model, test_dict, experimentID = experimentID, name="test")
				#run_classif_on_latents(model, experimentID = experimentID)

				plot_y0_space_physionet(model, batch_dict,
					experimentID = experimentID, itr = "test", n_traj_to_show = 1000)
			else:
				batch_dict = utils.get_next_batch(test_dict, 0, 1000)
				plot_y0_space(model, batch_dict, args.dataset,
					experimentID = experimentID, itr = "test", n_traj_to_show = 1000)


			if isinstance(model, ODEVAE):
				#plot_n_ode_calls_versus_n_points(model, data_dict["test_y"], data_dict["test_time_steps"], experimentID, "test")
				
				plot_h0_stdev(model, data_dict["test_y"], data_dict["test_time_steps"], experimentID, "test")
				plot_traj_from_prior(model, data_dict["test_time_steps"], experimentID, "test")	

			if (args.dataset == "periodic" or args.dataset == "pickle") and (args.extrap or args.extrap_test):
				plot_metric_versus_n_points(model, test_dict, args, 
					metric = "mse", experimentID = experimentID)

				start = test_dict["tp_to_predict"][0]
				plot_reconstructions(model, test_dict, experimentID, "extrap_future", 
					width = 7, mark_train_test = True,
					time_steps_to_predict = torch.linspace(start, start*10, 100).to(device))

			if args.dataset == "periodic":
				make_predictions_w_samples_same_traj(model, experimentID, device= device,
					extrap = (args.extrap or args.extrap_test))

			if args.latents == 2 and (not args.rnn_seq2seq) and not args.classic_rnn and not args.ode_gru_rnn:
				# Plot density plots in training data so that we have more samples in the plot
				viz.draw_one_density_plot(model, test_dict,
					experimentID, -99, density_plot_dir, 
					log_scale=False, multiply_by_poisson = False)
				
				if args.poisson:
					viz.draw_one_density_plot(model, test_dict,
						experimentID, -99, density_plot_dir, 
						log_scale=False, multiply_by_poisson = True)

			######## PLOTTING physionet
		
			if args.dataset == "physionet":
				########################
				# plot training data
				batch_dict = utils.get_next_batch(train_dict, 0, args.batch_size)

				if isinstance(model, ODEVAE and not args.poisson:
					plot_ode_performance(model, batch_dict, experimentID, "train")

				plot_reconstructions_per_patient(model, batch_dict,
					attr_list =  data_dict["attr"], 
					experimentID = experimentID, itr = "train",
					n_traj_to_show = 25)

				# Generating trajectories from prior
				#print( model.sample_traj_from_prior(torch.linspace(0,1,50), n_traj_samples = 10).size())

				########################
				# plot test data
				batch_dict = utils.get_next_batch(test_dict, 0, args.batch_size)

				if isinstance(model, ODEVAE):
					plot_ode_performance(model, batch_dict, experimentID, "test")

				plot_reconstructions_per_patient(model,
					batch_dict,
					attr_list =  data_dict["attr"], 
					experimentID = experimentID, itr = "test",
					n_traj_to_show = 25)

			else:
				plot_reconstruct_encoding_t0_ti(model, data_dict["test_y"], data_dict["test_time_steps"], experimentID, "test")

				if isinstance(model, ODEVAE):
					plot_ode_performance(model, train_dict, experimentID, "train")
					plot_ode_performance(model, test_dict, experimentID, "test")

			plot_reconstructions(model, train_dict, experimentID, "train")
			plot_reconstructions(model, test_dict, experimentID, "test")

			######## PLOTTING mujoco

			if args.dataset == "hopper":
				vizualize_mujoco(model, data_dict["dataset_obj"], 
					test_dict, experimentID)
				exit()

			# !!!! TODO: fix --- something wrong with cuda tensors
			if args.dataset != "physionet":
				viz.save_all_plots(test_dict, model = model, 
					loss_list = None, kl_list = None, call_count_d = None, 
					plot_name = file_name, experimentID = experimentID)

				if isinstance(model, ODEVAE):
					# Make a video where we condition on inducing points one-by-one
					make_conditioning_on_ind_points_video(model, data_dict["test_y"], data_dict["test_time_steps"], experimentID)
					convert_to_movie("plots/samples_same_traj/samples_same_traj_" + str(experimentID) + "_%04d.png", 
						"plots/samples_same_traj_" + str(experimentID) + ".mp4")

		exit()

	##################################################################
	# Training

	log_path = "logs/" + file_name + "_" + str(experimentID) + ".log"
	if not os.path.exists("logs/"):
		utils.makedirs("logs/")
	logger = utils.get_logger(logpath=log_path, filepath=os.path.abspath(__file__))
	logger.info(args)
	logger.info(input_command)

	optimizer = optim.Adamax(model.parameters(), lr=args.lr)
	# Increment number of time points at each iteration 
	# This allows to fit ODE gradients corrently for the start of the timeline and ODE does not go crazy

	loss_list = {"test_seq_sampled": [], "extrap_sampled": [], "test_seq_all": [], "extrap_all": []}

	kl_list = []
	gp_param_list = []
	y0_gp_param_list = []
	n_calls_d = {"itr" : [], "n_calls": [], "n_tp": []}

	start = time.time()
	time_draw = time_per_iter = 0.
	reconstr_plot_count = density_plot_count = 0

	batch_size = min(args.batch_size, data_dict["train_y"].size(0))
	num_batches = max(1, int(np.ceil(data_dict["train_y"].size(0) / batch_size)))

	for itr in range(1, num_batches * (args.niters + 1)):
		optimizer.zero_grad()
		utils.update_learning_rate(optimizer, decay_rate = 0.999, lowest = args.lr / 10)

		wait_until_kl_inc = 50
		if itr // num_batches < wait_until_kl_inc:
			kl_coef = 0.
		else:
			kl_coef = args.kl * (1-0.993** (itr // num_batches - wait_until_kl_inc)) # 0.996 -- decay to kl_coef = 0.01 in 2000 epochs, 0.1 in 1000 epochs

		batch_dict = utils.get_next_batch(train_dict, itr, batch_size)

		start = time.time()
		train_res = model.compute_all_losses(
			batch_dict,
			n_traj_samples = args.n_traj_samples,
			kl_coef = kl_coef)
			#use_classif_loss = itr // num_batches > 100)
		gc.collect()

		#logger.info("Time for train forward pass: {}".format(time.time() - start))

		n_calls_d["itr"].append(itr)
		n_calls_d["n_tp"].append(len(data_dict["train_time_steps"].detach()))
		n_calls_d["n_calls"].append(train_res["n_calls"])

		start = time.time()
		train_res["loss"].backward()
		#logger.info("Time for train backward pass: {}".format(time.time() - start))

		optimizer.step()
		gc.collect()

		#cpuStats()

		start = time.time()
		n_iters_to_viz = max(1, args.niters // 10)

		if itr % (n_iters_to_viz * num_batches) == 0:
		#if itr % (1) == 0:
			print(cpuStats())

			with torch.no_grad():
				#print("Computing losses... itr " + str(itr))
				time2 = time.time()

				if args.dataset == "physionet" or args.dataset == "activity":
					test_res = compute_loss_all_batches(model, test_dict, args,
						experimentID = experimentID,
						n_traj_samples = args.n_traj_samples, kl_coef = kl_coef)
				else:
					# Compute the likelihood on the unseen time series. GP z(t) is conditioned on sub-sampled observations.
					test_res  = model.compute_all_losses(
						test_dict, 
						n_traj_samples = args.n_traj_samples, 
						kl_coef = kl_coef)

					# # Compute the likelihood for extrapolation of the train sequence. GP z(t) is conditioned on sub-sampled observations.
					# extrap_loss, extrap_likelihood, extrap_kl, _, _, _  = model.compute_all_losses(
					# 	data_dict["extrap_y"], data_dict["time_steps_extrap"], 
					# 	mask = data_dict["extrap_mask"], labels = data_dict["extrap_labels"],
					# 	n_tp_to_sample = args.sample_tp, n_traj_samples = args.n_traj_samples, kl_coef = kl_coef)

					# # Compute the likelihood on the unseen time series. GP z(t) is conditioned on all available observations.
					# test_all_tp_loss, test_all_tp_likelihood, test_all_tp_kl, _, _, _  = model.compute_all_losses(
					# 	data_dict["test_y"], data_dict["test_time_steps"], 
					# 	mask = data_dict["test_mask"], labels = data_dict["test_labels"],
					# 	n_tp_to_sample = None, n_traj_samples = args.n_traj_samples, kl_coef = kl_coef)

					# # Compute the likelihood for extrapolation of the train sequence. GP z(t) is conditioned on all available observations.
					# extrap_all_tp_loss, extrap_all_tp_likelihood, extrap_all_tp_kl, _, _, _  = model.compute_all_losses(
					# 	data_dict["extrap_y"], data_dict["time_steps_extrap"], 
					# 	mask = data_dict["extrap_mask"], labels = data_dict["extrap_labels"],
					# 	n_tp_to_sample = None, n_traj_samples = args.n_traj_samples, kl_coef = kl_coef)
			
				# plot loss and KL only when we have looked at the whole time series
				#loss_list["extrap_sampled"].append(extrap_loss)
				loss_list["test_seq_sampled"].append(test_res["loss"].detach())
				#loss_list["extrap_all"].append(extrap_all_tp_loss.detach())
				#loss_list["test_seq_all"].append(test_all_tp_loss.detach())
				kl_list.append(test_res["kl_first_p"])


				time_per_iter = time.time() - start
				message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}| # time points {} | Time per iter {:.6f} |'.format(
					itr//num_batches, 
					test_res["loss"].detach(), test_res["likelihood"].detach(), 
					test_res["kl_first_p"], test_res["std_first_p"],
					n_timepoints, time_per_iter)

				# message2 = 'Epoch {:04d} [Extrap (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} |'.format(
				# 	itr//num_batches, 
				# 	extrap_loss, extrap_likelihood) 

				# message3 = 'Epoch {:04d} [Test seq (cond on all tp)] | Loss {:.6f} | Likelihood {:.6f} |'.format(
				# 	itr//num_batches, 
				# 	test_all_tp_loss, test_all_tp_likelihood)

				# message4 = 'Epoch {:04d} [Extrap (cond on all tp)] | Loss {:.6f} | Likelihood {:.6f} |'.format(
				# 	itr//num_batches, 
				# 	extrap_all_tp_loss, extrap_all_tp_likelihood)
		 	
				logger.info("Experiment " + str(experimentID))
				logger.info(message)
				#logger.info(message2)
				#logger.info(message3)
				#logger.info(message4)
				logger.info("KL coef: {}".format(kl_coef))
				logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
				logger.info("Train CE loss (one batch): {}".format(train_res["ce_loss"].detach()))
				
				if "auc" in test_res:
					logger.info("Classification AUC (test, fold {}): {:.4f}".format(args.fold, test_res["auc"]))

				if "mse" in test_res:
					logger.info("Test MSE: {:.4f}".format(test_res["mse"]))

				if "accuracy" in train_res:
					logger.info("Classification accuracy (TRAIN, fold {}): {:.4f}".format(args.fold, train_res["accuracy"]))

				if "accuracy" in test_res:
					logger.info("Classification accuracy (test, fold {}): {:.4f}".format(args.fold, test_res["accuracy"]))

				if "pois_likelihood" in test_res:
					logger.info("Poisson likelihood: {}".format(test_res["pois_likelihood"]))

				if "ce_loss" in test_res:
					logger.info("CE loss: {}".format(test_res["ce_loss"]))

			torch.save({
				'args': args,
				'state_dict': model.state_dict(),
			}, ckpt_path)

			print("Checking best results")
			store_best_results(args, train_dict, test_dict, experimentID, res_files)


			if itr % (n_iters_to_viz * num_batches) == 0 and args.viz:
				with torch.no_grad():
					start_draw = time.time()
					start = time.time()

					# save_reconstructions_for_same_traj(model, data_dict["test_y"], data_dict["test_time_steps"], reconstr_plot_count, experimentID)
					# reconstr_plot_count += 1

					# # Make a video where we condition on inducing points one-by-one
					# #make_conditioning_on_ind_points_video(model, data_dict["test_y"], data_dict["test_time_steps"], experimentID)

					# viz.draw_all_plots_one_dim(data_dict["test_y"], data_dict["test_time_steps"], data_dict["extrap_y"], data_dict["time_steps_extrap"], model, loss_list, 
					# 	kl_list, gp_param_list, y0_gp_param_list, n_calls_d, 
					# 	plot_name = file_name + "_itr_" + str(itr // num_batches),
					# 	curr_tp = int(max_tp_curr_iter),
					# 	experimentID = experimentID,
					# 	n_tp_to_sample = args.sample_tp,
					# 	save=True)

					print("plotting....")
					if args.dataset != "physionet" and not args.rnn_seq2seq and not args.classic_rnn and not args.ode_gru_rnn:
						viz.draw_all_plots_one_dim(test_dict, model, loss_list, 
							kl_list, gp_param_list, y0_gp_param_list, n_calls_d, 
							plot_name = file_name + "_itr_" + str(itr // num_batches),
							experimentID = experimentID,
							save=True)

					# plot_reconstruct_encoding_t0_ti(model, data_dict["test_y"], data_dict["test_time_steps"], experimentID, itr)

					# plot_reconstruct_diff_t0(model, data_dict["test_y"], data_dict["test_time_steps"], experimentID, itr,
					# 	n_tp_to_sample = args.sample_tp)

					print("making reconstructions on the TEST set....")

					if args.dataset == "physionet":
						batch_dict = utils.get_next_batch(test_dict, itr, batch_size)

						plot_reconstructions_per_patient(model,
							batch_dict,
							attr_list =  data_dict["attr"], 
							experimentID = experimentID, itr =itr // num_batches)
					else:
						plot_reconstructions(model, 
							test_dict, experimentID, itr // num_batches)


					if args.latents == 2 and (not args.rnn_seq2seq):
						# Plot density plots in training data so that we have more samples in the plot
						viz.draw_one_density_plot(model, test_dict,
							experimentID, density_plot_count, density_plot_dir, 
							log_scale=False, multiply_by_poisson = False)
						
						if args.poisson:
							viz.draw_one_density_plot(model, test_dict,
								experimentID, density_plot_count, density_plot_dir, 
								log_scale=False, multiply_by_poisson = True)

						density_plot_count += 1

					time_draw = time.time() - start_draw
					plt.pause(0.01)

					message += ' Draw time {:.6f} |'.format(time_draw)

		#logger.info("Time for test pass: {}".format(time.time() - start))


	torch.save({
		'args': args,
		'state_dict': model.state_dict(),
	}, ckpt_path)

	if args.viz:
		with torch.no_grad():
			# viz.draw_all_plots_one_dim(data_dict["test_y"], time_steps,  data_dict["extrap_y"], data_dict["time_steps_extrap"], model, 
			# 	loss_list = loss_list, kl_list = kl_list, 
			# 	gp_param_list = gp_param_list,
			#	y0_gp_param_list = y0_gp_param_list,
			# 	call_count_d = n_calls_d, plot_name = file_name, 
			# 	curr_tp = int(max_tp_curr_iter), experimentID = experimentID,
			# 	n_tp_to_sample = args.sample_tp,
			# 	save = True)

			if args.dataset != "physionet" and not args.rnn_seq2seq and not args.classic_rnn and not args.ode_gru_rnn:
				viz.save_all_plots(test_dict, model, 
					loss_list = loss_list, kl_list = kl_list, 
					gp_param_list = gp_param_list,
					y0_gp_param_list = y0_gp_param_list,
					call_count_d = n_calls_d, plot_name = file_name, 
					experimentID = experimentID,
					n_tp_to_sample = args.sample_tp)

			# Cannot run mujoco on guppy/slurm nodes
			# To generate mujoco, run the script again with --load parameter after the training finishes
			# if args.dataset == "hopper":
			# 	vizualize_mujoco(model, data_dict["dataset_obj"], data_dict["test_y"], time_steps, experimentID)

			if args.latents == 2 and (not args.rnn_seq2seq) and not args.classic_rnn and not args.ode_gru_rnn:
				viz.draw_one_density_plot(model, test_dict, experimentID, 
					density_plot_count, density_plot_dir, log_scale=False)

			# Make a video where we condition on inducing points one-by-one
			# make_conditioning_on_ind_points_video(model, data_dict["test_y"], time_steps, experimentID)
			if args.dataset != "physionet" and not args.rnn_seq2seq and not args.classic_rnn and not args.ode_gru_rnn:
				convert_to_movie("plots/samples_same_traj/samples_same_traj_" + str(experimentID) + "_%04d.png", 
					"plots/samples_same_traj_" + str(experimentID) + ".mp4")

			if args.latents == 2:
				for traj_id in range(10):
					convert_to_movie(density_plot_dir + \
						"y0_density_traj_{}_{}".format(traj_id, experimentID) + "_%03d.png",
						density_plot_dir + "y0_density_traj_{}_{}".format(traj_id, experimentID) + ".mp4", rate = 10)



	logger.info("Experiment " + str(experimentID))
	logger.info("Starting Time: " + starting_time)
	logger.info("Time of finishing: " + \
		datetime.datetime.today().strftime('%Y-%m-%d') + " " + datetime.datetime.now().time().strftime("%H:%M:%S"))


