

import copy
import pickle as pkl

import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

def get_xy_linspace_for_visualize(env_name, num_test_points):
    if env_name in ['sawyer_peg_push']:
        x = np.linspace(-0.6, 0.6, num_test_points)
        y = np.linspace(0.2, 1.0, num_test_points)
    elif env_name in ['Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
        x = np.linspace(-18, 18, num_test_points)
        y = np.linspace(-18, 18, num_test_points)
    elif env_name in ['Point2WaySpiralMaze-v0']:
        x = np.linspace(-14, 14, num_test_points)
        y = np.linspace(-18, 18, num_test_points)        
    elif env_name in ['AntMazeComplex2Way-v0']:
        x = np.linspace(-6, 6, num_test_points)
        y = np.linspace(-10, 10, num_test_points)        
    else:
        return None, None
    
    return x,y

def plot_obstacles(env_name, ax):
    if env_name in ['Point4WayComplexVer2Maze-v0']:        
        obstacle_point_x = np.array([-2, -2, -14, -14, -6, -6])
        obstacle_point_y = np.array([-18, -2, -2, -14, -14, -18])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([18, 2, 2, 14, 14, 18])
        obstacle_point_y = np.array([-2, -2, -14, -14, -6, -6])      
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([-18, -14, -14, -2, -2, -18])
        obstacle_point_y = np.array([6, 6, 14, 14, 2, 2])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([6, 6, 14, 14, 2, 2])
        obstacle_point_y = np.array([18, 14, 14, 2, 2, 18])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
    elif env_name in ['Point4WayFarmlandMaze-v0']:        
        obstacle_point_x = np.array([-2, -2, -14, -14, -2])
        obstacle_point_y = np.array([-14, -2, -2, -14, -14])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([14, 2, 2, 14, 14])
        obstacle_point_y = np.array([-2, -2, -14, -14, -2])      
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([-14, -2, -2, -14, -14])
        obstacle_point_y = np.array([14, 14, 2, 2, 14])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([14, 14, 2, 2, 14])
        obstacle_point_y = np.array([14, 2, 2, 14, 14])      
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')    
    elif env_name in ['Point2WaySpiralMaze-v0']:        
        obstacle_point_x = np.array([-14, -14, 6, 6, -14])
        obstacle_point_y = np.array([-14, -10, -10, -14, -14])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([6, 6, 10, 10, 6])
        obstacle_point_y = np.array([-14, 2, 2, -14, -14])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([-2, -2, 10, 10, -2])
        obstacle_point_y = np.array([2, 6, 6, 2, 2])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([-10, -10, 2, 2, -10])
        obstacle_point_y = np.array([-6, -2, -2, -6, -6])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([-10, -10, -6, -6, -10])
        obstacle_point_y = np.array([-2, 14, 14, -2, -2])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([-6, -6, 14, 14, -6])
        obstacle_point_y = np.array([10, 14, 14, 10, 10])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
    elif env_name in ['AntMazeComplex2Way-v0']:
        obstacle_point_x = np.array([-6, 2, 2, -6])
        obstacle_point_y = np.array([-6, -6, -2, -2])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
        obstacle_point_x = np.array([6, -2, -2, 6])
        obstacle_point_y = np.array([2, 2, 6, 6])
        ax.plot(obstacle_point_x, obstacle_point_y, c = 'black')
    else:
        pass



def visualize_d2c_all_together(agent, scatter_states, env_name, savedir_w_name, device, uniform_goal_sampler=None, multi_target=False,  env = None):
    disc_vis_start_time = time.time()    
    
    num_test_points = 60
    x,y = get_xy_linspace_for_visualize(env_name=env_name, num_test_points=num_test_points)
    if x is None or y is None:
        return
    
    grid_x, grid_y = np.meshgrid(x,y)    
    goal_xy = np.concatenate([np.reshape(grid_x, [-1, 1]), np.reshape(grid_y, [-1, 1])], axis =1) #[num_test_points^2, 2]
    
    
    
    if env_name in ['sawyer_peg_push']:
        goal_xy = np.concatenate([goal_xy, 0.01457*np.ones([goal_xy.shape[0], 1])], axis=-1) #[num_test_points^2, 3]

    if agent.d2c_gcrl:
        if uniform_goal_sampler is not None:
            if multi_target:
                obs_desired_goals = []
                for i in range(4):
                    obs_desired_goals.append(uniform_goal_sampler.sample(num_sample=1, sample_feasible=True))
                obs_desired_goals = np.stack(obs_desired_goals, axis =0)
            else:
                obs_desired_goal = uniform_goal_sampler.sample(num_sample=1, sample_feasible=True)
        else:
            if env_name in ['sawyer_peg_push']:
                if multi_target:
                    obs_desired_goals = np.array([[-0.3, 0.4, 0.02],
                                                [-0.3, 0.8, 0.02],
                                                [0.4, 0.4, 0.02]])
                else:
                    obs_desired_goal = np.array([-0.3, 0.4, 0.02])
                              
            elif env_name in ['Point4WayFarmlandMaze-v0']:
                obs_desired_goals = np.array([[16., 16.],
                                            [-16., -16.],
                                            [16., -16.],
                                            [-16., 16.]])
            elif env_name in ['Point4WayComplexVer2Maze-v0']:
                obs_desired_goals = np.array([[8., 16.],
                                            [-8., -16.],
                                            [16., -8.],
                                            [-16., 8.]])   
            elif env_name in ['Point2WaySpiralMaze-v0']:
                obs_desired_goals = np.array([[12., 16.],
                                            [-12., -16.]])
            elif env_name in ['AntMazeComplex2Way-v0']:
                obs_desired_goals = np.array([[4., 8.],
                                            [-4., -8.]])
            
            else:
                raise NotImplementedError
        
        if multi_target:
            observes_list = []
            for obs_desired_goal in obs_desired_goals:
                observes_list.append(np.concatenate([goal_xy, np.tile(obs_desired_goal, (goal_xy.shape[0], 1))], axis =-1)) # [num_test_points^2, goal_dim*2]
            observes = np.stack(observes_list, axis = 0).reshape(-1, goal_xy.shape[-1]+obs_desired_goal.shape[-1]) # [num_target_goals*num_test_points^2, goal_dim*2]
            
        else:
            observes = np.concatenate([goal_xy, np.tile(obs_desired_goal, (goal_xy.shape[0], 1))], axis =-1) # [num_test_points^2, goal_dim*2]
    else:
        observes = goal_xy #torch.as_tensor(goal_xy).float()# [num_test_points^2, dim*2]        
    
    observes = torch.from_numpy(observes).float().to(device)

    probs = agent.get_prob_by_d2c(observes)        
    with torch.no_grad():        
        if agent.d2c_normalize:
            observes = agent.normalize_obs(observes, agent.env_name)            
        preds = agent.d2c(observes).sigmoid().detach().cpu().numpy() # [num_test_points^2, heads]


    figure_scale=2
    # if agent.d2c_gcrl and env_name in ['Point2WayMaze-v0', 'Point3WayMaze-v0', 'Point4WayMaze-v0', 'Point4WayComplexMaze-v0', 'Point2WayNMaze-v0', 'Point2WaySpiralMaze-v0', 'AntMazeComplex2Way-v0', 'Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0', 'Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
    if agent.d2c_gcrl and multi_target: 
        num_target_goals = obs_desired_goals.shape[0]
        
        # [num_target_goals, num_test_points^2] 
        probs = probs.reshape(num_target_goals, num_test_points**2)
        # [num_target_goals, num_test_points^2, heads]
        preds = preds.reshape(num_target_goals, num_test_points**2, agent.d2c.heads)


        fig, axs = plt.subplots(nrows=agent.d2c.heads+1, ncols=num_target_goals, sharex=True, figsize=(4*num_target_goals*figure_scale, 12*figure_scale))
        for j in range(num_target_goals):
            for i in range(agent.d2c.heads+1):
                if i==0:
                    outputs = probs[j]
                    title = 'd2c_prob_target_goal'+str(j)
                
                else:
                    outputs = preds[j, :, i-1]
                    title = f'd2c_head{i-1}_target_goal'+str(j)
                    
                v_min, v_max = outputs.min(), outputs.max()
                
                outputs = np.reshape(outputs, [num_test_points, num_test_points])
                
                c = axs[i,j].pcolormesh(grid_x, grid_y, outputs, cmap='RdBu', vmin=0, vmax=1)
                
                plot_obstacles(env_name=env_name, ax=axs[i,j])

                axs[i,j].set_title(title)        
                axs[i,j].axis([grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()])
                fig.colorbar(c, ax=axs[i,j])
                axs[i,j].axis('tight')
                # axs[i,j].legend(loc='best', prop={'size': 10})
    else:
        fig, axs = plt.subplots(nrows=agent.d2c.heads+1, ncols=1, sharex=True, figsize=(4*figure_scale, 12*figure_scale))

        for i in range(agent.d2c.heads+1):
            if i==0:
                outputs = probs
                title = 'd2c_prob_visualize'
            else:
                outputs = preds[:, i-1]
                title = f'd2c_head{i-1}_visualize'

            v_min, v_max = outputs.min(), outputs.max()           
            
            outputs = np.reshape(outputs, [num_test_points, num_test_points])
            

            c = axs[i].pcolormesh(grid_x, grid_y, outputs, cmap='RdBu', vmin=0, vmax=1)
            
            plot_obstacles(env_name=env_name, ax=axs[i])

            axs[i].set_title(title)        
            axs[i].axis([grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()])
            fig.colorbar(c, ax=axs[i])
            axs[i].axis('tight')
            # axs[i].legend(loc='best', prop={'size': 10})

    # plt.legend(loc="best")
    plt.savefig(savedir_w_name+'.jpg')
    plt.close()   
    disc_vis_end_time = time.time()

def visualize_vf(agent, initial_state, scatter_states, env_name, savedir_w_name, device):
    disc_vis_start_time = time.time()
    
    num_test_points = 60
    x,y = get_xy_linspace_for_visualize(env_name=env_name, num_test_points=num_test_points)
    if x is None or y is None:
        return
    
    grid_x, grid_y = np.meshgrid(x,y)    
    goal_xy = np.concatenate([np.reshape(grid_x, [-1, 1]), np.reshape(grid_y, [-1, 1])], axis =1) #[num_test_points^2, 2]
    
    if env_name in ['sawyer_peg_push']:
        goal_xy = np.concatenate([goal_xy, 0.01457*np.ones([goal_xy.shape[0], 1])], axis=-1) #[num_test_points^2, 3]

    goal_xy = torch.from_numpy(goal_xy).float().to(device)
    
    
    initial_state = torch.tile(torch.from_numpy(initial_state).float().to(device)[None, :], (num_test_points**2, 1)) # [num_test_points^2, dim]
    obs_t = torch.cat([initial_state, goal_xy], dim = -1) # [num_test_points^2, dim]
    
    with torch.no_grad():
        
        if agent.normalize_rl_obs:
            obs_t = agent.normalize_obs(obs_t, env_name)
                        
        n_sample = 10
        tiled_obs_t = torch.tile(obs_t, (n_sample, 1, 1)).view((-1, obs_t.shape[-1])) #[num_test_points, dim] -> [n_sample*num_test_points, dim]
        
        dist = agent.actor(obs_t) # obs : [num_test_points, dim]
        action = dist.rsample((n_sample,)) # [n_sample, num_test_points, dim]
        action = action.view((-1, action.shape[-1])) # [n_sample*num_test_points, dim]
        actor_Q1, actor_Q2 = agent.critic(tiled_obs_t, action)
        actor_Q = torch.min(actor_Q1, actor_Q2).view(n_sample, -1, actor_Q1.shape[-1]) # [n_sample*num_test_points, dim(1)] -> [n_sample, num_test_points, dim(1)] 
        value = torch.mean(actor_Q, dim = 0).detach().cpu().numpy()[:,0] #[ts, dim(1)] -> [ts,]
    
        if agent.rl_reward_type=='d2c':
            if agent.d2c_reward_type=='positive':
                value = np.clip(value, 0, 1.0/(1.0-agent.discount))
            elif agent.d2c_reward_type=='negative':
                value = np.clip(value, -1.0/(1.0-agent.discount), 0)
        elif agent.rl_reward_type=='sparse':
            if agent.sparse_reward_type=='positive':
                value = np.clip(value, 0, 1.0/(1.0-agent.discount))
            elif agent.sparse_reward_type=='negative':
                value = np.clip(value, -1.0/(1.0-agent.discount), 0)
        else:
            raise NotImplementedError
        
    outputs = value
    v_min, v_max = outputs.min(), outputs.max()           
    
    outputs = np.reshape(outputs, [num_test_points, num_test_points])
    
    fig, ax = plt.subplots()

    c = ax.pcolormesh(grid_x, grid_y, outputs, cmap='RdBu', vmin=v_min, vmax=v_max)
    
    
    if scatter_states.ndim==1:
        ax.scatter(scatter_states[0], scatter_states[1], marker="*", c = 'black', s=10, label='Current_position')
    else:
        for t in range(scatter_states.shape[0]):
            ax.scatter(scatter_states[t, 0], scatter_states[t, 1], marker="*", c = str(1.-t/scatter_states.shape[0]) , s=30, label='s_'+str(t))

    plot_obstacles(env_name=env_name, ax=ax)

    ax.set_title('vf_visualize')        
    ax.axis([grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()])
    fig.colorbar(c, ax=ax)
    ax.axis('tight')
    # plt.legend(loc="best")
    plt.savefig(savedir_w_name+'.jpg')
    plt.close()   
    disc_vis_end_time = time.time()
    