from typing import Optional
from os.path import join, exists
import numpy as np
import sys
import gtimer as gt
import glob
import cv2
import torch
from logging import info, error, basicConfig, INFO
from tqdm import tqdm
from abstract import Arrayable
from envs.env import Env
from agents.agent import Agent
from interact import interact
from utils import compute_return
from mylog import logto, log_video
from parse import setup_args
from config import configure


def infer(dt: float, epoch: int, env: Env, agent: Agent, # noqa: C901
             time_limit: Optional[float] = None, 
             progress_bar: bool = False, video: bool = False
            ) -> Optional[float]:

    agent.eval()
    agent.reset()
    
    rewards, dones = [], []
    imgs = []
    time_limit = time_limit if time_limit else 10
    nb_steps = int(time_limit / dt)
    # for pendulum, the physical time is 200 * 0.05 = 2*5 = 10 (seconds)
    info(f"infer> run the trained policy from epoch {epoch} on a physical time {time_limit}"
         f" ({nb_steps} steps in total)")
    obs = env.reset()
    iter_range = tqdm(range(nb_steps)) if progress_bar else range(nb_steps)
    for _ in iter_range:
        obs, reward, done = interact(env, agent, obs)
        rewards.append(reward)
        dones.append(done)
        if video:
            imgs.append(env.render())
    gt.stamp('env interaction')
    R = compute_return(np.stack(rewards, axis=0),
                       np.stack(dones, axis=0))
    info(f"infer> return: {R}")
    info(f"infer> return scaled to physical time {time_limit}: {R*dt}")
    if video:
        log_video("infer", epoch, np.stack(imgs, axis=0))
        gt.stamp('logging video')
    

def write_to_video(logdir, fps, max_frames=2000):
    # load from arr files
    info(f"fps of the video: {fps}")
    #paths = [path for path in glob.glob(f'{logdir}/videos/*.npz')]
    paths = [path for path in glob.glob(f'{logdir}/videos/*.npy')]
    for path in paths:
        # gt.stamp(f'before writing to {path}')
        #import pdb;pdb.set_trace()
        #f = np.load(path) 
        #frames = f['arr_0']
        frames = np.load(path)
        gt.stamp(f'loaded the data from {path}')
        fname = path.strip().rsplit('/',1)[-1].rsplit('.')[0]
        nb_frames, H, W, C = frames.shape
        fps_ratio = 1
        fps_max = 50 # this is due to that faster fps causes some issues in the video, especially in gif
        if fps > fps_max: # cap it at fps_max
            fps_ratio = fps//fps_max
            fps = fps_max
            info(f"fps has been changed into {fps}")
            nb_frames = nb_frames // fps_ratio
            frames = frames[::fps_ratio, :, :] # sampling 
        out = cv2.VideoWriter(f"{logdir}/videos/output_{fname}.mp4", cv2.VideoWriter_fourcc(*'mp4v'), fps, (W,H))
        if nb_frames > max_frames: # only show the first max_Frames if too many
            for frame in frames[:max_frames, ...]:
                out.write(frame[:, :, ::-1])
        else:
            for frame in frames:
                out.write(frame[:, :, ::-1]) # RGB -> BGR
        out.release()
        gt.stamp(f'writing to {path}')


def main(args, write_video=False):
    """ load policy and run it in env (vec) """
    logdir = args.logdir
    dt = args.dt
    time_limit = args.time_limit

    # device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # init env
    agent, _, eval_env = configure(args)
    agent = agent.to(device)

    obs = eval_env.reset() # obs shape [nb_envs, dim_obs]

    # load checkpoints if agent_ckpt is not empty otherwise from logdir
    if args.agent_ckpt == '':
        agent_file = join(logdir, 'best_agent.pt')
    else:
        agent_file = args.agent_ckpt
    R = - np.inf
    cur_e = 0
    if exists(agent_file):
        state_dict = torch.load(agent_file)
        R = state_dict["return"]
        cur_e = state_dict["epoch"]
        info(f"infer> Loading agent {agent_file} with return {R}/scaled return {R*dt} at epoch {cur_e}...")
        agent.load_state_dict(state_dict)
    else:
        error(f"infer> cannot load policy")

    R_infer = infer(
        dt,
        cur_e,
        eval_env,
        agent,
        time_limit,
        progress_bar = True,
        video = args.render_mode=='rgb_array'
    )

    eval_env.close()
    if write_video:
        write_to_video(logdir, int(1/dt), max_frames=5000)
    info(gt.report())

if __name__ == '__main__':
    # step 1, load args from the folder if any
    args = setup_args()
    basicConfig(stream=sys.stdout, level=INFO)
    write_video=False
    if write_video:
        logto(args.logdir, reload=True)

    # step 2, load policy from agentdir, init env 
    # run the policy in the env, 
    # compute return and record the video 
    info(f"the args are {args}")
    main(args, write_video=write_video)
