from pathlib import Path
import cv2
import pickle
import pybullet
import numpy as np
import time
import matplotlib.pyplot as plt
import torchrl.networks as networks
import torchrl.policies as policies
from gym.wrappers.monitoring.video_recorder import VideoRecorder
import gym.wrappers as wrappers
import gym
import torch.nn.functional as F
from policydissect.quadrupedal.torchrl.utils import get_params
from policydissect.quadrupedal.vision4leg.get_env import get_env
import seaborn as sns
import random
import argparse
import torch
import os
import sys
sys.path.append(".")


def get_args():
    parser = argparse.ArgumentParser(description='RL')
    parser.add_argument('--seed', type=int, default=0, help='random seed (default: 1)')
    parser.add_argument('--env_seed', type=int, default=0, help='random seed (default: 1)')
    parser.add_argument('--env_name', type=str, default='A1MoveGround')
    parser.add_argument('--num_episodes', type=int, default=1, help='number of episodes')
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="config file",
    )
    parser.add_argument(
        '--save_dir', type=str, default='./snapshots', help='directory for snapshots (default: ./snapshots)'
    )
    parser.add_argument('--video_dir', type=str, default='./video')
    parser.add_argument('--log_dir', type=str, default='./log', help='directory for tensorboard logs (default: ./log)')
    parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training')
    parser.add_argument(
        "--device",
        type=int,
        default=0,
        help="gpu secification",
    )
    parser.add_argument('--add_tag', type=str, default='', help='directory for snapshots (default: ./snapshots)')
    parser.add_argument('--snap_check', type=str, default='best')

    # tensorboard
    parser.add_argument(
        "--id",
        type=str,
        default=None,
        help="id for tensorboard",
    )

    args = parser.parse_args()

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    return args


args = get_args()

# params = get_params(args.config)

# from metaworld_utils.meta_env import get_meta_env

np.random.seed(0)
random.seed(0)

PARAM_PATH = "{}\{}\{}\{}\params.json".format(args.log_dir, args.id, args.env_name, args.seed)
# params = get_params(args.config)
params = get_params(PARAM_PATH)

# params["env"]["env_build"]["enable_rendering"] = False
params["env"]["env_build"]["enable_rendering"] = True
env = get_env(params['env_name'], params['env'])

if hasattr(env, "_obs_normalizer"):

    NORM_PATH = "{}/{}/{}/{}/model/_obs_normalizer_{}.pkl".format(
        args.log_dir, args.id, params['env_name'], args.seed, args.snap_check
    )
    with open(NORM_PATH, 'rb') as f:
        env._obs_normalizer = pickle.load(f)
        print(env._obs_normalizer._mean)
        print(env._obs_normalizer._var)

env.eval()

# params['net']['base_type'] = networks.MLPBase
params['net']['activation_func'] = torch.nn.ReLU
# params['net']['activation_func'] = torch.nn.ReLU

obs_normalizer = env._obs_normalizer if hasattr(env, "_obs_normalizer") \
  else None

encoder = networks.LocoTransformerEncoder(
    in_channels=env.image_channels, state_input_dim=env.observation_space.shape[0], **params["encoder"]
)

pf = policies.GaussianContPolicyLocoTransformer(
    encoder=encoder,
    state_input_shape=env.observation_space.shape[0],
    visual_input_shape=(env.image_channels, 64, 64),
    output_shape=env.action_space.shape[0],
    **params["net"],
    **params["policy"]
)

PATH = "{}/{}/{}/{}/model/model_pf_{}.pth".format(args.log_dir, args.id, params['env_name'], args.seed, args.snap_check)

# PATH="log/MT10_MODULAR_4_4_2_NO_BN_GATED_REWEIGHT_RAND_CAS_PRESOFTMAX_ACTI_128/mt10/3/model/model_pf_best.pth"
# Value_PATH="log/MT10_MODULAR_4_4_2_NO_BN_GATED_REWEIGHT_RAND_CAS_PRESOFTMAX_ACTI_128/mt10/3/model/model_qf1_best.pth"
current_pf_dict = pf.state_dict()
current_pf_dict.update(torch.load(PATH, map_location="cuda:0"))
pf.load_state_dict(torch.load(PATH, map_location="cuda:0"))

pf.eval()
# pf.load_state_dict(torch.load(PATH, map_location="cuda:0"))
# qf1.load_state_dict(torch.load(Value_PATH, map_location="cuda:0"))

action_weights = []

num_episodes = args.num_episodes

success = 0
count = 0
total_success = 0

# current_path = os.path.dirname(__file__)
# xml_root_path = os.path.join(current_path,"../torchmetal/envs/updated")
# from torchmetarl.envs.mujoco_dict import SETTING_DICT, ENV_DICT, TIMELIMIT_DICT
# def make_env(env_name, env_setting, env_build_params):
#     file_path = os.path.join(
#         os.path.join(xml_root_path, env_name),
#         SETTING_DICT[env_setting])
#     env = ENV_DICT[env_name](file_path, **env_build_params)
#     return env
# xml_root_path = "/"
# os.path.join(current_path,"updated")
# MODEL_CONDITION = [
#     [0, 0],
#     [0, 1],
#     [0, 2],
#     [1, 0],
#     [1, 1],
#     [1, 2],
#     [2, 0],
#     [2, 1],
#     [2, 2],
# ]

# REWARD_COEFF = [[np.cos(np.pi / 14 * i), np.sin(np.pi / 14 * i)] for i in range(8)]

rewards = []
task_names = []

video_output_path = "{}/{}/{}/{}".format(args.video_dir, args.id, params['env_name'], args.seed)

Path(video_output_path).mkdir(parents=True, exist_ok=True)

total_reward = []
total_moving_distance = []

epi_reward = 0
epi_moving = 0

# for idx in range(1):
for _ in range(5):

    # last_weights = []
    # general_weights = []
    # cat_weights = []
    # record_id = env.pybullet_client.startStateLogging(
    #     pybullet.STATE_LOGGING_VIDEO_MP4,
    #     "./record.mp4"
    # )
    import time
    t = time.time()
    count = 0
    morpho_action_weights = []

    env.seed(args.env_seed)
    random.seed(args.env_seed)
    torch.manual_seed(args.env_seed)
    np.random.seed(args.env_seed)
    # import pybullet as p
    # log_id = p.startStateLogging(p.STATE_LOGGING_VIDEO_MP4, "racecar.mp4")
    # env = wrappers.Monitor(env_single, "view_log", video_callable=None, force=True)

    from vidgear.gears import WriteGear
    output_params = {"-vcodec": "libx264", "-crf": 0, "-preset": "fast"}
    # writer = WriteGear(
    #   output_filename=os.path.join(
    #     video_output_path, 'Output_{}{}.mp4'.format(args.snap_check, args.add_tag)),
    #   logging=True, **output_params
    # )
    # out = cv2.VideoWriter('project.mp4',cv2.VideoWriter_fourcc(*'mp4v'), 20.00, (360, 480))
    # obs = env_single.reset()
    reward = 0
    # x_sped = 0
    # y_sped = 0
    step = 0

    # env = get_env(
    #     "LaikagoMoveForward",
    #     "default",
    #     {
    #         "reward_scale": 1,
    #         "obs_norm": False,
    #         "env_build": {
    #             "enable_rendering_gui": True
    #         }
    #     }
    # )
    env.seed(0)
    random.seed(0)
    np.random.seed(0)
    obs = env.reset()
    # env.pybullet_client.configureDebugVisualizer(
    #   env.pybullet_client.COV_ENABLE_RENDERING, 1)
    # success = 0
    while True:
        # env_single.render()
        frame_1 = obs[-64 * 64:]
        frame_1 = frame_1.reshape((64, 64))

        # import matplotlib.pyplot as plt
        # plt.imshow(frame_1, cmap='gray')
        # plt.pause(0.01)
        ob_t = torch.Tensor(obs).unsqueeze(0)
        # print(ob_t)
        # ob_t[0, -4:] =
        # ob_t[0, -4:] = 1
        # ob_t[0, -3] = 0
        # ob_t[0, -1] = 0
        # embedding_input = np.zeros(env_info.num_tasks)
        # embedding_input = torch.zeros(env_to_wrap.num_tasks)
        # embedding_input[idx] = 1
        # embedding_input = torch.cat([torch.Tensor(env_single.goal.copy()), embedding_input])
        # embedding_input = embedding_input.unsqueeze(0)
        # print(ob_t)
        # action, general_weight, last_weight = pf.eval_act(ob_t, embedding_input, return_weights = True )
        action = pf.eval_act(ob_t)
        # action = pf.explore(ob_t)
        # action = action["action"].squeeze(0).cpu().numpy()
        # print(action.shape)
        # print(action)
        count += 1
        # action = np.array([0,0,0,0,0,0])
        # action = -action
        # action = np.array([1,1,1,1,1,1])
        # print(action)
        #  = weights
        # last_weight = F.softmax(last_weight, dim=-1)
        # # last_weight = last_weight.exp()
        # general_weight = F.softmax(general_weight, dim = -1)
        # print(general_weight[0].shape)

        morpho_action_weights.append(action)
        import time
        # time.sleep(0.3)

        obs, rew, done, info = env.step(action)
        epi_reward += rew
        pos = env.env.env.env.env.env.env._gym_env._robot.GetBasePosition()
        epi_moving = pos[0]
        # img = env.render(mode='rgb_array')
        #
        # sim_model = env.robot.quadruped
        # pyb = env.pybullet_client
        # root_vel_sim, root_ang_vel_sim = pyb.getBaseVelocity(sim_model)
        #
        # # import cv2
        # # img = cv2.putText(
        # #     img=np.copy(img),
        # #     text="Speed X: {:.4f}".format(root_vel_sim[0]),
        # #     org=(10, 30),
        # #     fontFace=1, fontScale=1, color=(0, 0, 255), thickness=1)
        #
        # # import cv2
        # # img = cv2.putText(
        # #     img=np.copy(img),
        # #     text="Speed Y: {:.4f}".format(root_vel_sim[1]),
        # #     org=(10, 50),
        # #     fontFace=1, fontScale=1, color=(0, 0, 255), thickness=1)
        #
        # # import cv2
        # # img = cv2.putText(
        # #     img=np.copy(img),
        # #     text="Speed Z: {:.4f}".format(root_vel_sim[2]),
        # #     org=(10, 70),
        # #     fontFace=1, fontScale=1, color=(0, 0, 255), thickness=1)
        #
        # # rot_quat = env.robot.GetBaseOrientation()
        # # import cv2
        # # img = cv2.putText(
        # #     img=np.copy(img),
        # #     text="Rot Quat 1: {:.4f}".format(rot_quat[0]),
        # #     org=(200, 30),
        # #     fontFace=1, fontScale=1, color=(0, 0, 255), thickness=1)
        #
        # # import cv2
        # # img = cv2.putText(
        # #     img=np.copy(img),
        # #     text="Rot Quat 1: {:.4f}".format(rot_quat[1]),
        # #     org=(200, 50),
        # #     fontFace=1, fontScale=1, color=(0, 0, 255), thickness=1)
        #
        # # import cv2
        # # img = cv2.putText(
        # #     img=np.copy(img),
        # #     text="Rot Quat 2: {:.4f}".format(rot_quat[2]),
        # #     org=(200, 70),
        # #     fontFace=1, fontScale=1, color=(0, 0, 255), thickness=1)
        #
        # # import cv2
        # # img = cv2.putText(
        # #     img=np.copy(img),
        # #     text="Rot Quat 3: {:.4f}".format(rot_quat[3]),
        # #     org=(200, 90),
        # #     fontFace=1, fontScale=1, color=(0, 0, 255), thickness=1)
        # # plt.text(
        # #     x=10,
        # #     y=10,
        # #     s="Speed: {}".format(root_vel_sim)
        # # )
        # # print(img.shape)
        # # import matplotlib.pyplot as plt
        # # plt.imshow(img)
        # # plt.pause(0.01)
        # # out.write(img)
        # # print(img.shape)
        # w, h, _ = img.shape
        # frame_1 = (frame_1 - np.min(frame_1)) / \
        #   (np.max(frame_1) - np.min(frame_1))
        # frame_resize = cv2.resize(frame_1, (w, w))
        # frame_resize = (frame_resize * 255).astype(img.dtype)
        # frame_resize = frame_resize.reshape((w, w, 1))
        # frame_resize = np.repeat(frame_resize, 3, axis=2)
        # # print(frame_resize)
        # # import matplotlib.pyplot as plt
        # # # plt.imshow(frame_resize)
        # # # plt.imshow(img)
        # # plt.imshow(np.concatenate([img, frame_resize], axis=1))
        # # plt.pause(0.01)
        #
        # # print(np.vstack([img, frame_resize]).shape)
        # writer.write(frame_resize, rgb_mode=True)
        # writer.write(np.concatenate(
        #   [img, frame_resize], axis=1), rgb_mode=True)
        # print(info["success"])
        # success = max(success, info["success"])
        # env.render()
        # if info["success"] and not done:
        # print("FXXk")
        # exit()
        reward += rew
        step += 1
        # x_sped += info["x_velocity"]
        # y_sped += info["y_velocity"]
        if done:
            total_reward.append(epi_reward)
            total_moving_distance.append(epi_moving)

            with open("{}.pkl".format("eval_result"), "wb+") as file:
                pickle.dump(dict(reward=total_reward, move_distance=total_moving_distance), file)
            # total_success += success
            # count += 1
            # print("EPoch:", idx, task_name, success, total_success / count)
            break
            # if not info["success"]:
            #     print("CAo")
            #     exit()
            # obs = env.reset()
            # idx += 1
    #         obs = env_single.reset()
    #         success = 0
    #         t = []
    #         t_now = 0
    # # m = [sin(t_now)]
    #         m = []
    # if len(info["rewards"]) > num_episodes:
    #     if len(info["rewards"]) == 1 and video_recorder.enabled:
    #         # save video of first episode
    #         print("Saved video.")
    #     print(info["rewards"][-1])
    #     num_episodes = le
    print("finish eposide")
    fps = count / (time.time() - t)
    print("Reward:", reward / num_episodes, "FPS:", fps)
    print("Step Counts:", step)

    # out.release()
    # writer.close()
    # print("X Speed:", REWARD_COEFF[idx], x_sped / num_episodes)
    # print("Y Speed:", REWARD_COEFF[idx], y_sped / num_episodes)
    rewards.append(reward)
    # env.pybullet_client.stopStateLogging(log_id)
    # p.stopStateLogging(log_id)

    # exit()
