from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import numpy as np
sys.path.append('.')
sys.path.append('..')
from methods.lips_bound import lips_bound_evaluation, estimate_rpi
import gym
import tensorflow as tf

class Linear_Policy_Pendulum(object):
	state_dim = 3
	obs_dim = 3
	action_dim = 1
	def __init__(self, theta):
		self.w = theta[:self.state_dim * self.action_dim].reshape(self.state_dim, self.action_dim) 		# state_dim * action_dim
		self.b = theta[self.state_dim * self.action_dim: (self.state_dim+1) * self.action_dim] 			# action_dim
		self.logvar = theta[-1]

	def get_states(self, states):
		oldshape = list(states.shape)
		oldshape[-1] = self.state_dim
		states = states.reshape(-1, self.obs_dim)
		N = states.shape[0]
		new_states = np.zeros([N, self.state_dim])
		new_states[:,-1] = states[:,-1]
		new_states[:,1] = np.arccos(states[:,0]) * np.sign(states[:,1])
		new_states[:,0] = states[:,1]
		return new_states.reshape(oldshape)

	def get_mean(self, states):
		states = self.get_states(states)
		policy_force = np.matmul(states, self.w) + self.b
		return policy_force

	def choose_action(self, state):
		states = state[None,:]
		return self.get_mean(states).reshape(self.action_dim) + np.exp(self.logvar/2) * np.random.randn(self.action_dim)		#action_dim

	def choose_actions(self, states):
		states = states.reshape(-1, self.obs_dim)
		n = states.shape[0]
		return self.get_mean(states) + np.exp(self.logvar/2) * np.random.randn(n, self.action_dim)	#N * action_dim

	def logpis(self, states, actions):
		# diff = np.matmul(states, self.w) + self.b - actions 								# N * action_dim
		diff = self.get_mean(states) - actions 								# N * action_dim
		return -0.5*self.logvar - 0.5*np.sum(diff*diff, axis = -1)/np.exp(self.logvar)		# N

def Get_Pendulum_s0_data(num_trajectory, seedID):
	env = gym.make('Pendulum-v0')
	env = env.unwrapped
	env.dt = 0.2
	env.max_torque = 40.
	env.max_speed = 200.
	env.seed(seedID)
	np.random.seed(seedID)

	S0 = []
	for i_episode in range(num_trajectory):
		state = env.reset()
		S0.append(state)
	S0 = np.array(S0)
	return S0

def Get_Pendulum_transition_data(num_trajectory, truncate_size, policy, seedID, render = False):
	env = gym.make('Pendulum-v0')
	env = env.unwrapped
	env.dt = 0.2
	env.max_torque = 40.
	env.max_speed = 200.
	env.seed(seedID)
	np.random.seed(seedID)

	intial_length = 0
	state_dim = policy.state_dim
	action_dim = policy.action_dim
	S = np.zeros([num_trajectory, truncate_size, state_dim])
	SN = np.zeros([num_trajectory, truncate_size, state_dim])
	A = np.zeros([num_trajectory, truncate_size, action_dim])
	REW = np.zeros([num_trajectory, truncate_size])
	for i_episode in range(num_trajectory):
		state = env.reset()
		for i_iteration in range(truncate_size + intial_length):
			if render:
				env.render()
			action = policy.choose_action(state).reshape(action_dim)
			next_state, reward, _, _ = env.step(action)
			if i_iteration >= intial_length:
				S[i_episode, i_iteration - intial_length, :] = state
				A[i_episode, i_iteration - intial_length, :] = action
				SN[i_episode, i_iteration - intial_length, :] = next_state
				REW[i_episode, i_iteration - intial_length] = reward
			state = next_state
	return [S, A, SN, REW*100]

class feature_q_learning_fitted(object):
    def __init__(self, state_action_dim, feature_dim, hidden_dim = 100, Learning_rate = 1e-3, reg_weight = 1e-2, seedID = 43):
        self.state_action_dim = state_action_dim
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim

        tf.set_random_seed(seedID)
        # Input
        self.xs = tf.placeholder(tf.float32, [None, state_action_dim])
        self.BQs = tf.placeholder(tf.float32, [None])

        # Calculation
        self.features, self.Qs , self.lips_constant, params = self._build_q_network('current_Q', True)
        _, self.target_Qs, _, target_params = self._build_q_network('target_Q', False)

        # Loss and operation
        self.update_target_op = [target.assign(current) for current, target in zip(params, target_params)]
        self.loss = tf.reduce_mean(tf.square(self.Qs - self.BQs))
        self.reg_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'current_Q'))
        self.train_op = tf.train.AdamOptimizer(Learning_rate).minimize(self.loss + reg_weight * self.reg_loss)
        # self.train_op = tf.train.GradientDescentOptimizer(Learning_rate).minimize(self.loss + reg_weight * self.reg_loss)

        # Debug

        # sess and cpu
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver()

    def reset(self):
        self.sess.run(tf.global_variables_initializer())

    def save_model(self, filename = './pendulum_domain/model/feature_q_fitted.ckpt'):
        self.saver.save(self.sess, filename)

    def load_model(self, filename = './pendulum_domain/model/feature_q_fitted.ckpt'):
        self.saver.restore(self.sess, filename)

    # restore for pendulum!!
    def _build_q_network(self, name, trainable):
        with tf.variable_scope(name, reuse = tf.AUTO_REUSE):
            l1 = tf.layers.dense(self.xs, self.hidden_dim, tf.nn.relu, trainable = trainable)
            features = tf.concat([tf.layers.dense(l1, self.feature_dim, tf.nn.relu, trainable = trainable), self.xs], -1)
            W2 = tf.get_variable('W2', initializer = tf.zeros(shape = [self.state_action_dim + self.feature_dim, 1]), regularizer = tf.contrib.layers.l2_regularizer(0.), trainable = trainable)
            b2 = tf.get_variable('b2', initializer = tf.zeros([1]), regularizer = tf.contrib.layers.l2_regularizer(0.), trainable = trainable)
            lips_constant = tf.sqrt(tf.reduce_sum(W2 * W2))
            # Qs = tf.squeeze(tf.reduce_sum(tf.expand_dims(features, -1) * W2, axis = -2) + b2)
            Qs = tf.squeeze(tf.matmul(features, W2) + b2) #+ 200
            # features = features * tf.squeeze(W2)
        params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope = name)
        return features, Qs, lips_constant, params

    def get_features(self, S, A):
        return self.sess.run(self.features, feed_dict = {
            self.xs: np.hstack([S,A])
        })

    def get_lips_constant(self):
        return self.sess.run(self.lips_constant)

    def get_Q_value(self, S, A):
        return self.sess.run(self.Qs, feed_dict = {
            self.xs: np.hstack([S,A])
        })

    def get_bellman_residual(self, policy_target, xs, r, sn, gamma, repeat = 10):
        Qs = self.sess.run(self.Qs, feed_dict = {
            self.xs: xs
        })
        BQs = np.zeros_like(r)
        for i in range(repeat):
            an = policy_target.choose_actions(sn).reshape(-1, policy_target.action_dim)
            BQs += self.sess.run(self.Qs, feed_dict = {
                self.xs: np.hstack([sn, an])
            })
        BQs = gamma * BQs / repeat + r
        return np.mean(np.square(Qs - BQs))

    def get_bellman_residual_discrete(self, pi, xs, r, sn, gamma, repeat = 10):
        Qs = self.sess.run(self.Qs, feed_dict = {
            self.xs: xs
        })
        BQs = np.zeros_like(r)
        for i in range(pi.shape[-1]):
            an = np.zeros([r.shape[0], pi.shape[-1]])
            an[:,i] = 1
            BQs_i = self.sess.run(self.target_Qs, feed_dict = {
                self.xs: np.hstack([sn, an])
            })
            BQs += pi[:,i] * BQs_i
        BQs = gamma * BQs + r
        return np.mean(np.square(Qs - BQs))

    def train(self, replay_buffer, policy_target, batch_size, repeat, gamma, fitted_iter = 2000, max_iter = 40000):
        S, A, SN, REW = replay_buffer
        N = S.shape[0]
        state_dim = S.shape[-1]
        action_dim = A.shape[-1]

        perm = np.random.permutation(N)
        j = 0
        LOSS = []

        for iter in range(max_iter):
            if j + batch_size > N:
                perm = np.random.permutation(N)
                j = 0
            subsamples = perm[j:j+batch_size]
            xs = np.hstack([S[subsamples], A[subsamples]])
            r = REW[subsamples]
            sn = SN[subsamples]
            BQs = np.zeros_like(r)
            for i_repeat in range(repeat):
                an = policy_target.choose_actions(sn).reshape(-1, action_dim)
                BQs += self.sess.run(self.target_Qs, feed_dict = {
                    self.xs: np.hstack([sn, an])
                })
            BQs = gamma * BQs / repeat + r
            loss, reg_loss, _ = self.sess.run([self.loss, self.reg_loss, self.train_op], feed_dict = {
                self.xs: xs,
                self.BQs: BQs
            })
            j += batch_size
            if iter % 500 == 0:
                # bellman_loss = self.get_bellman_residual(policy_target, np.hstack([S, A]), REW, SN, gamma)
                # print('iter = {}, loss = {}, bellman_loss = {}'.format(iter, loss, bellman_loss))
                print('iter = {}, loss = {}, reg_loss = {}'.format(iter, loss, reg_loss))
                LOSS.append(loss)
            if iter % fitted_iter == 0:
                self.sess.run(self.update_target_op)
        return

    def train_discrete_action(self, SASR_pi, batch_size, gamma, fitted_iter = 2000, max_iter = 40000):
        S, A, SN, REW, pi = SASR_pi
        N = S.shape[0]
        action_dim = A.shape[-1]
        state_dim = S.shape[-1]

        perm = np.random.permutation(N)
        j = 0
        LOSS = []
        # self.sess.run(self.update_target_op)
        for iter in range(max_iter):
            if iter % fitted_iter == 0:
                self.sess.run(self.update_target_op)
            if j + batch_size > N:
                perm = np.random.permutation(N)
                j = 0
            subsamples = perm[j:j+batch_size]
            xs = np.hstack([S[subsamples], A[subsamples]])
            r = REW[subsamples]
            sn = SN[subsamples]
            BQs = np.zeros_like(r)
            for i in range(action_dim):
                an = np.zeros([batch_size, action_dim])
                an[:,i] = 1
                BQs_i = self.sess.run(self.target_Qs, feed_dict = {
                    self.xs: np.hstack([sn, an])
                })
                BQs += pi[subsamples,i] * BQs_i
            BQs = gamma * BQs + r
            loss, reg_loss, _, Qs = self.sess.run([self.loss, self.reg_loss, self.train_op, self.Qs], feed_dict = {
                self.xs: xs,
                self.BQs: BQs
            })
            j += batch_size
            if iter % 500 == 0:
                bellman_loss = self.get_bellman_residual_discrete(pi, np.hstack([S, A]), REW, SN, gamma)
                print('iter = {}, loss = {}, bellman_loss = {}'.format(iter, loss, bellman_loss))
                # print('iter = {}, loss = {}, reg_loss = {}'.format(iter, loss, reg_loss))
                # print('BQ = {}, Q = {}'.format(BQs[:3], Qs[:3]))
                LOSS.append(loss)
        return

def get_ground_truth(SASR_target, gamma):
	REW = SASR_target[-1]
	truncate_size = REW.shape[1]
	discounted = np.exp(np.arange(truncate_size) * np.log(gamma))
	return np.mean(REW * discounted) / np.mean(discounted)

class pendulum_config(object):
    # domain parameters
    state_dim = 3
    action_dim = 1

    gamma = 0.95
    num_trajectory = 30
    truncate_size = 100
    eta = 10.0
    subsample_size = 500
    NT = [1,2,4,6,10,20,30]
    ETA = [4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 12.0, 15.0]
    SSIZE = [100, 200, 300, 400, 500, 600, 800, 1000, 1500]
    result_path = './results/pendulum_results/'
    data_path = './transition_data/pendulum_data/'
    figure_name = 'pendulum.pdf'
    ground_truth = -37.3473558721

    hidden_dim = 100
    feature_dim = 3
    Learning_rate = 1e-3

    max_iteration = 100

    dt = 0.2
    linear_theta0 = np.array([-5., -0.2/(dt*dt), -1./(3*dt), 0., -2.])
    # linear_theta0 = np.array([-5., -0.15/(dt*dt), -1./(3*dt), 0.01, -2.])
    linear_theta1 = np.array([-5., -0.2/(dt*dt), -1./(3*dt), 0., -5.5])
    policy_behavior = Linear_Policy_Pendulum(linear_theta0)
    policy_target = Linear_Policy_Pendulum(linear_theta1)

    # feature_model = feature_q_learning_fitted(state_dim + action_dim, feature_dim, seedID = 43)
    # feature_model.load_model()

    def feature_pendulum(self, S, A):
        # return feature_model.get_features(S, A)
        return np.hstack([S,A])

    def get_trasition_data(self, num_trajectory, truncate_size, policy, seedID):
        return Get_Pendulum_transition_data(num_trajectory, truncate_size, policy, seedID)

    def interval_estimation(self, num_trajectory, eta, subsample_size, seedID):
        print('======== Current Setting for pendulum =========')
        print('---nt = {}, ts = {}, eta = {}, sample_size = {}, seed = {}---'.format(num_trajectory, self.truncate_size, eta, subsample_size, seedID))

        state_dim = self.state_dim
        action_dim = self.action_dim

        SASR_behavior = self.get_trasition_data(num_trajectory, self.truncate_size, self.policy_behavior, seedID)
        S, A, SN, REW = SASR_behavior 			# dimension: nt*ts*X
        S_flat = S.reshape(-1, state_dim)		# N * state_dim
        A_flat = A.reshape(-1, action_dim)
        SN_flat = SN.reshape(-1, state_dim)
        REW_flat = REW.reshape(-1)
        replay_buffer = [S_flat, A_flat, SN_flat, REW_flat]

        max_iteration = self.max_iteration
        # if subsample_size > 500:
        #     max_iteration = int(self.max_iteration * 500 / subsample_size)
        s0 = Get_Pendulum_s0_data(500, seedID)

        # est_naive_average = get_ground_truth(SASR_behavior, self.gamma)
        # print('Naive average = {}'.format(est_naive_average))
        feature_model = feature_q_learning_fitted(self.state_dim + self.action_dim, self.feature_dim, seedID = 43)
        feature_model.load_model()

        Q_lower, Q_upper = lips_bound_evaluation(s0, replay_buffer, self.policy_target, feature_model.get_features, self.gamma, eta, subsample_size = subsample_size, max_iteration = max_iteration, discrete_action = False)
        Q0_lower, Q0_upper = estimate_rpi(s0, self.policy_target, feature_model.get_features, S_flat, A_flat, Q_lower, Q_upper, self.gamma, eta, discrete_action = False)

        Q_lower, Q_upper = lips_bound_evaluation(s0, replay_buffer, self.policy_target, feature_model.get_features, self.gamma, eta, subsample_size = subsample_size, double_sample = True, max_iteration = max_iteration, discrete_action = False)
        Q0_lower2, Q0_upper2 = estimate_rpi(s0, self.policy_target, feature_model.get_features, S_flat, A_flat, Q_lower, Q_upper, self.gamma, eta, discrete_action = False)
        print('-----end calculation-----')
        print('lower = {}, upper = {}'.format(Q0_lower, Q0_upper))
        print('double sample: lower = {}, upper = {}'.format(Q0_lower2, Q0_upper2))
        print('============================')
        sys.stdout.flush()
        return Q0_lower, Q0_lower2, Q0_upper, Q0_upper2
