import os
import logging
import pickle

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import math 
import glob
import re
from shutil import copyfile
import sklearn as sk

def makedirs(dirname):
	if not os.path.exists(dirname):
		os.makedirs(dirname)


def save_checkpoint(state, save, epoch):
	if not os.path.exists(save):
		os.makedirs(save)
	filename = os.path.join(save, 'checkpt-%04d.pth' % epoch)
	torch.save(state, filename)

	
def get_logger(logpath, filepath, package_files=[],
			   displaying=True, saving=True, debug=False):
	logger = logging.getLogger()
	if debug:
		level = logging.DEBUG
	else:
		level = logging.INFO
	logger.setLevel(level)
	if saving:
		info_file_handler = logging.FileHandler(logpath, mode='w')
		info_file_handler.setLevel(level)
		logger.addHandler(info_file_handler)
	if displaying:
		console_handler = logging.StreamHandler()
		console_handler.setLevel(level)
		logger.addHandler(console_handler)
	logger.info(filepath)
	# Add the code of the file to the log
	# with open(filepath, 'r') as f:
	# 	logger.info(f.read())

	for f in package_files:
		logger.info(f)
		with open(f, 'r') as package_f:
			logger.info(package_f.read())

	return logger


class AverageMeter(object):
	"""Computes and stores the average and current value"""
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.count += n
		self.avg = self.sum / self.count


class RunningAverageMeter(object):
	"""Computes and stores the average and current value"""

	def __init__(self, momentum=0.99):
		self.momentum = momentum
		self.reset()

	def reset(self):
		self.val = None
		self.avg = 0

	def update(self, val):
		if self.val is None:
			self.avg = val
		else:
			self.avg = self.avg * self.momentum + val * (1 - self.momentum)
		self.val = val


def inf_generator(iterable):
	"""Allows training with DataLoaders in a single infinite loop:
		for i, (x, y) in enumerate(inf_generator(train_loader)):
	"""
	iterator = iterable.__iter__()
	while True:
		try:
			yield iterator.__next__()
		except StopIteration:
			iterator = iterable.__iter__()

def dump_pickle(data, filename):
	with open(filename, 'wb') as pkl_file:
		pickle.dump(data, pkl_file)

def load_pickle(filename):
	with open(filename, 'rb') as pkl_file:
		filecontent = pickle.load(pkl_file)
	return filecontent

def make_dataset(dataset_type = "spiral",**kwargs):
	if dataset_type == "spiral":
		data_path = "data/spirals.pickle"
		dataset = load_pickle(data_path)["dataset"]
		chiralities = load_pickle(data_path)["chiralities"]
	elif dataset_type == "chiralspiral":
		data_path = "data/chiral-spirals.pickle"
		dataset = load_pickle(data_path)["dataset"]
		chiralities = load_pickle(data_path)["chiralities"]
	else:
		raise Exception("Unknown dataset type " + dataset_type)
	return dataset, chiralities


def split_last_dim(data):
	last_dim = data.size()[-1]
	last_dim = last_dim//2

	if len(data.size()) == 3:
		res = data[:,:,:last_dim], data[:,:,last_dim:]

	if len(data.size()) == 2:
		res = data[:,:last_dim], data[:,last_dim:]
	return res


def init_network_weights(net, std = 0.1):
	for m in net.modules():
		if isinstance(m, nn.Linear):
			nn.init.normal_(m.weight, mean=0, std=std)
			nn.init.constant_(m.bias, val=0)


def flatten(x, dim):
	return x.reshape(x.size()[:dim] + (-1, ))


def subsample_timepoints(data, time_steps = None, mask = None, n_tp_to_sample = None):
	# n_tp_to_sample: number of time points to subsample. If not None, sample exactly n_tp_to_sample points
	if mask is None:
		mask = torch.ones_like(data).to(get_device(data))

	if n_tp_to_sample is None:
		return data, time_steps, mask
	n_tp = len(time_steps)
	assert(n_tp_to_sample <= n_tp)

	for i in range(data.size(0)):
		missing_idx = sorted(np.random.choice(np.arange(n_tp), n_tp - n_tp_to_sample, replace = False))

		data[i, missing_idx] = 0.
		mask[i, missing_idx] = 0.

	return data, time_steps, mask


def get_device(tensor):
	device = torch.device("cpu")
	if tensor.is_cuda:
		device = tensor.get_device()
	return device

def sample_standard_gaussian(mu, sigma):
	device = get_device(mu)

	d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device))
	r = d.sample(mu.size()).squeeze(-1)
	return r * sigma.float() + mu.float()


def split_train_test(data, train_fraq = 0.8):
	n_samples = data.size(0)
	data_train = data[:int(n_samples * train_fraq)]
	data_test = data[int(n_samples * train_fraq):]
	return data_train, data_test

def split_train_test_data_and_time(data, time_steps, train_fraq = 0.8):
	n_samples = data.size(0)
	data_train = data[:int(n_samples * train_fraq)]
	data_test = data[int(n_samples * train_fraq):]

	assert(len(time_steps.size()) == 2)
	train_time_steps = time_steps[:, :int(n_samples * train_fraq)]
	test_time_steps = time_steps[:, int(n_samples * train_fraq):]

	return data_train, data_test, train_time_steps, test_time_steps



def get_next_batch(data_dict, itr, batch_size):
	data = data_dict["observed_data"]
	num_batches = int(np.ceil(len(data) / batch_size))
	idx = itr % num_batches

	start = idx * batch_size
	end = (idx+1) * batch_size
	
	batch_dict = get_dict_template()

	batch_dict["observed_data"] = data_dict["observed_data"][start:end]
	batch_dict["observed_tp"] = data_dict["observed_tp"]
	batch_dict[ "data_to_predict"] = data_dict["data_to_predict"][start:end]
	batch_dict["tp_to_predict"] = data_dict["tp_to_predict"]

	if ("observed_mask" in data_dict) and (data_dict["observed_mask"] is not None):
		batch_dict["observed_mask"] = data_dict["observed_mask"][start:end]

	if ("mask_predicted_data" in data_dict) and (data_dict["mask_predicted_data"] is not None):
		batch_dict["mask_predicted_data"] = data_dict["mask_predicted_data"] [start:end]

	if ("labels" in data_dict) and (data_dict["labels"] is not None):
		batch_dict["labels"] = data_dict["labels"][start:end]

	if "globs" in batch_dict:
		batch_dict["globs"] = data_dict["globs"][start:end]
	return batch_dict


def get_ckpt_model(ckpt_path, model, device):
	if not os.path.exists(ckpt_path):
		raise Exception("Checkpoint " + ckpt_path + " does not exist.")
	# Load checkpoint.
	checkpt = torch.load(ckpt_path)
	ckpt_args = checkpt['args']
	state_dict = checkpt['state_dict']
	model_dict = model.state_dict()

	# 1. filter out unnecessary keys
	state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
	# 2. overwrite entries in the existing state dict
	model_dict.update(state_dict) 
	# 3. load the new state dict
	model.load_state_dict(state_dict)
	model.to(device)


def update_learning_rate(optimizer, decay_rate = 0.999, lowest = 1e-3):
	for param_group in optimizer.param_groups:
		lr = param_group['lr']
		lr = max(lr * decay_rate, lowest)
		param_group['lr'] = lr


def linspace_vector(start, end, n_points):
	# start is either one value or a vector
	size = np.prod(start.size())

	assert(start.size() == end.size())
	if size == 1:
		# start and end are 1d-tensors
		res = torch.linspace(start, end, n_points)
	else:
		# start and end are vectors
		res = torch.Tensor()
		for i in range(0, start.size(0)):
			res = torch.cat((res, 
				torch.linspace(start[i], end[i], n_points)),0)
		res = torch.t(res.reshape(start.size(0), n_points))
	return res

def reverse(tensor):
	idx = [i for i in range(tensor.size(0)-1, -1, -1)]
	return tensor[idx]


def create_net(n_inputs, n_outputs, n_layers = 1, n_units = 100, nonlinear = nn.Tanh):
	layers = [nn.Linear(n_inputs, n_units)]
	for i in range(n_layers):
		layers.append(nonlinear())
		layers.append(nn.Linear(n_units, n_units))

	layers.append(nonlinear())
	layers.append(nn.Linear(n_units, n_outputs))
	return nn.Sequential(*layers)


def get_item_from_pickle(pickle_file, item_name):
	from_pickle = load_pickle(pickle_file)
	if item_name in from_pickle:
		return from_pickle[item_name]
	return None

def split_by_time(dataset, time_steps_extrap, n_timepoints):
	true_y, extrap_y = dataset[:,:n_timepoints], dataset
	time_steps = time_steps_extrap[:n_timepoints]
	return true_y, extrap_y, time_steps, time_steps_extrap

def get_dict_template():
	return {"observed_data": None,
			"observed_tp": None,
			"data_to_predict": None,
			"tp_to_predict": None,
			"observed_mask": None,
			"mask_predicted_data": None,
			"labels": None
			}
	

# def normalize_data(data):
# 	reshaped = data.reshape(-1, data.size(-1))

# 	att_mean = torch.mean(reshaped, 0)
# 	att_std = torch.std(reshaped, 0)
# 	# we don't want to divide by zero
# 	att_std[ att_std == 0.] = 1.

# 	if (att_std != 0.).all():
# 		data_norm = (data - att_mean) / att_std
# 	else:
# 		raise Exception("Zero!")

# 	if torch.isnan(data_norm).any():
# 		raise Exception("nans!")

# 	return data_norm, att_mean, att_std


def normalize_data(data):
	reshaped = data.reshape(-1, data.size(-1))

	att_min = torch.min(reshaped, 0)[0]
	att_max = torch.max(reshaped, 0)[0]
	# we don't want to divide by zero
	att_max[ att_max == 0.] = 1.

	if (att_max != 0.).all():
		data_norm = (data - att_min) / att_max
	else:
		raise Exception("Zero!")

	if torch.isnan(data_norm).any():
		raise Exception("nans!")

	return data_norm, att_min, att_max



def update_metric_if_larger(results_file_name, model_name, metric_name, value):
	if not os.path.isfile(results_file_name):
		with open(results_file_name, 'w') as f:
			f.write("model," + metric_name)

	res_dict = pd.read_csv(results_file_name, sep=",")
	if metric_name not in res_dict.columns:
		res_dict[metric_name] = None
		print("Added new metric: " + metric_name)

	value_updated = False
	
	if isinstance(value, torch.Tensor):
		value = value.cpu().numpy()

	if (res_dict["model"] == model_name).any():
		prev_value = res_dict.loc[res_dict["model"] == model_name, metric_name]
		larger_than_prev = False
		if prev_value is not None and prev_value.item() is not None:
			larger_than_prev = float(prev_value.item()) < value

		print("prev")
		print(prev_value)
		print(larger_than_prev)

		if (prev_value.item() is None) or math.isnan(prev_value) or larger_than_prev:

			print("here")
			print(value)
			print(res_dict.loc[res_dict["model"] == model_name, metric_name])


			res_dict.loc[res_dict["model"] == model_name, metric_name] = float(value)
			print("updated results table")
			value_updated = True
	else:
		 res_dict = res_dict.append({"model": model_name, metric_name: value}, ignore_index=True)
		 print("added a new model to results table")
		 value_updated = True

	print(res_dict)

	with open(results_file_name, 'w') as csvfile:
		res_dict.to_csv(csvfile, index=False)
	return res_dict, value_updated

def update_value(results_file_name, model_name, metric_name, value):
	if not os.path.isfile(results_file_name):
		with open(results_file_name, 'w') as f:
			f.write("model," + metric_name)

	res_dict = pd.read_csv(results_file_name, sep=",")
	if metric_name not in res_dict.columns:
		res_dict[metric_name] = None

	if (res_dict["model"] == model_name).any():
		prev_value = res_dict.loc[res_dict["model"] == model_name, metric_name]

		res_dict.loc[res_dict["model"] == model_name, metric_name] = value
		print("updated!!")
	else:
		 res_dict = res_dict.append({"model": model_name, metric_name: value}, ignore_index=True)
		 print("added new model!!")

	print(res_dict)

	with open(results_file_name, 'w') as csvfile:
		res_dict.to_csv(csvfile, index=False)
	return res_dict



def merge_results_tables(source_folder, dest_folder):
	source_folder_files = glob.glob(source_folder + '/results*.csv')
	dest_folder_files = glob.glob(dest_folder + '/results*.csv')

	experiment_types = [re.sub(source_folder + '/results_', '', s) for s in source_folder_files]
	experiment_types = [re.sub('.csv', '', s) for s in experiment_types]

	print(experiment_types)

	for experim in experiment_types:
		res_source_file = source_folder + '/results_' + experim + '.csv'
		res_dest_file = dest_folder + '/results_' + experim + '.csv'

		if "classif" in experim or "poisson" in experim:
			exp_source_file = source_folder + '/experim_' + experim + '.csv'
			exp_dest_file = dest_folder + '/experim_' + experim + '.csv'
		else:
			exp_source_file = source_folder + '/experiments_' + experim + '.csv'
			exp_dest_file = dest_folder + '/experiments_' + experim + '.csv'

		run_source_file = source_folder + '/run_' + experim + '.sh'
		run_dest_file = dest_folder + '/run_' + experim + '.sh'

		assert(os.path.exists(res_source_file))

		if not os.path.exists(res_dest_file):
			copyfile(res_source_file, res_dest_file)
			copyfile(exp_source_file, exp_dest_file)
			copyfile(run_source_file, run_dest_file)

		res_df = pd.read_csv(res_source_file, sep=",")
		exp_df = pd.read_csv(exp_source_file, sep=",")

		print(res_df)

		for metric_name in res_df.columns:
			if metric_name == "model":
				continue

			for model_name in res_df["model"]:

				value = res_df.loc[res_df["model"] == model_name, metric_name].item()
				experimentID = exp_df.loc[exp_df["model"] == model_name, metric_name].item()

				res_dict, value_updated = update_metric_if_larger(
						res_dest_file, model_name, 
						metric_name, value)
					
				if value_updated:
					res_dict = update_value(
						exp_dest_file, model_name, 
						metric_name, experimentID)

		with open(run_dest_file, "a") as dest_file:
			with open(run_source_file, "r") as s_file:
				for line in s_file:
					dest_file.write("\n" + line)


def get_run_command(experimentID, dataset, to_rerun_file = None):
	log_file = "logs/generate_and_fit_1d_toy_VI_" + dataset + "_" + str(experimentID) + ".log"

	if not os.path.exists(log_file):
		print("File " + log_file + " does not exist")
		return

	with open(log_file, "r") as file:
		lines = file.readlines()

		run_command = ""
		if len(lines) >= 3:
			run_command = lines[2].strip()

		if not run_command.startswith("generate_and_fit_1d_toy_VI.py"):
			print("WARNING: Experiment " + str(experimentID) + " does not have run command -- needs re-run")
			if to_rerun_file is not None:
				with open(to_rerun_file, 'a') as file:
					file.write(str(experimentID) + "\n")
			return None

		run_command = "python3 " + run_command
	return run_command


def gather_figures(experim_file_name, dataset, dest_folder,
	model_names = None,
	metrics_to_gather = None, 
	filename_to_gather = "plots/{}/reconstr_{}_traj_*_test.pdf",
	rerun_plots = False):

	makedirs(dest_folder)

	if not os.path.isfile(experim_file_name):
		print("File " + experim_file_name + " does not exist")
		return

	res_dict = pd.read_csv(experim_file_name, sep=",")

	print(res_dict)

	if metrics_to_gather is None:
		metrics_to_gather = res_dict.columns[1:]
		print("Gathering metrics: " + str(metrics_to_gather))

	if model_names is None:
		# if model names are not specified, take all models
		model_names = res_dict.iloc[:,0]
		print("Gathering models: " + str(model_names))

	for metric_name in metrics_to_gather:
		for model_name in model_names:
			if not (res_dict["model"] == model_name).any():
				print("Model " + model_name + " not found")
				continue
			
			if metric_name not in res_dict.columns:
				print("Column " + metric_name + " not found")
				continue

			experimID = res_dict.loc[res_dict["model"] == model_name, metric_name].item()

			if math.isnan(experimID):
				print("WARNING: Experiment " + str(metric_name) + " / " +  str(model_name) + " is missing -- needs re-run")
				continue

			experimID = int(experimID)
			plot_regexp = filename_to_gather.format(experimID, experimID)

			plots = glob.glob(plot_regexp)

			if (len(plots) == 0) or rerun_plots:
				run_command = get_run_command(experimID, dataset)
				if run_command is not None:
					os.system(run_command + " --load " + str(experimID) + " --do-not-store")

					print(plots)
					plots = glob.glob(plot_regexp)
					print(plots)

			if len(plots) != 0:
				for pl in plots:
					basename = os.path.basename(pl)
					
					prefix = basename[:basename.find(str(experimID))]
					suffix = basename[(basename.find(str(experimID))   + len(str(experimID))):]

					#copyfile(pl, dest_folder + basename)
					copyfile(pl, dest_folder + prefix + str(model_name) + str(metric_name) + suffix)



def gather_all_figures():
	# Spiral dataset
	gather_figures(
		experim_file_name = "results/experiments_pickle_interp_n_subsampled_points_likelihood.csv",
		dataset = "pickle",
		dest_folder = "plots_for_paper/spiral_interp_30stp/",
		metrics_to_gather = ["30"],
		rerun_plots = True)

	gather_figures(
		experim_file_name = "results/experiments_pickle_interp_n_subsampled_points_likelihood.csv",
		dataset = "pickle",
		dest_folder = "plots_for_paper/spiral_interp_20stp/",
		metrics_to_gather = ["20"],
		rerun_plots = True)

	gather_figures(
		experim_file_name = "results/experiments_pickle_extrap_n_subsampled_points_likelihood.csv",
		dataset = "pickle",
		dest_folder = "plots_for_paper/spiral_extrap_30stp/",
		metrics_to_gather = ["30"],
		rerun_plots = True)

	gather_figures(
		experim_file_name = "results/experiments_pickle_extrap_n_subsampled_points_likelihood.csv",
		dataset = "pickle",
		dest_folder = "plots_for_paper/spiral_extrap_20stp/",
		metrics_to_gather = ["20"],
		rerun_plots = True)

	gather_figures(
		experim_file_name = "results/experiments_pickle_interp_n_subsampled_points_likelihood.csv",
		dataset = "pickle",
		dest_folder = "plots_for_paper/spiral_interp/",
		metrics_to_gather = ["10", "20", "30", "50"],
		model_names = ["y0_ode_combine"],
		rerun_plots = True)

	gather_figures(
		experim_file_name = "results/experiments_pickle_extrap_n_subsampled_points_likelihood.csv",
		dataset = "pickle",
		dest_folder = "plots_for_paper/spiral_extrap/",
		metrics_to_gather = ["10", "20", "30", "50"],
		model_names = ["y0_ode_combine"],
		rerun_plots = True)


	gather_figures(
		experim_file_name = "results/experiments_pickle_interp_n_subsampled_points_likelihood.csv",
		dataset = "pickle",
		dest_folder = "plots_for_paper/spiral_interp/",
		metrics_to_gather = ["10", "20", "30", "50"],
		model_names = ["classic_rnn_cell_gru", "ode_gru_rnn"],
		rerun_plots = True)

	#gather_figures(
		# experim_file_name = "results/experiments_pickle_extrap_n_subsampled_points_likelihood.csv",
		# dataset = "pickle",
		# dest_folder = "plots_for_paper/spiral_extrap/",
		# metrics_to_gather = ["10", "20", "30", "50"],
		# model_names = ["classic_rnn_cell_gru", "ode_gru_rnn"],
		# rerun_plots = True)

	gather_figures(
		experim_file_name = "results/experiments_pickle_extrap_test_n_subsampled_points_likelihood.csv",
		dataset = "pickle",
		dest_folder = "plots_for_paper/spiral_extrap_test/",
		metrics_to_gather = ["10", "20", "30", "50"],
		model_names = ["classic_rnn_cell_gru", "classic_rnn_cell_gru_input_decay", "classic_rnn_cell_expdecay_input_decay", "ode_gru_rnn"],
		rerun_plots = True)

	for j in range(10):
		gather_figures(
			experim_file_name = "results/experiments_periodic_extrap_test_n_subsampled_points_likelihood.csv",
			dataset = "periodic",
			dest_folder = "plots_for_paper/periodic_extrap_future/",
			metrics_to_gather = ["10", "20", "30", "50", "80"],
			model_names = ["classic_rnn_cell_gru", "classic_rnn_cell_gru_input_decay", "classic_rnn_cell_expdecay_input_decay", "ode_gru_rnn"],
			filename_to_gather = "plots/{}/reconstr_{}_traj_" + str(j) + "_extrap_future.pdf",
			rerun_plots = False)
		gather_figures(
			experim_file_name = "results/experiments_periodic_extrap_n_subsampled_points_likelihood.csv",
			dataset = "periodic",
			dest_folder = "plots_for_paper/periodic_extrap_future/",
			metrics_to_gather = ["10", "20", "30", "50", "80"],
			model_names = ["y0_ode_combine", "y0_rnn", "rnn_vae"],
			filename_to_gather = "plots/{}/reconstr_{}_traj_" + str(j) + "_extrap_future.pdf",
			rerun_plots = False)



def gather_mujoco_figures():
	for j in range(10):
		for i in range(10, 100, 10):
			gather_figures(
				experim_file_name = "results/experiments_hopper_interp_n_subsampled_points_likelihood.csv",
				dataset = "hopper",
				dest_folder = "plots_for_paper/experiments_hopper_interp_n_subsampled_points_likelihood/",
				metrics_to_gather = ["30"],
				filename_to_gather = "hopper_imgs/{}/reconstr_traj_" + str(j) + "_{}-0" + str(i) + ".jpg",
				rerun_plots = False)

	for j in range(10):
		for i in range(10, 100, 10):
			gather_figures(
				experim_file_name = "results/experiments_hopper_interp_n_subsampled_points_likelihood.csv",
				dataset = "hopper",
				dest_folder = "plots_for_paper/experiments_hopper_interp_n_subsampled_points_likelihood/",
				metrics_to_gather = ["30"],
				model_names = ["y0_ode_combine"],
				filename_to_gather = "hopper_imgs/{}/true_traj_" + str(j) + "_{}-0" + str(i) + ".jpg",
				rerun_plots = False)




def get_commands_for_best_runs():
	experiments = glob.glob("results/experim*.csv")
	
	to_rerun_file = "need_rerun.txt"

	for experim_file_name in experiments:
		basename = os.path.basename(experim_file_name)
		destfile = "results/run_" + basename[basename.find("_"):-4]  + ".sh"
		dataset = basename[ basename.find("_")+1 : (basename[basename.find("_")+1:].find("_") + basename.find("_") + 1)]

		if not os.path.isfile(experim_file_name):
			print("File " + experim_file_name + " does not exist")
			return

		res_dict = pd.read_csv(experim_file_name, sep=",")

		print(res_dict)

		metrics_to_gather = res_dict.columns[1:]
		model_names = res_dict.iloc[:,0]

		for metric_name in metrics_to_gather:
			for model_name in model_names:
				if not (res_dict["model"] == model_name).any():
					print("Model " + model_name + " not found")
					continue
				
				if metric_name not in res_dict.columns:
					print("Column " + metric_name + " not found")
					continue

				experimID = res_dict.loc[res_dict["model"] == model_name, metric_name].item()

				if math.isnan(experimID):
					print("WARNING: Experiment " + str(metric_name) + " / " +  str(model_name) + " is missing -- needs re-run")
					continue

				experimID = int(experimID)
				run_command = get_run_command(experimID, dataset, to_rerun_file)
				# If we need to store the best results as well
				if run_command is not None:
					os.system(run_command + " --load " + str(experimID))

				if run_command is not None:
					with open(destfile, 'a') as file:
						file.write(run_command + "\n")
				


def shift_outputs(outputs, first_datapoint = None):
	outputs = outputs[:,:,:-1,:]

	if first_datapoint is not None:
		n_traj, n_dims = first_datapoint.size()
		first_datapoint = first_datapoint.reshape(1, n_traj, 1, n_dims)
		outputs = torch.cat((first_datapoint, outputs), 2)
	return outputs



def split_data_extrap(data_dict, data_type = "train", concat_globs = False, dataset = ""):
	device = get_device(data_dict[data_type + "_y"])

	n_observed_tp = data_dict[data_type + "_y"].size(1) // 2
	if dataset == "hopper":
		n_observed_tp = data_dict[data_type + "_y"].size(1) // 3

	split_dict = {"observed_data": data_dict[data_type + "_y"][:,:n_observed_tp,:].clone(),
				"observed_tp": data_dict[data_type + "_time_steps"][:n_observed_tp].clone(),
				"data_to_predict": data_dict[data_type + "_y"][:,n_observed_tp:,:].clone(),
				"tp_to_predict": data_dict[data_type + "_time_steps"][n_observed_tp:].clone()}

	split_dict["observed_mask"] = None 
	split_dict["mask_predicted_data"] = None 
	split_dict["labels"] = None 

	if (data_type + "_mask" in data_dict) and (data_dict[data_type + "_mask"] is not None):
		split_dict["observed_mask"] = data_dict[data_type + "_mask"][:, :n_observed_tp].clone()
		split_dict["mask_predicted_data"] = data_dict[data_type + "_mask"][:, n_observed_tp:].clone()

	if (data_type + "_labels" in data_dict) and (data_dict[data_type + "_labels"] is not None):
		split_dict["labels"] = data_dict[data_type + "_labels"].clone()

	if data_type + "_globs" in data_dict:
		split_dict["globs"] = data_dict[data_type + "_globs"]

	if concat_globs:
		n_tp = split_dict["observed_data"].size(1)
		globs_repeated = split_dict["globs"].unsqueeze(1).repeat(1,n_tp,1)
		split_dict["observed_data"] = torch.cat((split_dict["observed_data"], globs_repeated), -1)
		split_dict["observed_mask"] = torch.cat((split_dict["observed_mask"], 
			torch.ones(globs_repeated.size()).to(device) ), -1)

	return split_dict




def split_data_interp(data_dict, data_type = "train", concat_globs = False):
	device = get_device(data_dict[data_type + "_y"])

	split_dict = {"observed_data": data_dict[data_type + "_y"].clone(),
				"observed_tp": data_dict[data_type + "_time_steps"].clone(),
				"data_to_predict": data_dict[data_type + "_y"].clone(),
				"tp_to_predict": data_dict[data_type + "_time_steps"].clone()}

	split_dict["observed_mask"] = None 
	split_dict["mask_predicted_data"] = None 
	split_dict["labels"] = None 

	if data_type + "_mask" in data_dict and data_dict[data_type + "_mask"] is not None:
		split_dict["observed_mask"] = data_dict[data_type + "_mask"].clone()
		split_dict["mask_predicted_data"] = data_dict[data_type + "_mask"].clone()

	if (data_type + "_labels" in data_dict) and (data_dict[data_type + "_labels"] is not None):
		split_dict["labels"] = data_dict[data_type + "_labels"].clone()

	if data_type + "_globs" in data_dict:
		split_dict["globs"] = data_dict[data_type + "_globs"]


	# correct observed_data in the plotting function!!!
	if concat_globs:
		n_tp = split_dict["observed_data"].size(1)
		globs_repeated = split_dict["globs"].unsqueeze(1).repeat(1,n_tp,1)
		split_dict["observed_data"] = torch.cat((split_dict["observed_data"], globs_repeated), -1)
		split_dict["observed_mask"] = torch.cat((split_dict["observed_mask"], 
			torch.ones(globs_repeated.size()).to(device) ), -1)

	return split_dict



def split_data_for_pred_from_labels(data_dict, data_type = "train", concat_globs = False):
	device = get_device(data_dict[data_type + "_y"])

	n_labels = data_dict[data_type + "_labels"].size(-1)

	split_dict = {"observed_data": data_dict[data_type + "_labels"].clone(),
				"observed_tp": data_dict[data_type + "_time_steps"].clone(),
				"data_to_predict": data_dict[data_type + "_y"].clone(),
				"tp_to_predict": data_dict[data_type + "_time_steps"].clone()}

	split_dict["observed_mask"] = None 
	split_dict["mask_predicted_data"] = None 
	split_dict["labels"] = None 

	if data_type + "_mask" in data_dict and data_dict[data_type + "_mask"] is not None:
		# Make mask for labels
		mask = data_dict[data_type + "_mask"].clone()
		# Find the time points with at least one observation
		mask = torch.sum(mask, -1) > 0

		# repeat the mask for each label to mark that the label for this time point is present
		pred_mask = mask.repeat(n_labels, 1,1).permute(1,2,0)

		split_dict["observed_mask"] = pred_mask
		split_dict["mask_predicted_data"] = data_dict[data_type + "_mask"].clone()

	# if (data_type + "_labels" in data_dict) and (data_dict[data_type + "_labels"] is not None):
	# 	split_dict["labels"] = data_dict[data_type + "_y"].clone()

	if data_type + "_globs" in data_dict:
		split_dict["globs"] = data_dict[data_type + "_globs"]

	if concat_globs:
		n_tp = split_dict["observed_data"].size(1)
		globs_repeated = split_dict["globs"].unsqueeze(1).repeat(1,n_tp,1)
		split_dict["observed_data"] = torch.cat((split_dict["observed_data"], globs_repeated), -1)
		split_dict["observed_mask"] = torch.cat((split_dict["observed_mask"].float(), 
			torch.ones(globs_repeated.size()).to(device) ), -1)

	return split_dict



def subsample_observed_data(data_dict, n_tp_to_sample):
	# Subsample time points
	data, time_steps, mask = subsample_timepoints(
		data_dict["observed_data"].clone(), 
		time_steps = data_dict["observed_tp"].clone(), 
		mask = (data_dict["observed_mask"].clone() if data_dict["observed_mask"] is not None else None),
		n_tp_to_sample = n_tp_to_sample)

	new_data_dict = {}
	for key in data_dict.keys():
		new_data_dict[key] = data_dict[key]

	new_data_dict["observed_data"] = data.clone()
	new_data_dict["observed_tp"] = time_steps.clone()
	new_data_dict["observed_mask"] = mask.clone()
	return new_data_dict





def compute_loss_all_batches(model, test_dict, args,
	experimentID,
	n_traj_samples = 1, kl_coef = 1., 
	max_samples_for_eval = None):

	total = {}
	total["loss"] = 0
	total["likelihood"] = 0
	total["mse"] = 0
	total["kl_first_p"] = 0
	total["std_first_p"] = 0
	total["n_calls"] = 0
	total["pois_likelihood"] = 0
	total["ce_loss"] = 0

	n_test_batches = 0
	device = get_device(test_dict["data_to_predict"])
	classif_predictions = torch.Tensor([]).to(device)
	all_test_labels =  torch.Tensor([]).to(device)

	batch_size = args.batch_size
	n_traj = test_dict["observed_data"].size(0)
	n_batches = n_traj // batch_size + 1

	for i in range(n_batches):
		print("Computing loss... " + str(i))
		
		batch_dict = get_next_batch(test_dict, i, batch_size)
		results  = model.compute_all_losses(batch_dict,
			n_traj_samples = n_traj_samples, kl_coef = kl_coef)

		if args.classif:





			# for j in range(batch_dict["labels"].size(0)):
			# 	labeled_tp = torch.sum(batch_dict["labels"][j], -1) > 0.

			# 	all_test_labels_l = batch_dict["labels"][j][labeled_tp]
			# 	classif_predictions_l = results["label_predictions"][0,j][labeled_tp]


			# 	# classif_predictions and all_test_labels are in on-hot-encoding -- convert to class ids
			# 	_, pred_class_id = torch.max(classif_predictions_l, -1)
			# 	_, class_labels = torch.max(all_test_labels_l, -1)


				
			# 	error_idx = [i for i in range(len(pred_class_id)) if (pred_class_id[i].cpu().numpy() != class_labels[i].cpu().numpy())]

			# 	print("here")
			# 	print(error_idx)
			# 	print(pred_class_id[error_idx])
			# 	print(class_labels[error_idx])








			classif_predictions = torch.cat((classif_predictions, 
				results["label_predictions"]),1)
			all_test_labels = torch.cat((all_test_labels, 
				batch_dict["labels"]),0)
			
		for key in total.keys(): 
			if key in results:
				var = results[key]
				if isinstance(var, torch.Tensor):
					var = var.detach()
				total[key] += var

		n_test_batches += 1

		# for speed
		if max_samples_for_eval is not None:
			if n_batches * batch_size >= max_samples_for_eval:
				break

	if n_test_batches > 0:
		for key, value in total.items():
			total[key] = total[key] / n_test_batches
 
	n_traj_samples = classif_predictions.size(0)
	if args.classif:
		if args.dataset == "physionet":
			all_test_labels = all_test_labels.reshape(-1)

			idx_not_nan = 1 - torch.isnan(all_test_labels)
			classif_predictions = classif_predictions[:,idx_not_nan]
			all_test_labels = all_test_labels[idx_not_nan]

			dirname = "plots/" + str(experimentID) + "/"
			os.makedirs(dirname, exist_ok=True)
			
			# print("classif_predictions")
			# print(classif_predictions.size())

			# print("all_test_labels")
			# print(all_test_labels.size())

			file = dirname + "fold_" + str(args.fold) + ".csv"
			with open(file, "w") as f:
				d = pd.DataFrame(
					torch.cat(
						(classif_predictions[0].unsqueeze(-1), 
						all_test_labels.unsqueeze(-1)), 
					1).cpu().numpy())
				d.to_csv(f, index=False)
			
			total["auc"] = 0.
			if torch.sum(all_test_labels) != 0.:
				print("Fold " + str(args.fold))
				print("Number of labeled examples: {}".format(len(all_test_labels.reshape(-1))))
				print("Number of examples with mortality 1: {}".format(torch.sum(all_test_labels == 1.)))

				# For each trajectory, we get n_traj_samples samples from y0 -- compute loss on all of them
				all_test_labels = all_test_labels.repeat(n_traj_samples, 1)

				# Cannot compute AUC with only 1 class
				total["auc"] = sk.metrics.roc_auc_score(all_test_labels.cpu().numpy().reshape(-1), 
					classif_predictions.cpu().numpy().reshape(-1))
			else:
				print("Warning: Couldn't compute AUC -- all examples are from the same class")
		
		if args.dataset == "activity":
			all_test_labels = all_test_labels.repeat(n_traj_samples, 1,1,1)

			labeled_tp = torch.sum(all_test_labels, -1) > 0.

			all_test_labels = all_test_labels[labeled_tp]
			classif_predictions = classif_predictions[labeled_tp]

			# classif_predictions and all_test_labels are in on-hot-encoding -- convert to class ids
			_, pred_class_id = torch.max(classif_predictions, -1)
			_, class_labels = torch.max(all_test_labels, -1)

			total["accuracy"] = sk.metrics.accuracy_score(
					class_labels.cpu().numpy(), 
					pred_class_id.cpu().numpy())

	return total





