import gym
import CustomNChain
import numpy as np
import time
import torch

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from fqf_iqn_qrdqn.agent import QRDQNAgent, EEAgent, MYAgent, PQRAgent
# from fqf_iqn_qrdqn.env import WarpFramePyTorch
from gym import spaces, wrappers
import cv2
cv2.ocl.setUseOpenCL(False)

import ray
from ray import tune
import wandb
from ray.tune.integration.wandb import wandb_mixin

import os
current_path = os.getcwd()


class WarpFramePyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        """
        Warp frames to 84x84 as done in the Nature paper and later work.
        :param env: (Gym Environment) the environment
        """
        gym.ObservationWrapper.__init__(self, env)
        self.width = 84
        self.height = 84
        self.observation_space = spaces.Box(
            low=0, high=255, shape=(1, self.height, self.width),
            dtype=env.observation_space.dtype)

    def observation(self, state=None):
        """
        returns the current observation from a frame
        :param frame: ([int] or [float]) environment frame
        :return: ([int] or [float]) the observation
        """
        frame = self.get_obs()


        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(
            frame, (self.width, self.height), interpolation=cv2.INTER_AREA)

        return frame[None, :, :]

def print_Q_table(self):
    # Make Q(s,a) table
    env2 = gym.make('CustomNChain-v0', small=small, large=large, std=std)
    env2= WarpFramePyTorch(env2)
    state = env2.reset(start_state=0)
    print("#########################################")
    for i in range(5):
        # env.render()
        self.online_net.eval()       
        observation = torch.ByteTensor(np.expand_dims(state, axis=0)).cuda().float() /255
        with torch.no_grad():
            Q_value = self.online_net.calculate_q(observation).tolist()

        print("Q(s{},a) :".format(i), Q_value)

        state = env2.step(0)[0]
    print("#########################################")

def plot_Q(self, states=None, state_embeddings=None, ax=None):
    assert states is not None or state_embeddings is not None
    batch_size = states.shape[0] if states is not None\
            else state_embeddings.shape[0]
    self.online_net.eval()
    with torch.no_grad():
        quantiles = self.online_net(states=states, state_embeddings=state_embeddings)
        # quantile.shape = (1,200,6)

        if self.name == 'QRDQN':

            perturbed_quantiles = quantiles.mean(dim=0)
            perturbed_quantiles = perturbed_quantiles.T.cpu()

        if self.name == 'DLTV':
            LTV = torch.sort(quantiles).values[:, self.N//2: self.N] #Left Truncated Variance
            LTV = LTV.std(axis=1) 
            assert LTV.shape == (1, self.env.action_space.n)

            perturbed_quantiles = quantiles.mean(dim=0) + self.coefficient_C.get() * LTV
            perturbed_quantiles = perturbed_quantiles.T.cpu()
            #perturbed_quantile.shape = (6,1)

        elif self.name == 'p-DLTV':
            C = self.coefficient_C.get()
            perturbed_quantiles = self.online_net.calculate_p_LTV(states=states, state_embeddings=state_embeddings,coefficient_C=C)
            perturbed_quantiles = perturbed_quantiles.T.cpu()

        elif self.name == 'PQR':
            xi = self.xi
            perturbed_quantiles = self.online_net.calculate_P(states=states, xi=xi)
            perturbed_quantiles = perturbed_quantiles.T.cpu()

    quantiles = quantiles.mean(dim=0)

    assert quantiles.shape == (200, self.env.action_space.n)
    quantiles = quantiles.T.cpu()


    df = pd.DataFrame({'a0': quantiles[0], 'a1':quantiles[1]})
    sns.distplot( df['a0'], color='Red',ax=ax, hist=False)
    sns.distplot( df['a1'], color='Blue', ax=ax, hist=False)



    ax.axvline(df['a0'].mean(), color='Red', linestyle='dashed', linewidth=2, 
        label='a0 : mean {:,.2f}, std {:,.2f}'.format(df['a0'].mean(), df['a0'].std() ))
    ax.axvline(df['a1'].mean(), color='Blue', linestyle='dashed', linewidth=2, 
        label='a1 : mean {:,.2f}, std {:,.2f}'.format(df['a1'].mean(), df['a1'].std() ))
    ax.set_ylim(0,0.4)

    if self.name == 'DLTV' or 'p-DLTV' or 'PQR':
        df2 = pd.DataFrame({'a0': perturbed_quantiles[0], 'a1': perturbed_quantiles[1]})
        ax.plot(df2['a0'].mean(), 0 , marker='o', ms=5, mec='r' ,mfc='r', clip_on=False)
        ax.plot(df2['a1'].mean(), 0 , marker='o', ms=5, mec='b', mfc='b' ,clip_on=False)




@wandb_mixin
def run(config):
    import CustomNChain

    # small, large, std = [5,13], 10 , 0.5 # reward: (small + 5* std , large + std)
    gamma = 0.9
    num_step = 20001

    lr = config['lr']
    seed = config['seed']
    agent_name = config['agent_name']

    start_steps = 2000
    info ='(s0 = N({}, {}^2), s4 = N({}, {}^2), gamma={})'.format(small, std, large, std, gamma)

    env = gym.make('CustomNChain-v0', small=small, large=large, std=std)
    env = WarpFramePyTorch(env)
    test_env = gym.make('CustomNChain-v0', small=small, large=large, std=std)
    test_env = WarpFramePyTorch(test_env)

    log_dir = current_path + '/logs/CustomNChain'

    if agent_name == 'QRDQN':
        agent = QRDQNAgent(
                env=env, test_env=test_env, log_dir=log_dir, cuda=True,
                    batch_size=32, start_steps=start_steps, epsilon_train=0.01, epsilon_decay_steps=2500,
                    update_interval=4, target_update_interval=100, eval_interval=2500,
                    num_eval_steps=50, max_episode_steps=50, gamma = gamma, seed=seed)
    
    elif agent_name =='DLTV':
        agent = EEAgent(
                env=env, test_env=test_env, log_dir=log_dir, cuda=True,
                    batch_size=32, start_steps=start_steps, epsilon_train=0.01, epsilon_decay_steps=2500,
                    update_interval=4, target_update_interval=100, eval_interval=2500,
                    num_eval_steps=50, max_episode_steps=50, gamma = gamma, seed=seed)

    elif agent_name == 'p-DLTV':
        agent = MYAgent(
                env=env, test_env=test_env, log_dir=log_dir, cuda=True,
                    batch_size=32, start_steps=start_steps, epsilon_train=0.01, epsilon_decay_steps=2500,
                    update_interval=4, target_update_interval=100, eval_interval=2500,
                    num_eval_steps=50, max_episode_steps=50, gamma = gamma, seed=seed)
    elif agent_name == 'PQR':
        agent = PQRAgent(
                env=env, test_env=test_env, log_dir=log_dir, cuda=True,
                    batch_size=32, start_steps=start_steps, epsilon_train=0.01, epsilon_decay_steps=2500,
                    update_interval=4, target_update_interval=100, eval_interval=2500,
                    num_eval_steps=50, max_episode_steps=50, gamma = gamma, seed=seed)


    record_step = [5000, 10000, 15000, 20000]

    


    while True:
        agent.train_episode()

        if agent.episodes % 10 ==0:
            agent.evaluate()
            iteration = agent.episodes // 100

            print_Q_table(agent)   

        if agent.steps > record_step[k]:
            state = env.reset(start_state=0)

            fig, axes = plt.subplots(5,1, figsize=(5,10), sharex=False)
            
            plt.subplots_adjust(hspace= 0.5)
            for i in range(5):
                state = torch.ByteTensor(np.expand_dims(state, axis=0)).cuda().float() /255
                plot_Q(agent, states=state, ax=axes[i])
                axes[i].set_title('PDF of Z(s{}, a)'.format(i))
                axes[i].set_xlim([2,20])
                axes[i].legend()
                state = env.step(0)[0]
            # plt.suptitle('{} - iteration : {} \n {}'.format(agent.name, record_step[k] ,info))    
                # plt.xlabel('Q-value')
                axes[i].set_xlabel('')
            plt.suptitle(agent.name, fontsize=20)

            # Saving result
            file_name = 'c=inf_{}_{}-{}.png'.format(seed_num, agent.name, record_step[k])
            path_to_save = os.path.join(log_dir,'imgs', info)
            # path_to_save = log_dir + '/rebutall_img/' + info
            path_to_save = log_dir + '/rebutall_img/' + agent.name 

            print('##########################################')
            print(path_to_save)
            print('##########################################')

            if not os.path.isdir(path_to_save):
                os.makedirs(path_to_save)

            plt.savefig(path_to_save + '/' + file_name)
            plt.close()

            k += 1

            if k > len(record_step)-1 :
                break



if __name__ == '__main__':
    ray.init()
    small, large, std =  [10,10] , [5,13], 0.1 # reward: (small + 5* std , large + std)

    agent_list = ['QRDQN', 'DLTV','p-DLTV', 'PQR']

    for agent in agent_list:
        analysis = tune.run(run, 
            config = {
            'lr': tune.grid_search([5e-5]),
            'seed' : tune.grid_search([122, 123]),
            'agent_name': agent,
            'wandb': {
                'project': 'rebuttal-NChain_{}_{}-6noop'.format(small, large),
                'group': agent
                    }
            },
            fail_fast= True,
            resources_per_trial={"cpu":2, "gpu":1},
            name= "NChain-experiment",
            )



    