from .common.mpi_running_mean_std import RunningMeanStd
import experiments.utils.tf_util as U
import tensorflow as tf
import gym
from .common.distributions import make_pdtype

class MlpPolicy(object):
    recurrent = False
    def __init__(self, sess, graph, name, *args, **kwargs):
        self.sess = sess
        self.graph = graph

        with self.sess.as_default():
            with self.graph.as_default():
                with tf.variable_scope(name):
                    self._init(*args, **kwargs)
                    self.scope = tf.get_variable_scope().name

    def _init(self, ob_space, ac_space, policy_net_layers, v_net_layers, gaussian_fixed_var=True):
        assert isinstance(ob_space, gym.spaces.Box)

        self.pdtype = pdtype = make_pdtype(ac_space)
        sequence_length = None

        # ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
        ob = U.get_placeholder_with_graph(name="ob", dtype=tf.float32,
                                          shape=[sequence_length] + list(ob_space.shape),
                                          graph=self.graph)
        self.ob_phd = ob

        # with tf.variable_scope("obfilter"):
        #     self.ob_rms = RunningMeanStd(shape=ob_space.shape, sess=self.sess, graph=self.graph)

        with tf.variable_scope('vf'):
            # obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
            # last_out = obz
            last_out = ob
            for i in range(len(v_net_layers)):
                hid_size = v_net_layers[i]
                last_out = tf.layers.dense(last_out, hid_size, name="fc%i" % (i + 1),
                                           kernel_initializer=U.normc_initializer(1.0)
                                           )
                last_out = tf.contrib.layers.layer_norm(last_out)
                last_out = tf.nn.relu(last_out)

            self.vpred = tf.layers.dense(last_out, 1, name='final',
                                         kernel_initializer=U.normc_initializer(1.0)
                                         )[:,0]

        with tf.variable_scope('pol'):
            # last_out = obz
            last_out = ob
            for i in range(len(policy_net_layers)):
                hid_size = policy_net_layers[i]
                last_out = tf.layers.dense(last_out, hid_size, name='fc%i' % (i + 1),
                                           kernel_initializer=U.normc_initializer(1.0)
                                           )
                last_out = tf.contrib.layers.layer_norm(last_out)
                last_out = tf.nn.relu(last_out)

            if gaussian_fixed_var and isinstance(ac_space, gym.spaces.Box):
                mean = tf.nn.tanh(tf.layers.dense(last_out, pdtype.param_shape()[0]//2, name='final',
                                       kernel_initializer=U.normc_initializer(0.01)
                                       ))
                logstd = tf.get_variable(name="logstd", shape=[1, pdtype.param_shape()[0]//2],
                                         initializer=tf.zeros_initializer()
                                         )
                # logstd = tf.layers.dense(last_out, pdtype.param_shape()[0]//2, name='logstd',
                #                        #kernel_initializer=U.normc_initializer(0.01)
                #                        )
                pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
            else:
                pdparam = tf.layers.dense(last_out, pdtype.param_shape()[0], name='final',
                                          kernel_initializer=U.normc_initializer(0.01)
                                          )

        self.pd = pdtype.pdfromflat(pdparam)

        self.state_in = []
        self.state_out = []

        stochastic = tf.placeholder(dtype=tf.bool, shape=())
        ac = U.switch(stochastic, self.pd.sample(), self.pd.mode())
        self._act = U.function([stochastic, ob], [ac, self.vpred], sess=self.sess, graph=self.graph)

    def act(self, stochastic, ob):
        with self.graph.as_default():
            # print("Stochastic is {}".format(stochastic))

            # mean_value, std_value = self.sess.run([self.pd.mean, self.pd.std], feed_dict={self.ob_phd: [ob]})
            # print("The action mean and std are {} and {}".format(mean_value, std_value))

            ac1, vpred1 = self._act(stochastic, ob[None])

            # print("The sample action is {}, {}".format(ac1, vpred1))
            # print("Mean, var and sample action is {}, {}, {}".format(mean_value, std_value, ac1))

            return ac1[0], vpred1[0]

    def get_variables(self):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)

    def get_trainable_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)

    def get_initial_state(self):
        return []

    def get_policy_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/pol')

    def get_critic_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/vf')


"""
    for shaping method oprs-v1
    MLP policy with shaping weight function f as input of policy and shaped critic
"""
class MlpPolicyOprsV1(object):
    recurrent = False
    def __init__(self, sess, graph, name, *args, **kwargs):
        self.sess = sess
        self.graph = graph

        with self.sess.as_default():
            with self.graph.as_default():
                with tf.variable_scope(name):
                    self._init(*args, **kwargs)
                    self.scope = tf.get_variable_scope().name

    def _init(self, ob_space, ac_space, policy_net_layers, v_net_layers, gaussian_fixed_var=True,
              f_output=None):
        assert isinstance(ob_space, gym.spaces.Box)
        assert f_output is not None

        self.pdtype = pdtype = make_pdtype(ac_space)
        sequence_length = None
        self.f_output = f_output

        # ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
        ob = U.get_placeholder_with_graph(name="ob", dtype=tf.float32,
                                          shape=[sequence_length] + list(ob_space.shape),
                                          graph=self.graph)
        self.ob_phd = ob

        # with tf.variable_scope("obfilter"):
        #     self.ob_rms = RunningMeanStd(shape=ob_space.shape, sess=self.sess, graph=self.graph)

        """
            the shaped value function
        """
        with tf.variable_scope('vf_shaped'):
            # obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
            # last_out = obz
            last_out = tf.concat([ob, self.f_output], axis=1) #ob
            for i in range(len(v_net_layers)):
                hid_size = v_net_layers[i]
                last_out = tf.layers.dense(last_out, hid_size, name="fc%i" % (i + 1),
                                           kernel_initializer=U.normc_initializer(1.0)
                                           )
                last_out = tf.contrib.layers.layer_norm(last_out)
                last_out = tf.nn.relu(last_out)

            self.vpred = tf.layers.dense(last_out, 1, name='final',
                                         kernel_initializer=U.normc_initializer(1.0)
                                         )[:,0]

        with tf.variable_scope('vf_true'):
            # obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
            # last_out = obz
            last_out = ob # ob
            for i in range(len(v_net_layers)):
                hid_size = v_net_layers[i]
                last_out = tf.layers.dense(last_out, hid_size, name="fc%i" % (i + 1),
                                           kernel_initializer=U.normc_initializer(1.0)
                                           )
                last_out = tf.contrib.layers.layer_norm(last_out)
                last_out = tf.nn.relu(last_out)

            self.vpred_true = tf.layers.dense(last_out, 1, name='final',
                                         kernel_initializer=U.normc_initializer(1.0)
                                         )[:, 0]

        with tf.variable_scope('pol'):
            # last_out = obz
            last_out = tf.concat([ob, self.f_output], axis=1) #ob
            for i in range(len(policy_net_layers)):
                hid_size = policy_net_layers[i]
                last_out = tf.layers.dense(last_out, hid_size, name='fc%i' % (i + 1),
                                           kernel_initializer=U.normc_initializer(1.0)
                                           )
                last_out = tf.contrib.layers.layer_norm(last_out)
                last_out = tf.nn.relu(last_out)

            if gaussian_fixed_var and isinstance(ac_space, gym.spaces.Box):
                mean = tf.nn.tanh(tf.layers.dense(last_out, pdtype.param_shape()[0]//2, name='final',
                                       kernel_initializer=U.normc_initializer(0.01)
                                       ))
                logstd = tf.get_variable(name="logstd", shape=[1, pdtype.param_shape()[0]//2],
                                         initializer=tf.zeros_initializer()
                                         )
                # logstd = tf.layers.dense(last_out, pdtype.param_shape()[0]//2, name='logstd',
                #                        #kernel_initializer=U.normc_initializer(0.01)
                #                        )
                pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
            else:
                pdparam = tf.layers.dense(last_out, pdtype.param_shape()[0], name='final',
                                          kernel_initializer=U.normc_initializer(0.01)
                                          )

        self.pd = pdtype.pdfromflat(pdparam)

        self.state_in = []
        self.state_out = []

        stochastic = tf.placeholder(dtype=tf.bool, shape=())
        ac = U.switch(stochastic, self.pd.sample(), self.pd.mode())
        self._act = U.function([stochastic, ob], [ac, self.vpred, self.vpred_true],
                               sess=self.sess, graph=self.graph)

    def act(self, stochastic, ob):
        with self.graph.as_default():
            ac1, vpred1, vpred_true = self._act(stochastic, ob[None])
            return ac1[0], vpred1[0], vpred_true[0]

    def get_variables(self):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)

    def get_trainable_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)

    def get_initial_state(self):
        return []

    def get_policy_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/pol')

    def get_critic_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/vf_shaped')

    def get_true_critic_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/vf_true')


# """
#     for shaping method oprs-v2
#     MLP policy with shaping weight function f as input of shaped critic
#     f is not an input of policy
# """
# class MlpPolicyOprsV2(object):
#     recurrent = False
#     def __init__(self, sess, graph, name, *args, **kwargs):
#         self.sess = sess
#         self.graph = graph
#
#         with self.sess.as_default():
#             with self.graph.as_default():
#                 with tf.variable_scope(name):
#                     self._init(*args, **kwargs)
#                     self.scope = tf.get_variable_scope().name
#
#     def _init(self, ob_space, ac_space, policy_net_layers, v_net_layers, gaussian_fixed_var=True,
#               f_output=None):
#         assert isinstance(ob_space, gym.spaces.Box)
#         # assert f_output is not None
#
#         self.pdtype = pdtype = make_pdtype(ac_space)
#         sequence_length = None
#         self.f_output = f_output
#
#         # ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
#         ob = U.get_placeholder_with_graph(name="ob", dtype=tf.float32,
#                                           shape=[sequence_length] + list(ob_space.shape),
#                                           graph=self.graph)
#         self.ob_phd = ob
#
#         # with tf.variable_scope("obfilter"):
#         #     self.ob_rms = RunningMeanStd(shape=ob_space.shape, sess=self.sess, graph=self.graph)
#
#         """
#             the shaped value function
#         """
#         with tf.variable_scope('vf_shaped'):
#             # obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
#             # last_out = obz
#             if self.f_output is not None:
#                 last_out = tf.concat([ob, self.f_output], axis=1) #ob
#             else:
#                 last_out = ob
#
#             for i in range(len(v_net_layers)):
#                 hid_size = v_net_layers[i]
#                 last_out = tf.layers.dense(last_out, hid_size, name="fc%i" % (i + 1),
#                                            kernel_initializer=U.normc_initializer(1.0)
#                                            )
#                 last_out = tf.contrib.layers.layer_norm(last_out)
#                 last_out = tf.nn.relu(last_out)
#
#             self.vpred = tf.layers.dense(last_out, 1, name='final',
#                                          kernel_initializer=U.normc_initializer(1.0)
#                                          )[:,0]
#
#         with tf.variable_scope('vf_true'):
#             # obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
#             # last_out = obz
#             last_out = ob # ob
#             for i in range(len(v_net_layers)):
#                 hid_size = v_net_layers[i]
#                 last_out = tf.layers.dense(last_out, hid_size, name="fc%i" % (i + 1),
#                                            kernel_initializer=U.normc_initializer(1.0)
#                                            )
#                 last_out = tf.contrib.layers.layer_norm(last_out)
#                 last_out = tf.nn.relu(last_out)
#
#             self.vpred_true = tf.layers.dense(last_out, 1, name='final',
#                                          kernel_initializer=U.normc_initializer(1.0)
#                                          )[:, 0]
#
#         with tf.variable_scope('pol'):
#             # last_out = obz
#             last_out = ob
#             for i in range(len(policy_net_layers)):
#                 hid_size = policy_net_layers[i]
#                 last_out = tf.layers.dense(last_out, hid_size, name='fc%i' % (i + 1),
#                                            kernel_initializer=U.normc_initializer(1.0)
#                                            )
#                 last_out = tf.contrib.layers.layer_norm(last_out)
#                 last_out = tf.nn.relu(last_out)
#
#             if gaussian_fixed_var and isinstance(ac_space, gym.spaces.Box):
#                 mean = tf.nn.tanh(tf.layers.dense(last_out, pdtype.param_shape()[0]//2, name='final',
#                                        kernel_initializer=U.normc_initializer(0.01)
#                                        ))
#                 logstd = tf.get_variable(name="logstd", shape=[1, pdtype.param_shape()[0]//2],
#                                          initializer=tf.zeros_initializer()
#                                          )
#                 # logstd = tf.layers.dense(last_out, pdtype.param_shape()[0]//2, name='logstd',
#                 #                        #kernel_initializer=U.normc_initializer(0.01)
#                 #                        )
#                 pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
#             else:
#                 pdparam = tf.layers.dense(last_out, pdtype.param_shape()[0], name='final',
#                                           kernel_initializer=U.normc_initializer(0.01)
#                                           )
#
#         self.pd = pdtype.pdfromflat(pdparam)
#
#         self.state_in = []
#         self.state_out = []
#
#         stochastic = tf.placeholder(dtype=tf.bool, shape=())
#         ac = U.switch(stochastic, self.pd.sample(), self.pd.mode())
#         self._act = U.function([stochastic, ob], [ac, self.vpred, self.vpred_true],
#                                sess=self.sess, graph=self.graph)
#
#     def act(self, stochastic, ob):
#         with self.graph.as_default():
#             ac1, vpred1, vpred_true = self._act(stochastic, ob[None])
#             return ac1[0], vpred1[0], vpred_true[0]
#
#     def get_variables(self):
#         return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
#
#     def get_trainable_variables(self):
#         return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
#
#     def get_initial_state(self):
#         return []
#
#     def get_policy_variables(self):
#         return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/pol')
#
#     def get_critic_variables(self):
#         return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/vf_shaped')
#
#     def get_true_critic_variables(self):
#         return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/vf_true')

"""
    for shaping method oprs-v2
    MLP policy with shaping weight function f as input of shaped critic
    f is not an input of policy
"""
class MlpPolicyOprsV2(object):
    recurrent = False
    def __init__(self, sess, graph, name, *args, **kwargs):
        self.sess = sess
        self.graph = graph

        with self.sess.as_default():
            with self.graph.as_default():
                with tf.variable_scope(name):
                    self._init(*args, **kwargs)
                    self.scope = tf.get_variable_scope().name

    def _init(self, ob_space, ac_space, policy_net_layers, v_net_layers, gaussian_fixed_var=True,
              f_output=None):
        assert isinstance(ob_space, gym.spaces.Box)
        # assert f_output is not None

        self.pdtype = pdtype = make_pdtype(ac_space)
        sequence_length = None
        self.f_output = f_output

        # ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
        ob = U.get_placeholder_with_graph(name="ob", dtype=tf.float32,
                                          shape=[sequence_length] + list(ob_space.shape),
                                          graph=self.graph)
        self.ob_phd = ob

        # with tf.variable_scope("obfilter"):
        #     self.ob_rms = RunningMeanStd(shape=ob_space.shape, sess=self.sess, graph=self.graph)

        """
            the shaped value function
        """
        with tf.variable_scope('vf_shaped'):
            # obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
            # last_out = obz
            # if self.f_output is not None:
            #     last_out = tf.concat([ob, self.f_output], axis=1) #ob
            # else:
            #     last_out = ob

            last_out = ob

            for i in range(len(v_net_layers)):
                hid_size = v_net_layers[i]
                last_out = tf.layers.dense(last_out, hid_size, name="fc%i" % (i + 1),
                                           kernel_initializer=U.normc_initializer(1.0)
                                           )
                last_out = tf.contrib.layers.layer_norm(last_out)
                last_out = tf.nn.relu(last_out)

            self.vpred = tf.layers.dense(last_out, 1, name='final',
                                         kernel_initializer=U.normc_initializer(1.0)
                                         )[:,0]

        with tf.variable_scope('vf_true'):
            # obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
            # last_out = obz
            last_out = ob # ob
            for i in range(len(v_net_layers)):
                hid_size = v_net_layers[i]
                last_out = tf.layers.dense(last_out, hid_size, name="fc%i" % (i + 1),
                                           kernel_initializer=U.normc_initializer(1.0)
                                           )
                last_out = tf.contrib.layers.layer_norm(last_out)
                last_out = tf.nn.relu(last_out)

            self.vpred_true = tf.layers.dense(last_out, 1, name='final',
                                         kernel_initializer=U.normc_initializer(1.0)
                                         )[:, 0]

        with tf.variable_scope('pol'):
            # last_out = obz
            last_out = ob
            for i in range(len(policy_net_layers)):
                hid_size = policy_net_layers[i]
                last_out = tf.layers.dense(last_out, hid_size, name='fc%i' % (i + 1),
                                           kernel_initializer=U.normc_initializer(1.0)
                                           )
                last_out = tf.contrib.layers.layer_norm(last_out)
                last_out = tf.nn.relu(last_out)

            if gaussian_fixed_var and isinstance(ac_space, gym.spaces.Box):
                mean = tf.nn.tanh(tf.layers.dense(last_out, pdtype.param_shape()[0]//2, name='final',
                                       kernel_initializer=U.normc_initializer(0.01)
                                       ))
                logstd = tf.get_variable(name="logstd", shape=[1, pdtype.param_shape()[0]//2],
                                         initializer=tf.zeros_initializer()
                                         )
                # logstd = tf.layers.dense(last_out, pdtype.param_shape()[0]//2, name='logstd',
                #                        #kernel_initializer=U.normc_initializer(0.01)
                #                        )
                pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
            else:
                pdparam = tf.layers.dense(last_out, pdtype.param_shape()[0], name='final',
                                          kernel_initializer=U.normc_initializer(0.01)
                                          )

        self.pd = pdtype.pdfromflat(pdparam)

        self.state_in = []
        self.state_out = []

        stochastic = tf.placeholder(dtype=tf.bool, shape=())
        ac = U.switch(stochastic, self.pd.sample(), self.pd.mode())
        # self._act = U.function([stochastic, ob], [ac, self.vpred, self.vpred_true],
        #                        sess=self.sess, graph=self.graph)
        self._act = U.function([stochastic, ob], [ac],
                               sess=self.sess, graph=self.graph)
        self._v_pred = U.function([ob, f_output], [self.vpred, self.vpred_true],
                                  sess=self.sess, graph=self.graph)

    def act(self, stochastic, ob):
        with self.graph.as_default():
            # ac1, vpred1, vpred_true = self._act(stochastic, ob[None])
            # return ac1[0], vpred1[0], vpred_true[0]
            ac1 = self._act(stochastic, ob[None])
            return ac1[0]

    def v_predict(self, ob, f_output):
        with self.graph.as_default():
            vpred1, vpred_true = self._v_pred(ob[None], f_output[None])
            return vpred1[0], vpred_true[0]

    def get_variables(self):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)

    def get_trainable_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)

    def get_initial_state(self):
        return []

    def get_policy_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/pol')

    def get_critic_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/vf_shaped')

    def get_true_critic_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope+'/vf_true')