import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'
import torch
import numpy as np
import gym
gym.logger.set_level(40)
import time
import random
from pathlib import Path
from cfg import parse_cfg
from env import make_env
from algorithm.tdmpc import TDMPC
from algorithm.helper import OverlappingEpisode, Episode, VideoEpisode, ReplayBuffer
import logger
torch.backends.cudnn.benchmark = True
__CONFIG__, __LOGS__ = 'cfgs', 'logs'

from animate_utils import VideoAnimateLCM, VideoAnimateDiffusion


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def evaluate(env, agent, num_episodes, step, env_step, video):
    """Evaluate a trained agent and optionally save a video."""
    episode_rewards, episode_success = [], []
    for i in range(num_episodes):
        obs, done, ep_reward, t = env.reset(), False, 0, 0
        if video: video.init(env, enabled=(i==0))
        while not done:
            action = agent.plan(obs, eval_mode=True, step=step, t0=t==0)
            obs, reward, done, info = env.step(action.cpu().numpy())
            ep_reward += reward
            if video: video.record(env)
            t += 1
        episode_rewards.append(ep_reward)
        episode_success.append(int(info.get('success', 0)))
        if video: video.save(env_step)
    return np.nanmean(episode_rewards), np.nanmean(episode_success)


def train(cfg):
    """Training script for TD-MPC. Requires a CUDA-enabled device."""
    assert torch.cuda.is_available()
    set_seed(cfg.seed)
    work_dir = Path().cwd() / __LOGS__ / cfg.task / cfg.modality / cfg.exp_name / str(cfg.seed)
    env, agent, buffer = make_env(cfg), TDMPC(cfg), ReplayBuffer(cfg)

    # sds pixel configs
    domain, task = cfg.task.replace('-', '_').split('_', 1)
    camera_id = dict(quadruped=2).get(domain, 0) #TODO: make sure to select the camera appropriately.  0 for humanoid+dog front, 1 for dog side
    # for video diffusion, max resolution is 256
    dim = dict(dog=512).get(domain, 480)
    render_kwargs = dict(height=dim, width=dim, camera_id=camera_id)

    # setup SDS model
    device = torch.device('cuda')
    guidance = VideoAnimateDiffusion(device)

    noise_level = cfg.noise_level
    align_scale = cfg.alignment_scale
    recon_scale = cfg.recon_scale

    negative_prompts = "bad quality, worse quality"
    c_in = guidance.get_text_embeds(cfg.text_prompt, negative_prompts)

    #conditional_text = guidance.get_text_embeds([cfg.text_prompt])
    #unconditional_text = guidance.get_text_embeds([""])
    #baseline_text = guidance.get_text_embeds(["a humanoid"])
    #c_in = torch.cat([unconditional_text, conditional_text, baseline_text])
    #del guidance.tokenizer
    #del guidance.text_encoder
    
    # Run training
    context_size = 4
    #context_size = 2
    L = logger.Logger(work_dir, cfg)
    episode_idx, start_time = 0, time.time()
    for step in range(0, cfg.train_steps+cfg.episode_length, cfg.episode_length):
        # Collect trajectory
        obs = env.reset()
        episode = OverlappingEpisode(cfg, obs, context_size)
        #episode = VideoEpisode(cfg, obs)
        latent_history = torch.zeros([context_size, 4, 64, 64]).half().to(device)
        success_history = torch.zeros([context_size]).half().to(device)
        timestep = torch.randint(noise_level, noise_level + 100, [1], dtype=torch.long, device=device)
        #source_noise = torch.randn_like(latent_history) # source noise to be the same as latent history, resampled per ep
        while not episode.done:
            source_noise = torch.randn_like(latent_history) # source noise to be the same as latent history, resampled per ep
            action = agent.plan(obs, step=step, t0=episode.first)
            obs, gt_reward, done, info = env.step(action.cpu().numpy())
            rendered = torch.Tensor(env.render(**render_kwargs).copy()[np.newaxis, ...]).permute(0,3,1,2).half().to(device)
            latent = guidance.encode_imgs(rendered / 255.)
            #latent = guidance.encode_imgs(rendered)

            latent_history = latent_history.roll(-1, 0)
            latent_history[-1] = latent
            success_history = success_history.roll(-1, 0)
            success_history[-1] = info['success']

            frame_rewards = guidance.get_loop_sds_alignment(c_in, latent_history, timestep, alignment_scale=align_scale, recon_scale=recon_scale, noise=source_noise)
            #reward = frame_rewards.mean() #frame_rewards[-1]
            idxs = torch.where(latent_history.sum([1,2,3]) != 0)
            reward = frame_rewards[idxs] + success_history[idxs]
            #if latent_history[0].sum() == 0:
            #    reward = frame_rewards[1:]
            #else:
            #    reward = frame_rewards
            

            #episode += (obs, action, latent, gt_reward, done)
            episode += (obs, action, reward, gt_reward, done)
        assert len(episode) == cfg.episode_length
        # Can set hyperparams here, also optionally pass in stable source noise
        #episode.reward = guidance.get_sds_alignment(c_in, episode.latent, alignment_scale=align_scale, recon_scale=recon_scale, noise_level=noise_level, length=len(episode))
        #episode.cumulative_reward = episode.reward.sum().cpu().numpy()
        buffer += episode

        # Update model
        train_metrics = {}
        if step >= cfg.seed_steps:
            num_updates = cfg.seed_steps if step == cfg.seed_steps else cfg.episode_length
            for i in range(num_updates):
                train_metrics.update(agent.update(buffer, step+i))

        # Log training episode
        episode_idx += 1
        env_step = int(step*cfg.action_repeat)
        common_metrics = {
            'episode': episode_idx,
            'step': step,
            'env_step': env_step,
            'total_time': time.time() - start_time,
            'episode_reward': episode.cumulative_reward,
            'episode_gt_reward': episode.cumulative_gt_reward}
        train_metrics.update(common_metrics)
        L.log(train_metrics, category='train', agent=agent)

        # Evaluate agent periodically
        if env_step % cfg.eval_freq == 0:
            common_metrics['episode_reward'], common_metrics['episode_success'] = evaluate(env, agent, cfg.eval_episodes, step, env_step, L.video)
            common_metrics['episode_gt_reward'] = common_metrics['episode_reward']
            L.log(common_metrics, category='eval', agent=agent)

    L.finish(agent)
    print('Training completed successfully')


if __name__ == '__main__':
    train(parse_cfg(Path().cwd() / __CONFIG__))
