# Create a synthetic dataset
from __future__ import absolute_import, division
from __future__ import print_function
import os
import matplotlib
import numpy as np
import numpy.random as npr
from scipy.special import expit as sigmoid
import pickle
import matplotlib.pyplot as plt
import matplotlib.image
import torch
from lib.gaussian_process import GaussianProcess

# ======================================================================================

def get_next_val(init, t, tmin, tmax, final = None):
	if final is None:
		return init
	val = init + (final - init) / (tmax - tmin) * t
	return val


def generate_periodic(time_steps, init_freq, init_amplitude, starting_point, 
	final_freq = None, final_amplitude = None, phi_offset = 0.):

	tmin = time_steps.min()
	tmax = time_steps.max()

	data = []
	t_prev = time_steps[0]
	phi = phi_offset
	for t in time_steps:
		dt = t - t_prev
		amp = get_next_val(init_amplitude, t, tmin, tmax, final_amplitude)
		freq = get_next_val(init_freq, t, tmin, tmax, final_freq)
		phi = phi + 2 * np.pi * freq * dt # integrate to get phase

		y = amp * np.sin(phi) + starting_point
		t_prev = t
		data.append([t,y])
	return np.array(data)

def assign_value_or_sample(value, sampling_interval = [0.,1.]):
	if value is None:
		int_length = sampling_interval[1] - sampling_interval[0]
		return np.random.random() * int_length + sampling_interval[0]
	else:
		return value

class TimeSeries:
	def __init__(self, 
		noise_generator = GaussianProcess(1., "WienerProcess"), 
		device = torch.device("cpu")):

		self.noise_generator = noise_generator
		self.device = device
		self.y0 = None

	def init_visualization(self):
		self.fig = plt.figure(figsize=(10, 4), facecolor='white')
		self.ax = self.fig.add_subplot(111, frameon=False)
		plt.show(block=False)

	def visualize(self, truth):
		self.ax.plot(truth[:,0], truth[:,1])

	def add_noise(self, traj_list, time_steps, noise_weight):
		n_samples = traj_list.size(0)

		# Add noise to all the points except the first point
		noise = self.noise_generator.sample_multidim(time_steps.numpy()[1:], dim = n_samples)
		noise = torch.Tensor(np.transpose(noise)).to(self.device)

		traj_list_w_noise = traj_list.clone()
		# Dimension [:,:,0] is a time dimension -- do not add noise to that
		traj_list_w_noise[:,1:,1] += noise_weight * noise
		self.noise_generator.clear()
		return traj_list_w_noise


class Periodic_1d(TimeSeries):
	def __init__(self, device = torch.device("cpu"), 
		init_freq = 0.5, init_amplitude = 1.,
		final_amplitude = 10., final_freq = 1., 
		noise_generator = GaussianProcess(1., "WienerProcess"),
		y0 = 0.):
		"""
		If some of the parameters (init_freq, init_amplitude, final_amplitude, final_freq) is not provided, it is randomly sampled.
		For now, all the time series share the time points and the starting point.
		"""
		super(Periodic_1d, self).__init__(noise_generator, device)
		
		self.init_freq = init_freq
		self.init_amplitude = init_amplitude
		self.final_amplitude = final_amplitude
		self.final_freq = final_freq
		self.y0 = y0

	def sample_traj(self, time_steps, n_samples = 1, noise_weight = 1.):
		"""
		Sample periodic functions. 
		"""
		traj_list = []
		for i in range(n_samples):
			init_freq = assign_value_or_sample(self.init_freq, [0.5,1.])
			final_freq = assign_value_or_sample(self.final_freq, [0.5,1.])
			init_amplitude = assign_value_or_sample(self.init_amplitude, [0.,1.])
			final_amplitude = assign_value_or_sample(self.final_amplitude, [0.,1.])

			noisy_y0 = self.y0 + np.random.normal(loc=0., scale=0.1)

			traj = generate_periodic(time_steps, init_freq = init_freq, 
				init_amplitude = init_amplitude, starting_point = noisy_y0, 
				final_amplitude = final_amplitude, final_freq = final_freq)
			traj_list.append(traj)
		# shape: [n_samples, n_timesteps, 2]
		# traj_list[:,:,0] -- time stamps
		# traj_list[:,:,1] -- values at the time stamps
		traj_list = np.array(traj_list)
		traj_list = torch.Tensor().new_tensor(traj_list, device = self.device)

		if self.noise_generator is not None:
			traj_list = self.add_noise(traj_list, time_steps, noise_weight)
		return traj_list


class BimodalSigmoid(TimeSeries):
	def __init__(self,
		split_point,
		union_point,
		amplitude = 1.,
		sigmoid_steepness = 10.,
		noise_generator = GaussianProcess(1., "WienerProcess"),
		device = torch.device("cpu"),
		y0 = 0.,):
		"""
		If some of the parameters (init_freq, init_amplitude, final_amplitude, final_freq) is not provided, it is randomly sampled.
		For now, all the time series share the time points and the starting point.
		"""
		super(BimodalSigmoid, self).__init__(noise_generator, device)

		self.split_point = split_point
		self.union_point = union_point
		self.amplitude = amplitude
		self.sigmoid_steepness = sigmoid_steepness
		self.y0 = y0

	@staticmethod
	def get_split_points(time_steps_inteval, split_point_fr = 0.3, union_point_fr = 0.7):
		interval = time_steps_inteval[1] - time_steps_inteval[0]
		# Point where two trajectories split (roughly)
		split_point = time_steps_inteval[0] + split_point_fr * interval
		# Point where two trajectories reunite (roughly)
		union_point = time_steps_inteval[0] + union_point_fr * interval
		return split_point, union_point

	def sample_traj(self, time_steps, n_samples = 1,  noise_weight = 1.):
		"""
		Sample new trajecotories
		"""
		traj_list = []		
		if n_samples == 1:
			traj_top = self.amplitude * torch.sigmoid((time_steps - self.split_point) * self.sigmoid_steepness)
			traj_list.append(traj_top.unsqueeze(1))

		for i in range(n_samples // 2):
			traj_top = self.amplitude * torch.sigmoid((time_steps - self.split_point) * self.sigmoid_steepness)
			traj_bottom = self.amplitude * torch.sigmoid((time_steps - self.union_point) * self.sigmoid_steepness)

			traj_list.append(traj_top.unsqueeze(1))
			traj_list.append(traj_bottom.unsqueeze(1))

		# shape: [n_samples, n_timesteps, 2]
		# traj_list[:,:,0] -- time stamps
		# traj_list[:,:,1] -- values at the time stamps
		traj_list = torch.stack((traj_list))
		traj_list = torch.cat((time_steps.view(1,-1,1).repeat(n_samples,1,1), traj_list),2)

		if self.noise_generator is not None:
			traj_list = self.add_noise(traj_list, time_steps, noise_weight)

		# shuffle list for that the "top" and "bottom" trajectories don't go in order
		r = torch.randperm(traj_list.size(0))
		traj_list = traj_list[r]

		return torch.Tensor().new_tensor(traj_list, device = self.device)


class Periodic_2StartingPoints(TimeSeries):
	def __init__(self, device = torch.device("cpu"), 
		init_freq = 0.5, init_amplitude = 1.,
		final_amplitude = 10., final_freq = 1., 
		noise_generator = GaussianProcess(1., "WienerProcess"),
		y0 = 0.):
		"""
		If some of the parameters (init_freq, init_amplitude, final_amplitude, final_freq) is not provided, it is randomly sampled.
		For now, all the time series share the time points and the starting point.
		"""
		super(Periodic_2StartingPoints, self).__init__(noise_generator, device)

		self.init_freq = init_freq
		self.init_amplitude = init_amplitude
		self.final_amplitude = final_amplitude
		self.final_freq = final_freq
		self.phi_offsets = [1/2 * np.pi, 3/2 * np.pi]
		self.y0 = y0

	def sample_traj(self, time_steps, n_samples = 1, noise_weight = 1.):
		"""
		Sample periodic functions. 
		"""
		traj_list = []
		for i in range(n_samples // len(self.phi_offsets) + 1):
			init_freq = assign_value_or_sample(self.init_freq, [0.5,1.])
			final_freq = assign_value_or_sample(self.final_freq, [0.5,1.])
			init_amplitude = assign_value_or_sample(self.init_amplitude, [0.,1.])
			final_amplitude = assign_value_or_sample(self.final_amplitude, [0.,1.])

			noisy_y0 = self.y0 + np.random.normal(loc=0., scale=0.1)

			# Create periodic trajectories with different phases phi
			for offset in self.phi_offsets:
				traj = generate_periodic(time_steps, init_freq = init_freq, 
					init_amplitude = init_amplitude, starting_point = noisy_y0, 
					final_amplitude = final_amplitude, final_freq = final_freq,
					phi_offset = offset)
				traj_list.append(traj)

		# shape: [n_samples, n_timesteps, 2]
		# traj_list[:,:,0] -- time stamps
		# traj_list[:,:,1] -- values at the time stamps
		traj_list = np.array(traj_list)
		if self.noise_generator is not None:
			traj_list = self.add_noise(traj_list, time_steps, noise_weight)
		return torch.Tensor().new_tensor(traj_list, device = self.device)

class Periodic2d(TimeSeries):
	def __init__(self, device = torch.device("cpu"), 
		init_freq = 0.5, init_amplitude = 1.,
		final_amplitude = 10., final_freq = 1., 
		phi_offset = None,
		noise_generator = GaussianProcess(1., "WienerProcess"),
		y0 = np.array([0.,0.])):
		"""
		If some of the parameters (init_freq, init_amplitude, final_amplitude, final_freq) is not provided, it is randomly sampled.
		For now, all the time series share the time points and the starting point.
		"""
		super(Periodic2d, self).__init__(noise_generator, device)
		
		self.init_freq = init_freq
		self.init_amplitude = init_amplitude
		self.final_amplitude = final_amplitude
		self.final_freq = final_freq
		self.phi_offset = phi_offset
		self.y0 = y0

	def sample_traj(self, time_steps, n_samples = 1, noise_weight = 1.):
		"""
		Sample periodic functions. 
		"""

		traj_list = []
		for i in range(n_samples):
			for y0_one_dim in self.y0:
				init_freq = assign_value_or_sample(self.init_freq, [0.5,1.])
				final_freq = assign_value_or_sample(self.final_freq, [0.5,1.])
				init_amplitude = assign_value_or_sample(self.init_amplitude, [0.,1.])
				final_amplitude = assign_value_or_sample(self.final_amplitude, [0.,1.])
				phi_offset =  assign_value_or_sample(self.phi_offset, [0., 2*np.pi])

				noisy_y0 = y0_one_dim + np.random.normal(loc=0., scale=0.1)

				traj = generate_periodic(time_steps, init_freq = init_freq, 
					init_amplitude = init_amplitude, starting_point = noisy_y0, 
					final_amplitude = final_amplitude, final_freq = final_freq,
					phi_offset = phi_offset)
				traj_list.append(traj)
		# shape: [n_samples, n_timesteps, 2]
		# traj_list[:,:,0] -- time stamps
		# traj_list[:,:,1] -- values at the time stamps
		traj_list = np.array(traj_list)
		if self.noise_generator is not None:
			traj_list = self.add_noise(traj_list, time_steps, noise_weight)

		# Transform the dataset into 2d 
		# shape: [n_traj*2, n_tp, 1]
		timestamps_list = traj_list[:,:,:1]
		# shape: [n_traj*2, n_tp]
		data = traj_list[:,:,1]

		n_traj, n_tp = data.shape
		# Take the time stamps for only half of the trajectories
		timestamps_list = timestamps_list[:n_traj//2]

		# shape: [2, n_traj, n_tp]
		data = data.reshape(2, n_traj//2, n_tp)
		# shape: [n_traj, n_tp, 2]
		data = np.transpose(data, (1,2,0))
		# shape: [n_traj, n_tp, 3]
		data = np.concatenate((timestamps_list, data), 2)
		return torch.Tensor().new_tensor(data, device = self.device)




class PeriodicUpStraightDown(TimeSeries):
	def __init__(self, device = torch.device("cpu"), 
		freq = 0.5, amplitude = 1.,
		noise_generator = GaussianProcess(1., "WienerProcess"),
		y0 = 0.):
		"""
		If some of the parameters (init_freq, init_amplitude, final_amplitude, final_freq) is not provided, it is randomly sampled.
		For now, all the time series share the time points and the starting point.
		"""
		super(PeriodicUpStraightDown, self).__init__(noise_generator, device)
		
		self.freq = freq
		self.amplitude = amplitude
		self.y0 = y0

	def sample_traj(self, time_steps, n_samples = 1, noise_weight = 1., slopes = [-2.,0.,2.]):
		"""
		Sample periodic functions. 
		"""
		traj_list = []
		counter = 0

		for i in range(n_samples // len(slopes) + 1):
			for sl in slopes:
				noisy_y0 = self.y0 + np.random.normal(loc=0., scale=0.1)

				freq = assign_value_or_sample(self.freq, [1., 1.5])
				amplitude = assign_value_or_sample(self.amplitude, [0.,1.])

				traj = generate_periodic(
					time_steps, 
					init_freq = freq, 
				 	init_amplitude = amplitude, 
				 	final_freq = freq,
				 	final_amplitude = amplitude,
				 	starting_point = 0.01)

				traj = np.array(traj)
				traj[:,1] = traj[:,1] + noisy_y0 + sl * time_steps.numpy()
				traj_list.append(traj)

				counter += 1
				if counter >= n_samples:
					break

		# shape: [n_samples, n_timesteps, 2]
		# traj_list[:,:,0] -- time stamps
		# traj_list[:,:,1] -- values at the time stamps
		traj_list = np.array(traj_list)

		if self.noise_generator is not None:
			traj_list = self.add_noise(traj_list, time_steps, noise_weight)
		return torch.Tensor().new_tensor(traj_list, device = self.device)


def generate_zigzag(time_steps, freq, amplitude, starting_point = 0.):
	period = 1/freq
	slope = amplitude / (period / 4)

	data = []
	for i in range(1, int(time_steps.max()/ period * 2) + 2):
		tp_current_period =  time_steps[ (time_steps <= (period/2 * i)) & (time_steps >= (period/2 * (i-1)))]
		delta_t = tp_current_period - period/2 * (i-1)
		
		if i % 2 == 0:
			data.append( amplitude - slope * delta_t)
		else:
			data.append( -amplitude + slope * delta_t)

	data = np.concatenate(data)
	data = data[:len(time_steps)].reshape(-1, 1)
	return np.concatenate((time_steps.unsqueeze(-1), data),-1)


class Zigzag(TimeSeries):
	def __init__(self, device = torch.device("cpu"), 
		freq = 0.5, amplitude = 1.,
		noise_generator = GaussianProcess(1., "WienerProcess"),
		y0 = 0.):
		"""
		For now, all the time series share the time points and the starting point.
		"""
		super(Zigzag, self).__init__(noise_generator, device)
		
		self.freq = freq
		self.amplitude = amplitude
		self.y0 = y0

	def sample_traj(self, time_steps, n_samples = 1, noise_weight = 1., slopes = [-2.,0.,2.]):
		"""
		Sample periodic functions. 
		"""
		traj_list = []
		for i in range(n_samples // len(slopes) + 1):
			for sl in slopes:
				noisy_y0 = self.y0 + np.random.normal(loc=0., scale=0.1)

				freq = assign_value_or_sample(self.freq, [1., 1.5])
				amplitude = assign_value_or_sample(self.amplitude, [0.,1.])

				traj = generate_zigzag(time_steps, freq, amplitude)
				traj[:,1] = traj[:,1] + noisy_y0
				traj_list.append(traj)

		# shape: [n_samples, n_timesteps, 2]
		# traj_list[:,:,0] -- time stamps
		# traj_list[:,:,1] -- values at the time stamps

		traj_list = np.array(traj_list)
		traj_list = torch.Tensor().new_tensor(traj_list, device = self.device)

		if self.noise_generator is not None:
			traj_list = self.add_noise(traj_list, time_steps, noise_weight)
		return traj_list



if __name__ == "__main__":
	# dataset = Periodic_1d()
	# traj = dataset.get_dataset()
	# dataset.init_visualization()
	# dataset.visualize(traj)

	dataset = Zigzag(freq = 2, amplitude = 2.,)
	traj = dataset.sample_traj(torch.Tensor(np.linspace(0,10)), n_samples=10, noise_weight = 0.)

	dataset.init_visualization()
	dataset.visualize(traj[0].numpy())

	plt.pause(9000)
