import os
import numpy as np

import torch
import torch.nn as nn

import lib.utils as utils
from lib.diffeq_solver import DiffeqSolver
from lib.datasets import Spirals, NavierStokes, SDESamples, Household, AirQuality
from generate_timeseries import Periodic_1d, BimodalSigmoid, Periodic_2StartingPoints, Periodic2d, PeriodicUpStraightDown, Zigzag
from torch.distributions import uniform

from torch.utils.data import DataLoader
from mujoco_physics import HopperPhysics
from physionet import PhysioNet, variable_time_collate_fn
from person_activity import PersonActivity, variable_time_collate_fn_activity

from sklearn import model_selection
import random
from sklearn.model_selection import KFold

#####################################################################################################
def parse_datasets(args, device):
	dataset_name = args.dataset

	n_total_tp = args.timepoints + args.extrap
	max_t_extrap = args.max_t / args.timepoints * n_total_tp

	if dataset_name == "pickle":
		if args.datafile is None:
			raise Exception("Please provide the pickle file with the dataset using --datafile argument.")

		if not os.path.isfile(args.datafile):
			raise Exception("Data file {} does not exist".format(args.datafile))

		data_dict = utils.load_pickle(args.datafile)["data"]
		for key, value in data_dict.items():
			if value is not None:
				data_dict[key] = value.to(device)
		return data_dict

	if dataset_name == "hopper":
		# ugly -- figure out a better way
		args.obsrv_std = 1e-3

		dataset_obj = HopperPhysics(root='data', train=True, generate=False, device = device)
		dataset = dataset_obj.get_dataset()[:args.n]
		dataset = dataset.to(device)

		if args.dim_only is not None:
			dataset = dataset[:,:,:args.dim_only]

		n_tp_data = dataset[:].shape[1]

		# Time steps that are used later on for exrapolation
		time_steps = torch.arange(start=0, end = n_tp_data, step=1).float().to(device)
		time_steps = time_steps / len(time_steps)

		dataset = dataset.to(device)
		time_steps = time_steps.to(device)

		if not args.extrap:
			# Creating dataset for interpolation
			# sample time points from different parts of the timeline, 
			# so that the model learns from different parts of hopper trajectory
			n_traj = len(dataset)
			n_tp_data = dataset.shape[1]
			n_reduced_tp = args.timepoints

			# sample time points from different parts of the timeline, 
			# so that the model learns from different parts of hopper trajectory
			start_ind = np.random.randint(0, high=n_tp_data - n_reduced_tp +1, size=n_traj)
			end_ind = start_ind + n_reduced_tp
			sliced = []
			for i in range(n_traj):
				  sliced.append(dataset[i, start_ind[i] : end_ind[i], :])
			dataset = torch.stack(sliced).to(device)
			time_steps = time_steps[:n_reduced_tp]

		# Split into train and test by the time sequences
		train_time_steps = test_time_steps = time_steps
		train_y, test_y = utils.split_train_test(dataset, train_fraq = 0.8)

		print("train_y")
		print(train_y.size())

		data_dict = {"dataset_obj": dataset_obj,
					"train_y": train_y, 
					"test_y": test_y, 
					"train_time_steps": train_time_steps, 
					"test_time_steps": test_time_steps}
		return data_dict
		

	if dataset_name == "physionet":
		train_dataset_obj = PhysioNet('data/physionet', train=True, 
										quantization = args.quantization,
										download=True, n_samples = min(10000, args.n), 
										device = device)
		# Use custom collate_fn to combine samples with arbitrary time observations.
		# Returns the dataset along with mask and time steps
		batch_size = min(min(len(train_dataset_obj), args.batch_size), args.n)
		test_dataset_obj = PhysioNet('data/physionet', train=False, 
										quantization = args.quantization,
										download=True, n_samples = min(10000, args.n), 
										device = device)

		# Combine and shuffle samples from physionet Train and physionet Test
		total_dataset = train_dataset_obj[:len(train_dataset_obj)]

		if not args.classif:
			# Only 'training' physionet samples are have labels. Therefore, if we do classifiction task, we don't need physionet 'test' samples.
			total_dataset = total_dataset + test_dataset_obj[:len(test_dataset_obj)]

		if args.classif and args.cv > 0:
			kf = KFold(n_splits=args.cv, shuffle=True, random_state=42)
			for i, (train_idx, test_idx) in enumerate(kf.split(total_dataset)):
				if i == args.fold:
					np.random.shuffle(train_idx)
					np.random.shuffle(test_idx)

					train_data = [total_dataset[i] for i in train_idx]
					test_data = [total_dataset[i] for i in test_idx] 
		else:
			# Shuffle and split
			train_data, test_data = model_selection.train_test_split(total_dataset, train_size= 0.8, 
				random_state = 42, shuffle = True)

		train_dataloader = DataLoader(train_data, batch_size= min(10000, args.n), shuffle=False, 
			collate_fn= lambda batch: variable_time_collate_fn(batch, device))
		test_dataloader = DataLoader(test_data, batch_size=min(10000, args.n), shuffle=False, 
			collate_fn= lambda batch: variable_time_collate_fn(batch, device))

		# Make the union of all time points and perform normalization across the whole dataset
		train_time_steps, train_y, train_mask, train_globs, train_labels = train_dataloader.__iter__().next()
		test_time_steps, test_y, test_mask, test_globs, test_labels = test_dataloader.__iter__().next()

		print("# Points after discretization:")
		print(torch.sum(train_mask) + torch.sum(test_mask))

		print("Total number of patients in training set:")
		print(len(train_dataset_obj.data))

		print("data dims:")
		print(train_y.size())

		print("for each patient, avg number of attributes that have at least 1 value")
		n = []
		for i in range(train_mask.size(0)):
			n.append(torch.sum(torch.mean(train_mask[i:(i+1),:,:],1) != 0))
		print(torch.mean(torch.stack(n).float()))

		print("number of non-missing values per attribute in 1st patient:")
		print(torch.sum(train_mask[:1,:,:],1))

		print("avg number of observations per time series across all patients")
		print(torch.mean(torch.sum(train_mask[:,:,:],1),0))

		# Visualize dataset
		# for i in range(10):
		# 	train_dataset_obj.visualize(tp, train_y[i], train_mask[i], 
		# 		"physionet_patient_{}.png".format(i))

		if args.dim_only is not None:
			train_y = train_y[:,:,:args.dim_only]
			test_y = test_y[:,:,:args.dim_only]
			train_mask = train_mask[:,:,:args.dim_only]
			test_mask = test_mask[:,:,:args.dim_only]

		n_tp_data = train_y.size(1)

		# Excluding global parameters like age, gender, etc.
		attr_names = [param for param in train_dataset_obj.params if param not in train_dataset_obj.global_params]
		data_dict = {"dataset_obj": train_dataloader, 
					"dataset_obj_train": train_dataloader, 
					"dataset_obj_test": test_dataloader,
					"train_y": train_y, 
					"test_y": test_y, 
					"train_time_steps": train_time_steps, # time steps for training sequences 
					"test_time_steps": test_time_steps, # time steps for test sequences 
					"train_mask": train_mask,
					"test_mask": test_mask,
					"train_labels": train_labels,
					"test_labels": test_labels,
					"train_globs": train_globs,
					"test_globs": test_globs,
					"attr": attr_names}  
		return data_dict

	if dataset_name == "activity":
		dataset_obj = PersonActivity('data/PersonActivity', 
							download=True, n_samples = 10000, device = device)
		print(dataset_obj)
		# Use custom collate_fn to combine samples with arbitrary time observations.
		# Returns the dataset along with mask and time steps
		batch_size = min(min(len(dataset_obj), args.batch_size), args.n)
		# Shuffle and split
		train_data, test_data = model_selection.train_test_split(dataset_obj, train_size= 0.8, 
			random_state = 42, shuffle = True)

		train_data = [train_data[i] for i in np.random.choice(len(train_data), len(train_data))]
		test_data = [test_data[i] for i in np.random.choice(len(test_data), len(test_data))]

		train_dataloader = DataLoader(train_data, batch_size= args.n, shuffle=False, 
			collate_fn= lambda batch: variable_time_collate_fn_activity(batch, device))
		test_dataloader = DataLoader(test_data, batch_size=args.n, shuffle=False, 
			collate_fn= lambda batch: variable_time_collate_fn_activity(batch, device))

		# Make the union of all time points 
		train_time_steps, train_y, train_mask, train_labels, train_globs = train_dataloader.__iter__().next()
		test_time_steps, test_y, test_mask, test_labels, test_globs = test_dataloader.__iter__().next()

		print("train_y")
		print(train_y.size())
		print("test_y")
		print(test_y.size())

		if args.dim_only is not None:
			train_y = train_y[:,:,:args.dim_only]
			test_y = test_y[:,:,:args.dim_only]
			train_mask = train_mask[:,:,:args.dim_only]
			test_mask = test_mask[:,:,:args.dim_only]

		n_timepoints = args.timepoints

		train_y = train_y[:,:n_timepoints]
		test_y = test_y[:,:n_timepoints]
		train_mask = train_mask[:,:n_timepoints]
		test_mask = test_mask[:,:n_timepoints]

		train_labels = train_labels[:,:n_timepoints]
		test_labels = test_labels[:,:n_timepoints]
		train_time_steps = train_time_steps[:n_timepoints]
		test_time_steps = test_time_steps[:n_timepoints]

		n_labels = train_labels.size(-1)
		class_weights = torch.zeros(n_labels)
		
		print(n_labels)
		for i in range(n_labels):
			class_weights[i] = torch.sum(torch.max(train_labels[train_mask[:,:,0] == 1.],-1)[1] == i)

		print("classes")
		print(class_weights / torch.sum(class_weights))

		class_weights = torch.sum(class_weights)/class_weights
		class_weights = class_weights/torch.max(class_weights)
		print(class_weights)

		n_tp_data = train_y.size(1)

		data_dict = {"dataset_obj": train_dataloader, 
					"dataset_obj_train": train_dataloader, 
					"dataset_obj_test": test_dataloader,
					"train_y": train_y, 
					"test_y": test_y, 
					"train_time_steps": train_time_steps, # time steps for training sequences 
					"test_time_steps": test_time_steps, # time steps for test sequences 
					"train_mask": train_mask,
					"test_mask": test_mask,
					"train_labels": train_labels,
					"test_labels": test_labels,
					"class_weights": class_weights,
					"train_globs": train_globs,
					"test_globs": test_globs}
		return data_dict

	# Small datasets that fit into memory
	if dataset_name == "house":
		dataset_obj = Household(device = device)
		dataset_obj.init_visualization()
		dataset_obj.init_dataset()
		dataset = dataset_obj.get_batches(tp_per_batch = n_total_tp, n_samples = args.n)

		time_steps_extrap = torch.arange(start=0, end = dataset.shape[1], step=1)
		time_steps_extrap = time_steps_extrap / len(time_steps_extrap)

	elif dataset_name == "airquality":
		dataset_obj = AirQuality(device = device)
		dataset_obj.init_visualization()
		dataset_obj.init_dataset()
		dataset, time_steps_extrap = dataset_obj.get_batches(tp_per_batch = n_total_tp, n_samples = args.n)
		if args.dim_only is not None:
			dataset = dataset[:,:,:args.dim_only]
		# here time_steps_extrap is 1 2d tensor with the list of time stamps for each trajectory
	else:
		# Sampling args.timepoints time points in the interval [0, args.max_t]
		# Sample points for both training sequence and explapolation (test)
		distribution = uniform.Uniform(torch.Tensor([0.0]),torch.Tensor([max_t_extrap]))
		time_steps_extrap =  distribution.sample(torch.Size([n_total_tp-1]))[:,0]
		time_steps_extrap = torch.cat((torch.Tensor([0.0]), time_steps_extrap))
		time_steps_extrap = torch.sort(time_steps_extrap)[0]

		dataset_obj = None
		###################################################################
		# Sample from ODE A(t-B) + w z; z~Wiener process
		if dataset_name == "parabola":
			dataset_obj = SDESamples(method = args.method, device = device, noise_weight = args.noise_weight, y0 = 0.)

		###################################################################
		# The dataset of trajectories splitted two-ways
		if dataset_name == "bimodal":
			split_point, union_point = bimodal.get_split_points([0., args.max_t])
			dataset_obj = BimodalSigmoid(split_point, union_point, device = device, y0 = 0.) # good noise_weight: 0.1
		##################################################################
		# Sample a periodic function
		if dataset_name == "periodic":
			# Increasing amplitude and frequency
			# dataset_obj = Periodic_1d(
			# 	init_freq = 0.5, init_amplitude = 1.,
			# 	final_amplitude = 10., final_freq = 1., 
			# 	noise_generator = GaussianProcess(1., "WienerProcess"))

			# Increasing amplitude, constant frequency
			# dataset_obj = Periodic_1d(
			# 	init_freq = 0.5, init_amplitude = 1.,
			# 	final_amplitude = 2., final_freq = 0.5, 
			# 	noise_generator = GaussianProcess(1., "WienerProcess"))

			dataset_obj = Periodic_1d(
				init_freq = None, init_amplitude = 1.,
				final_amplitude = 1., final_freq = None, 
				noise_generator = None, y0 = 1.)

		##################################################################
		# Sample a periodic function with two possible initial points

		if dataset_name == "periodic2s":
			dataset_obj = Periodic_2StartingPoints(
				init_freq = None, init_amplitude = 1.,
				final_amplitude = 1., final_freq = None, 
				noise_generator = None, y0 = 1.)

		##################################################################
		# Sample a 2d data where each dimension is a periodic function

		if dataset_name == "periodic2d":
			dataset_obj = Periodic2d(
				init_freq = None, init_amplitude = 1.,
				final_amplitude = 1., final_freq = None, 
				noise_generator = None, y0 = np.array([1.,1.]))

		##################################################################
		# Sample data going Up, Straight or Down with a periodic perturbation

		if dataset_name == "periodic3w": # stands for 3-way
			dataset_obj = PeriodicUpStraightDown(
				freq = None, amplitude = 1.,
				noise_generator = None, y0 = 1.)

		##################################################################
		# Sample zigzag dataset

		if dataset_name == "zigzag":
			dataset_obj = Zigzag(
				freq = None, amplitude = None,
				noise_generator = None, y0 = 1.)
		##################################################################

		if dataset_obj is None:
			raise Exception("Unknown dataset: {}".format(dataset_name))

		dataset = dataset_obj.sample_traj(time_steps_extrap, n_samples = args.n, 
			noise_weight = args.noise_weight)

		# cut the time dimension [:,:,0]
		# shape: [n_samples, n_tp, n_dim]
		dataset = dataset[:,:,1:]

	# Process small datasets
	dataset = dataset.to(device)
	time_steps_extrap = time_steps_extrap.to(device)

	train_time_steps = test_time_steps = time_steps_extrap
	train_y, test_y = utils.split_train_test(dataset, train_fraq = 0.8)

	data_dict = {"dataset_obj": dataset_obj,
				"train_y": train_y, 
				"test_y": test_y, 
				"train_time_steps": train_time_steps, 
				"test_time_steps": test_time_steps}

	return data_dict

