import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
#os.environ['MUJOCO_GL'] = 'egl'
import torch
import numpy as np
import gym
gym.logger.set_level(40)
import time
import random
from pathlib import Path
from cfg import parse_cfg
from env import make_env
from algorithm.tdmpc import TDMPC
from algorithm.helper import VideoEpisode, ReplayBuffer
import logger
torch.backends.cudnn.benchmark = True
__CONFIG__, __LOGS__ = 'cfgs', 'logs'

from dreamfusion.guidance.vid_utils import VideoDiffusion


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def evaluate(env, agent, num_episodes, step, env_step, video):
    """Evaluate a trained agent and optionally save a video."""
    episode_rewards = []
    for i in range(num_episodes):
        obs, done, ep_reward, t = env.reset(), False, 0, 0
        if video: video.init(env, enabled=(i==0))
        while not done:
            action = agent.plan(obs, eval_mode=True, step=step, t0=t==0)
            obs, reward, done, _ = env.step(action.cpu().numpy())
            ep_reward += reward
            if video: video.record(env)
            t += 1
        episode_rewards.append(ep_reward)
        if video: video.save(env_step)
    return np.nanmean(episode_rewards)


def train(cfg):
    """Training script for TD-MPC. Requires a CUDA-enabled device."""
    assert torch.cuda.is_available()
    set_seed(cfg.seed)
    work_dir = Path().cwd() / __LOGS__ / cfg.task / cfg.modality / cfg.exp_name / str(cfg.seed)
    env, agent, buffer = make_env(cfg), TDMPC(cfg), ReplayBuffer(cfg)

    # sds pixel configs
    domain, task = cfg.task.replace('-', '_').split('_', 1)
    camera_id = dict(quadruped=2).get(domain, 0) #TODO: make sure to select the camera appropriately.  0 for humanoid+dog front, 1 for dog side
    # for video diffusion, max resolution is 256
    dim = dict(dog=256).get(domain, 256)
    render_kwargs = dict(height=dim, width=dim, camera_id=camera_id)

    # setup SDS model
    device = torch.device('cuda')
    guidance = VideoDiffusion(device, t_range=[0.02, 0.98])

    c_in = guidance.get_text_embeds([cfg.text_prompt])

    #conditional_text = guidance.get_text_embeds([cfg.text_prompt])
    #unconditional_text = guidance.get_text_embeds([""])
    #baseline_text = guidance.get_text_embeds(["a humanoid"])
    #c_in = torch.cat([unconditional_text, conditional_text, baseline_text])
    #del guidance.tokenizer
    #del guidance.text_encoder
    
    # Run training
    L = logger.Logger(work_dir, cfg)
    episode_idx, start_time = 0, time.time()
    for step in range(0, cfg.train_steps+cfg.episode_length, cfg.episode_length):
        # Collect trajectory
        obs = env.reset()
        episode = VideoEpisode(cfg, obs)
        while not episode.done:
            action = agent.plan(obs, step=step, t0=episode.first)
            obs, gt_reward, done, _ = env.step(action.cpu().numpy())
            rendered = torch.Tensor(env.render(**render_kwargs).copy()[np.newaxis, ...]).permute(0,3,1,2).half().to(device)
            latent = guidance.encode_imgs(rendered)
            episode += (obs, action, latent, gt_reward, done)
        assert len(episode) == cfg.episode_length
        # Can set hyperparams here, also optionally pass in stable source noise
        episode.reward = guidance.get_sds_alignment(c_in, episode.latent)
        episode.cumulative_reward = episode.reward.sum().cpu().numpy()
        buffer += episode

        # Update model
        train_metrics = {}
        if step >= cfg.seed_steps:
            num_updates = cfg.seed_steps if step == cfg.seed_steps else cfg.episode_length
            for i in range(num_updates):
                train_metrics.update(agent.update(buffer, step+i))

        # Log training episode
        episode_idx += 1
        env_step = int(step*cfg.action_repeat)
        common_metrics = {
            'episode': episode_idx,
            'step': step,
            'env_step': env_step,
            'total_time': time.time() - start_time,
            'episode_reward': episode.cumulative_reward,
            'episode_gt_reward': episode.cumulative_gt_reward}
        train_metrics.update(common_metrics)
        L.log(train_metrics, category='train', agent=agent)

        # Evaluate agent periodically
        if env_step % cfg.eval_freq == 0:
            common_metrics['episode_reward'] = evaluate(env, agent, cfg.eval_episodes, step, env_step, L.video)
            L.log(common_metrics, category='eval', agent=agent)

    L.finish(agent)
    print('Training completed successfully')


if __name__ == '__main__':
    train(parse_cfg(Path().cwd() / __CONFIG__))
