import argparse
import pickle
import pprint
import os
import copy
from glob import glob

import numpy as np
import torch
import torchvision
from PIL import Image
from sklearn.cluster import KMeans
from sklearn.preprocessing import MinMaxScaler
from yellowbrick.cluster import KElbowVisualizer
import matplotlib.pyplot as plt

def main(args):
    os.makedirs(os.path.join(args.out_dir, args.factor), exist_ok=True)

    factor_traj = {}
    global_dict = {}

    base_key = None
    
    source_dir_list = glob(os.path.join(args.root, args.factor, '*.pkl'))
    for i, instance_file in enumerate(source_dir_list):
        k = i
        if "base" in instance_file:
            base_key = k
        factor_traj[k] = instance_file

    factor_traj = dict(sorted(factor_traj.items()))
    pprint.pprint(factor_traj)

    # base
    with open(factor_traj[base_key], 'rb') as f:
        data = pickle.load(f)
    # print(data.keys()) # dict_keys(['observations', 'actions', 'rewards', 'terminals', 'infos'])
    
    save_data = {
        "frame": [],
        "action": [],
        "reward": [],
        "done": [],
        "info": [],
    }

    for i in range(len(data["observations"]["sensor"])):
        # print(data["observations"].keys()) # dict_keys(['sensor', 'image', 'task'])
        # print(data["observations"]["image"][i].shape) # (224, 224, 3)
        # print(data["actions"][i]) # [0.75 0.   0.  ]
        save_data["frame"].append(data["observations"]["image"][i])
        save_data["action"].append(data["actions"][i])
        save_data["reward"].append(data["rewards"][i])
        save_data["done"].append(data["terminals"][i])
        save_data["info"].append(data["infos"][i])

    global_dict[base_key] = copy.deepcopy(save_data)

    for key in factor_traj.keys():

        if key == base_key:
            continue

        with open(factor_traj[key], 'rb') as f:
            data = pickle.load(f)

        save_data = {
            "frame": [],
            "action": [],
            "reward": [],
            "done": [],
            "info": [],
        }

        for i in range(len(data["observations"]["sensor"])):
            save_data["frame"].append(data["observations"]["image"][i])
            save_data["action"].append(data["actions"][i])
            save_data["reward"].append(data["rewards"][i])
            save_data["done"].append(data["terminals"][i])
            save_data["info"].append(data["infos"][i])

        global_dict[key] = copy.deepcopy(save_data)


    # action clustering
    actions = []
    for key in factor_traj.keys():
        for action in global_dict[key]["action"]:
            actions.append(action.copy())
        
    print("total acions:", len(actions))
    actions = np.array(actions)
    scaler = MinMaxScaler()
    actions_scaled = scaler.fit_transform(actions.copy())

    model = KMeans()
    visualizer = KElbowVisualizer(model, k=(1,10))
    visualizer.fit(actions_scaled)
    visualizer.show(outpath=os.path.join(args.out_dir, args.factor,"recommand.png"))

    k = args.K  # easy:5 medium: 5, hard: 4
    
    model = KMeans(n_clusters = k, random_state = 10)
    model.fit(actions_scaled)
    output = model.fit_predict(actions_scaled)

    labels = model.labels_
    centers = model.cluster_centers_

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d') 

    ax.scatter(  actions[:,0]
           , actions[:,1]
           , actions[:,2]
           , c = labels
           , s = 10
           , cmap = "rainbow"
           , alpha = 1
          )
    ax.set_xlabel('throtle')
    ax.set_ylabel('steering')
    ax.set_zlabel('brake')
    plt.savefig(os.path.join(args.out_dir, args.factor,'cluster.png'))

    # action labeling
    cnt = 0
    for key in factor_traj.keys():
        for i in range(len(global_dict[key]["action"])):
            action = np.expand_dims(actions_scaled[cnt], axis=0)
            action_label = model.predict(action)
            global_dict[key]["action"][i] = action_label.copy()
            cnt += 1
    assert cnt == len(actions)
    
    os.makedirs(args.out_dir, exist_ok=True)
    
    with open(os.path.join(args.out_dir, args.factor, args.file_name), 'wb') as f:
        global_dict["classes"] = ["buildings"]
        global_dict["actions"] = list(range(args.K))
        pickle.dump(global_dict, f, pickle.HIGHEST_PROTOCOL)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Expert Trajectory Preprocessing')
    # dataset parameters
    parser.add_argument('root', metavar='DIR',
                        help='root path of trajectory')
    parser.add_argument('--factor', type=str, default='FOV')
    parser.add_argument('--K', type=int, default=4)
    parser.add_argument('--out-dir', type=str, default='default')
    parser.add_argument('--file-name', type=str, default='train_dataset.pkl')
    args = parser.parse_args()
    main(args)
