import os
import matplotlib
import matplotlib.pyplot
import matplotlib.pyplot as plt

import lib.utils as utils
import numpy as np
import tarfile
import torch
from torch.utils.data import DataLoader
from torchvision.datasets.utils import download_url
from lib.utils import get_device

class PhysioNet(object):

	urls = [
		'https://physionet.org/physiobank/database/challenge/2012/set-a.tar.gz',
		'https://physionet.org/physiobank/database/challenge/2012/set-b.tar.gz',
	]

	outcome_urls = ['https://physionet.org/physiobank/database/challenge/2012/Outcomes-a.txt']

	params = [
		'Age', 'Gender', 'Height', 'ICUType', 'Weight', 'Albumin', 'ALP', 'ALT', 'AST', 'Bilirubin', 'BUN',
		'Cholesterol', 'Creatinine', 'DiasABP', 'FiO2', 'GCS', 'Glucose', 'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'Mg',
		'MAP', 'MechVent', 'Na', 'NIDiasABP', 'NIMAP', 'NISysABP', 'PaCO2', 'PaO2', 'pH', 'Platelets', 'RespRate',
		'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT', 'Urine', 'WBC'
	]

	global_params = ['Age', 'Gender', 'Height', 'ICUType']

	params_dict = {k: i for i, k in enumerate(params)}

	labels = [ "SAPS-I", "SOFA", "Length_of_stay", "Survival", "In-hospital_death" ]
	labels_dict = {k: i for i, k in enumerate(labels)}

	def __init__(self, root, train=True, download=False,
		quantization = 0.1, n_samples = None, device = torch.device("cpu")):

		self.root = root
		self.train = train
		self.reduce = "average"
		self.quantization = quantization

		if download:
			self.download()

		if not self._check_exists():
			raise RuntimeError('Dataset not found. You can use download=True to download it')

		if self.train:
			data_file = self.training_file
		else:
			data_file = self.test_file
		
		if device == torch.device("cpu"):
			self.data = torch.load(os.path.join(self.processed_folder, data_file), map_location='cpu')
			self.labels = torch.load(os.path.join(self.processed_folder, self.label_file), map_location='cpu')
		else:
			self.data = torch.load(os.path.join(self.processed_folder, data_file))
			self.labels = torch.load(os.path.join(self.processed_folder, self.label_file))

		if n_samples is not None:
			self.data = self.data[:n_samples]
			self.labels = self.labels[:n_samples]

	def download(self):
		if self._check_exists():
			return

		self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

		os.makedirs(self.raw_folder, exist_ok=True)
		os.makedirs(self.processed_folder, exist_ok=True)

		# Download outcome data
		for url in self.outcome_urls:
			filename = url.rpartition('/')[2]
			download_url(url, self.raw_folder, filename, None)

			txtfile = os.path.join(self.raw_folder, filename)
			with open(txtfile) as f:
				lines = f.readlines()
				outcomes = {}
				for l in lines[1:]:
					l = l.rstrip().split(',')
					record_id, labels = l[0], np.array(l[1:]).astype(float)
					outcomes[record_id] = torch.Tensor(labels).to(self.device)

				torch.save(
					labels,
					os.path.join(self.processed_folder, filename.split('.')[0] + '.pt')
				)

		for url in self.urls:
			filename = url.rpartition('/')[2]
			download_url(url, self.raw_folder, filename, None)
			tar = tarfile.open(os.path.join(self.raw_folder, filename), "r:gz")
			tar.extractall(self.raw_folder)
			tar.close()

			print('Processing {}...'.format(filename))

			dirname = os.path.join(self.raw_folder, filename.split('.')[0])
			patients = []
			total = 0
			for txtfile in os.listdir(dirname):
				record_id = txtfile.split('.')[0]
				with open(os.path.join(dirname, txtfile)) as f:
					lines = f.readlines()
					prev_time = 0
					tt = [0.]
					vals = [torch.zeros(len(self.params)).to(self.device)]
					mask = [torch.zeros(len(self.params)).to(self.device)]
					nobs = [torch.zeros(len(self.params))]
					for l in lines[1:]:
						total += 1
						time, param, val = l.split(',')
						# Time in hours
						time = float(time.split(':')[0]) + float(time.split(':')[1]) / 60.
						# round up the time stamps (up to 6 min by default)
						# used for speed -- we actually don't need to quantize it in Latent ODE
						time = round(time / self.quantization) * self.quantization

						if time != prev_time:
							tt.append(time)
							vals.append(torch.zeros(len(self.params)).to(self.device))
							mask.append(torch.zeros(len(self.params)).to(self.device))
							nobs.append(torch.zeros(len(self.params)).to(self.device))
							prev_time = time

						if param in self.params_dict:
							#vals[-1][self.params_dict[param]] = float(val)
							n_observations = nobs[-1][self.params_dict[param]]
							if self.reduce == 'average' and n_observations > 0:
								prev_val = vals[-1][self.params_dict[param]]
								new_val = (prev_val * n_observations + float(val)) / (n_observations + 1)
								vals[-1][self.params_dict[param]] = new_val
							else:
								vals[-1][self.params_dict[param]] = float(val)
							mask[-1][self.params_dict[param]] = 1
							nobs[-1][self.params_dict[param]] += 1
						else:
							assert param == 'RecordID', 'Read unexpected param {}'.format(param)
				tt = torch.tensor(tt).to(self.device)
				vals = torch.stack(vals)
				mask = torch.stack(mask)

				labels = None
				if record_id in outcomes:
					# Only training set has labels
					labels = outcomes[record_id]

				# Get variables that don't change like Age, Gender, etc.
				idx = [self.params_dict[param] for param in self.global_params]
				glob = vals[0,idx]

				remaining_idx = [i for i in range(len(self.params)) if i not in idx]
				vals =  vals[:,remaining_idx]
				mask = mask[:, remaining_idx]

				patients.append((record_id, tt, vals, mask, glob, labels))

			print("Original number of points")
			print(total)

			torch.save(
				patients,
				os.path.join(self.processed_folder, 
					filename.split('.')[0] + "_" + str(self.quantization) + '.pt')
			)
				
		print('Done!')

	def _check_exists(self):
		for url in self.urls:
			filename = url.rpartition('/')[2]

			if not os.path.exists(
				os.path.join(self.processed_folder, 
					filename.split('.')[0] + "_" + str(self.quantization) + '.pt')
			):
				return False
		return True

	@property
	def raw_folder(self):
		return os.path.join(self.root, self.__class__.__name__, 'raw')

	@property
	def processed_folder(self):
		return os.path.join(self.root, self.__class__.__name__, 'processed')

	@property
	def training_file(self):
		return 'set-a_{}.pt'.format(self.quantization)

	@property
	def test_file(self):
		return 'set-b_{}.pt'.format(self.quantization)

	@property
	def label_file(self):
		return 'Outcomes-a.pt'

	def __getitem__(self, index):
		return self.data[index]

	def __len__(self):
		return len(self.data)

	def get_label(self, record_id):
		return self.labels[record_id]

	def __repr__(self):
		fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
		fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
		fmt_str += '    Split: {}\n'.format('train' if self.train is True else 'test')
		fmt_str += '    Root Location: {}\n'.format(self.root)
		fmt_str += '    Quantization: {}\n'.format(self.quantization)
		fmt_str += '    Reduce: {}\n'.format(self.reduce)
		return fmt_str

	def visualize(self, timesteps, data, mask, plot_name):
		width = 15
		height = 15

		non_zero_attributes = (torch.sum(mask,0) > 2).numpy()
		non_zero_idx = [i for i in range(len(non_zero_attributes)) if non_zero_attributes[i] == 1.]
		n_non_zero = sum(non_zero_attributes)

		mask = mask[:, non_zero_idx]
		data = data[:, non_zero_idx]
		
		params_non_zero = [self.params[i] for i in non_zero_idx]
		params_dict = {k: i for i, k in enumerate(params_non_zero)}

		n_col = 3
		n_row = n_non_zero // n_col + (n_non_zero % n_col > 0)
		fig, ax_list = plt.subplots(n_row, n_col, figsize=(width, height), facecolor='white')

		#for i in range(len(self.params)):
		for i in range(n_non_zero):
			param = params_non_zero[i]
			param_id = params_dict[param]

			tp_mask = mask[:,param_id].long()

			tp_cur_param = timesteps[tp_mask == 1.]
			data_cur_param = data[tp_mask == 1., param_id]

			ax_list[i // n_col, i % n_col].plot(tp_cur_param.numpy(), data_cur_param.numpy(),  marker='o') 
			ax_list[i // n_col, i % n_col].set_title(param)

		fig.tight_layout()
		fig.savefig(plot_name)
		plt.close(fig)


def variable_time_collate_fn(batch, device = torch.device("cpu")):
	"""
	Expects a batch of time series data in the form of (record_id, tt, vals, mask, globs, labels) where
		- record_id is a patient id
		- tt is a 1-dimensional tensor containing T time values of observations.
		- vals is a (T, D) tensor containing observed values for D variables.
		- mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise.
		- globs is a tensor containing time-invariant features of each patient, like age, gender, etc.
		- labels is a list of labels for the current patient, if labels are available. Otherwise None.
	Returns:
		combined_tt: The union of all time observations.
		combined_vals: (M, T, D) tensor containing the observed values.
		combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise.
	"""
	D = batch[0][2].shape[1]
	combined_tt, inverse_indices = torch.unique(torch.cat([ex[1] for ex in batch]), sorted=True, return_inverse=True)
	combined_tt = combined_tt.to(device)

	offset = 0
	combined_vals = torch.zeros([len(batch), len(combined_tt), D]).to(device)
	combined_mask = torch.zeros([len(batch), len(combined_tt), D]).to(device)
	
	combined_labels = None
	# figure out how many labels there are
	# recall that some patients just have None -- labels not available

	# Take 5-th element in tuple corresponding to labels
	N_labels = [len(batch[i][5]) for i in range(len(batch)) if batch[i][5] is not None]
	if len(N_labels) > 0:
		# Hopefully, size of labels in all patients is the same
		N_labels = max([N_labels[i] for i in range(len(N_labels)) if N_labels[i] is not None])
	else:
		N_labels = 5

	combined_labels = torch.zeros(len(batch), N_labels) + torch.tensor(float('nan'))
	combined_labels = combined_labels.to(device = device)
	globs = []
	
	for b, (record_id, tt, vals, mask, glob, labels) in enumerate(batch):
		tt = tt.to(device)
		vals = vals.to(device)
		mask = mask.to(device)
		glob = glob.to(device)
		if labels is not None:
			labels = labels.to(device)

		indices = inverse_indices[offset:offset + len(tt)]
		offset += len(tt)

		combined_vals[b, indices] = vals
		combined_mask[b, indices] = mask

		if labels is not None:
			combined_labels[b] = labels

		globs.append(glob)
	combined_vals, _, _ = utils.normalize_data(combined_vals)
	if torch.max(combined_tt) != 0.:
		combined_tt = combined_tt / torch.max(combined_tt)

	globs = torch.stack(globs)
	globs, _, _ = utils.normalize_data(globs)

	# Use only last label -- mortatility in the hospital 
	return combined_tt, combined_vals, combined_mask, globs, combined_labels[:,4]


if __name__ == '__main__':
	torch.manual_seed(1991)

	dataset = PhysioNet('data/physionet', train=False, download=True)
	dataloader = DataLoader(dataset, batch_size=10, shuffle=True, collate_fn=variable_time_collate_fn)
	print(dataloader.__iter__().next())
