import os
import matplotlib
import matplotlib.pyplot
import matplotlib.pyplot as plt

import os
import time
import argparse
import numpy as np
from sklearn import model_selection
from scipy.misc import imread
import pandas

import torch
import torch.nn as nn
from torch.nn.functional import relu
import torch.optim as optim
from torch.nn.utils.spectral_norm import spectral_norm

import lib.utils as utils
#from lib.plotting import visualize
from lib.diffeq_solver import DiffeqSolver

from torch.distributions.normal import Normal

from lib.gaussian_process import GaussianProcess, test_gaussian_process, test_multidim_sampling
#from lib.odes import euler_solver as odeint
# git clone https://github.com/rtqichen/torchdiffeq.git
from torchdiffeq.torchdiffeq import odeint

class SDESamples():
	def __init__(self, method = "dopri5", device = torch.device("cpu"), noise_weight = 0., y0 = 0.):
		self.type = type

		self.true_A = torch.Tensor([10.]).to(device)
		self.true_B = torch.Tensor([0.5]).to(device)

		self.noise = GaussianProcess(1., "WienerProcess") # WhiteNoise(1.)
		self.gp_weight = torch.Tensor([float(noise_weight)]).to(device)
		
		self.y0 = torch.Tensor([y0]).to(device)
		#self.y0_noise_sigma = 0. 
	
		self.device = device
		self.method = method

	def true_ode_func(self, t, y, context = None, time_steps = None, n_samples = 1):
		# Solving SDE
		# x' = Ax + B
		# x(t0) = x0
		
		# Random vars: 
		# B ~ Wiener process

		# True solution
		# With B=0
		# x(t) = x0 exp(A(t-t0))

		test_multidim_sampling()

		noise = self.noise.sample_multidim([t.numpy()], dim = np.prod(y.size()) )
		noise = np.reshape(noise, y.size())

		gp_weight = self.gp_weight
		if context is not None:
			# to draw the noise-less model
			assert("noise_weight" in context)
			gp_weight = context["noise_weight"]

		#ode_f = self.true_A * y + self.true_B + self.gp_weight * torch.Tensor(noise)
		ode_f = self.true_A * (t-self.true_B) + gp_weight * torch.Tensor(noise).to(self.device)
		return ode_f

	def sample_traj(self, time_steps, n_samples = 1, noise_weight = 1., y0_noise_std = 0.1):
		# Similarly to Periodic_1d, return [n_samples, n_tp, n_dim +1]
		# The dimension [:,:, 0] is the dimension with time stamps
		y0 = self.y0

		if len(y0.size()) < 2:
			y0 = np.tile(y0, (n_samples, 1))
			y0 = torch.Tensor().new_tensor(y0, device = self.device)

		y0_noise = Normal(torch.Tensor([0.]).to(self.device), torch.Tensor([y0_noise_std]).to(self.device)).sample((n_samples,))
		y0 += y0_noise

		func = lambda t_local, y: self.true_ode_func(t_local, y, context = {"noise_weight": noise_weight})
		sol = odeint(func, y0, time_steps, 
			rtol = 1e-4, atol = 1e-6, method = self.method)
		sol = sol.permute(1,0,2)
		self.noise.clear()
		
		ts = torch.Tensor(time_steps).to(self.device).repeat(n_samples, 1).unsqueeze(2)
		return torch.cat((ts, sol),2)

class Spirals():
	def __init__(self, type = "chiralspiral", device = torch.device("cpu")):
		self.type = type
		self.device = device

	def get_dataset(self, n_timepoints, n_examples = 200):
		print("Reading " + self.type + " dataset...")
		# "spiral" for only counter-clockwise spirals
		# "chiralspiral" for spirals in both directions
		spirals, chir = utils.make_dataset(dataset_type=self.type)
		# Take only first 100 spirals
		spirals = spirals[:n_examples]
		chir = chir[:n_examples]
		# Remove time dimension
		spirals = spirals[:,:,1:]
		# Cut length of time series
		spirals = spirals[:,:n_timepoints]

		spirals_torch = torch.from_numpy(spirals).float().to(self.device)
		chir_torch = torch.from_numpy(chir).to(self.device)
		return [spirals_torch, chir_torch]

	def init_visualization(self):
		self.fig = plt.figure(figsize=(16, 8), facecolor='white')
		self.ax_traj = self.fig.add_subplot(241, frameon=False)
		self.ax_phase = self.fig.add_subplot(242, frameon=False)
		self.ax_vecfield = self.fig.add_subplot(243, frameon=False)
		self.ax_latents_pca = self.fig.add_subplot(244, frameon=False)
		self.ax_extrap_loss = self.fig.add_subplot(245, frameon=False)
		self.ax_new_seq_loss = self.fig.add_subplot(246, frameon=False)
		self.ax_new_plot = self.fig.add_subplot(247, frameon=False)
		self.ax_gp_samples = self.fig.add_subplot(248, frameon=False)

		plt.show(block=False)

	def visualize(self,truth, prediction, t, ode_grad_func = None, 
			itr = None, ode_args = None,
			encodings = None, chirality = None):

		visualize(self.fig, 
			[self.ax_traj, self.ax_phase, self.ax_vecfield, self.ax_latents_pca], 
			truth, prediction, t, ode_grad_func = ode_grad_func, 
			itr = itr, ode_args = ode_args,
			encodings = encodings,
			chirality = chirality)


class NavierStokes():
	def __init__(self, device = torch.device("cpu")):
		self.device = device

	def get_dataset(self, n_timepoints, downscale = 1):
		self.n_timepoints = n_timepoints
		#rows, cols = 110, 110
		basepath = os.path.dirname(__file__)

		data = utils.load_pickle("navier_stokes_dataset.pickle")[:-1]
		data = data[:, ::downscale, ::downscale]
		data = np.reshape(data, (data.shape[0],-1)) 
		data += 0.01
		
		#data = []
		#print("Reading Navier-Stokes dataset...")
		#for t in range(n_timepoints):
		#    file = os.path.join(basepath, 
		#        'navier_stokes_steps/', 'step{0:03d}.png'.format(t))
		#    if not os.path.isfile(file):
		#        break
		#    image = imread(file)
		#    image = image[:,:,0]
		#    image = image[::downscale, ::downscale]
		#    image = image / 255.0
		#    image = np.reshape(image, (-1))
		#    data.append(image)

		data = np.array(data)
		data = np.expand_dims(data, 0)
		return [torch.from_numpy(data).float().to(self.device), None]

	def init_visualization(self):
		self.fig = plt.figure(figsize=(8,4), facecolor='white')
		self.ax_pred = self.fig.add_subplot(121, frameon=False)
		self.ax_truth = self.fig.add_subplot(122, frameon=False)

	def visualize(self, truth, prediction, time, ode_grad_func = None, 
		itr = None, ode_args = None,
		encodings = None, chirality = None):

		# prediction -- prediction for one training example
		# prediction shape: [T, D]. 
		# T -- number of time points
		# D -- number of dimensions (input_dim)
		assert(prediction.size()[0] == len(time))
		assert(truth.size() == prediction.size())

		dim = int(np.sqrt(prediction.size()[1]))
		prediction = prediction.view(-1, dim, dim)
		truth = truth.view(-1, dim, dim)

		matplotlib.image.imsave('prediction.png', prediction.numpy()[-1])
		matplotlib.image.imsave('truth.png', truth.numpy()[-1])

		for t in range(len(time)):
			 plt.cla()

			 self.fig.suptitle('Time step ' + str(t))
			 #self.ax_pred.matshow(prediction[t])
			 #self.ax_truth.matshow(truth[t])
	
			 self.ax_pred.imshow(prediction[t], vmin=0)
			 self.ax_truth.imshow(truth[t], vmin=0)

			 self.ax_pred.set_title('Prediction')
			 self.ax_truth.set_title('Truth')

			 self.ax_pred.set_xticks([])
			 self.ax_pred.set_yticks([])
			 self.ax_truth.set_xticks([])
			 self.ax_truth.set_yticks([])
			 plt.savefig('navier_stokes_png/{:03d}'.format(t))
			 plt.draw()
			 plt.pause(0.001)
			 
			 # Make a movie: ffmpeg  -r 10 -i %03d.png -vb 20M  output.webm


class Household():
	def __init__(self, device = torch.device("cpu")):
		self.device = device

	def init_dataset(self, filename = "data/household_power_consumption.txt"):
		household = pandas.read_csv(filename, sep = ';', header = None)

		header = household[0]
		# Cut the header
		household = household[1:]

		# Replace missing data with -1
		household[household == '?'] = -1

		# Cut the columns with date and time
		household = household.iloc[:, 2:]
		household = np.array(household, dtype = np.float)

		# normalize the third column: voltage
		household[:,2] = household[:,2] / np.max(household[:,2])
		household[np.isnan(household)] = 0.

		# add the column with time stamp. 
		# Since the data (besides the missing part) is regular, just assign time points with equal gaps.
		# time_steps = torch.range(start=0, end = household.shape[0]-1, step=1)
		# time_steps = time_steps / len(time_steps)
		
		# household = np.concatenate((np.expand_dims(time_steps,1), household), 1)

		# Remove missing dataset
		missing_rows = np.where(household[:,1] != -1)[0]
		household = household[missing_rows]

		self.dataset = household
		return household

	def get_batches(self, tp_per_batch = None, n_samples = None):
		if tp_per_batch is not None:
			n_batches = self.dataset.shape[0] // tp_per_batch

			n_dims = self.dataset.shape[1]
			batches = np.reshape(self.dataset[:(n_batches * tp_per_batch)], (n_batches, tp_per_batch, n_dims))
		else:
			# Add extra dimension for the batch
			batches = np.expand_dims(self.dataset,0)

		if n_samples is not None:
			batches = batches[:n_samples]
		return torch.Tensor(batches).to(self.device)


	def init_visualization(self):
		self.fig = plt.figure(figsize=(8,4), facecolor='white')
		self.ax = self.fig.add_subplot(111, frameon=False)

	def visualize(self, dataset, dim = 6, max_obs = 10000):
		dataset = dataset[:,:max_obs]

		n_timepoints = dataset.shape[1]
		batch_id = 0
		time_steps = np.arange(n_timepoints)
		non_missing = (dataset[batch_id,:,dim] != -1)

		self.ax.plot(time_steps[non_missing], dataset[batch_id, non_missing, dim])
		plt.draw()

			 


class AirQuality():
	def __init__(self, device = torch.device("cpu")):
		self.device = device

	def init_dataset(self, filename = "data/AirQualityUCI.csv"):
		dataset = pandas.read_csv(filename, sep = ';', header = None)

		header = dataset[0]
		# Cut time and date
		dataset = dataset[1:]

		# Missing data is denoted as -200

		# Cut the columns with date and time
		dataset = dataset.iloc[:, 2:]
		dataset = dataset.iloc[:, :-2]

		dataset = np.array(dataset, dtype = np.float)

		# Normalize all columns
		for i in range(dataset.shape[1]):
			col = dataset[:,i]

			# avoid -200 values when computing the mean, std and also when computing updated col
			non_missing = np.where(col != -200)[0]
			mean = np.mean(col[non_missing])
			std = np.std(col[non_missing])

			col[non_missing] = (col[non_missing] - mean) / std
			dataset[:,i] = col

		# add the column with time stamp. 
		# Since the data (besides the missing part) is regular, just assign time points with equal gaps.
		time_steps = torch.arange(start=0, end = dataset.shape[0], step=1).type(torch.FloatTensor) 
		time_steps = time_steps / float(len(time_steps)) * 1000
		
		dataset = np.concatenate((np.expand_dims(time_steps,1), dataset), 1)

		# Remove missing dataset

		# Remove the forth column because of missing values (the first column contains the time stamps)
		ind_to_delete = 3
		dataset = np.concatenate((dataset[:, :ind_to_delete], dataset[:, (ind_to_delete+1):]), axis = 1)

		missing_rows = np.any(dataset == -200, 1)
		dataset = dataset[~missing_rows]

		self.dataset = dataset
		return dataset

	def get_batches(self, tp_per_batch = None, n_samples = None):
		if tp_per_batch is not None:
			n_batches = self.dataset.shape[0] // tp_per_batch

			n_dims = self.dataset.shape[1]
			batches = np.reshape(self.dataset[:(n_batches * tp_per_batch)], (n_batches, tp_per_batch, n_dims))
		else:
			# Add extra dimension for the batch
			batches = np.expand_dims(self.dataset,0)

		if n_samples is not None:
			batches = batches[:n_samples]

		batches = torch.Tensor(batches).to(self.device)

		# shape: [n_timepoints, n_traj]
		time_steps = torch.t(batches[:,:,0])
		# Make all sets of time points start from zero
		time_steps = time_steps - time_steps[0]

		batches = batches[:,:,1:]
		return batches, time_steps


	def init_visualization(self):
		self.fig = plt.figure(figsize=(8,4), facecolor='white')
		self.ax = self.fig.add_subplot(111, frameon=False)

	def visualize(self, batches, time_steps, dim = 6, batch_id = 0):
		one_batch = batches[batch_id:(batch_id+3)]
		time_steps_one = time_steps[:,batch_id:(batch_id+3)]

		for i in range(one_batch.shape[0]):
			self.ax.plot(time_steps_one[:,i].numpy(), one_batch[i,:,dim].numpy())
		plt.draw()

			 
