import copy
import glob
import numpy as np
import os

import torch
import torch.nn.functional as F
from torch.optim import Adam

from common.buffers import ReplayBuffer
from common.utils import (
    get_device,
    soft_update,
    to_torch,
    to_np,
    preprocess,
)
from models.policies import TanhGaussianPolicy
from models.values import QNetwork

# https://github.com/pranz24/pytorch-soft-actor-critic/


class SAC:
    def __init__(self, config, env, logger):
        self.c = config
        self.env = env
        self.logger = logger
        self.device = get_device()

        self.step = 0
        self.episode = 0
        self.updates = 0

        self.build_models()

    def build_models(self):
        obs_shape = self.env.observation_space.shape
        act_shape = self.env.action_space.shape

        # Replay buffer
        self.buffer = ReplayBuffer(
            self.c.replay_size,
            obs_shape,
            act_shape,
            obs_type=np.uint8 if self.c.pixel_obs else np.float32,
        )

        # Policy
        self.policy = TanhGaussianPolicy(
            obs_shape, act_shape, self.c.hidden_size, self.env.action_space
        ).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=self.c.lr)

        # Critic
        self.critic = QNetwork(obs_shape, act_shape, self.c.hidden_size).to(self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=self.c.lr)
        self.critic_target = copy.deepcopy(self.critic)

        # Temperature
        if self.c.automatic_entropy_tuning:
            if self.c.target_entropy == "auto":
                # Target entropy is −dim(A) as given in the paper
                self.target_entropy = -torch.prod(torch.tensor(act_shape)).item()
            else:
                self.target_entropy = float(self.c.target_entropy)
            self.log_alpha = torch.zeros(1, requires_grad=True).to(self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=self.c.lr)
        else:
            self.log_alpha = torch.tensor(self.c.alpha).log().to(self.device)

    def select_action(self, obs, evaluate=False):
        obs = to_torch(preprocess(obs[None]))
        action, _ = self.policy(obs, deterministic=evaluate)
        return to_np(action)[0]

    def update_parameters(self, obs, act, rew, next_obs, done, updates):
        obs = to_torch(preprocess(obs))
        next_obs = to_torch(preprocess(next_obs))
        act, rew, done = map(to_torch, [act, rew, done])
        alpha = self.log_alpha.exp()

        # Compute Q target
        with torch.no_grad():
            next_act, next_logp = self.policy(next_obs)
            next_q1_target, next_q2_target = self.critic_target(next_obs, next_act)
            min_next_q_target = torch.min(next_q1_target, next_q2_target)
            q_target = rew + (1 - done) * self.c.gamma * (
                min_next_q_target - alpha * next_logp
            )

        # Compute Q loss
        q1, q2 = self.critic(obs, act)
        q1_loss = F.mse_loss(q1, q_target)
        q2_loss = F.mse_loss(q2, q_target)
        q_loss = q1_loss + q2_loss

        # Update critic
        self.critic_optim.zero_grad()
        q_loss.backward()
        self.critic_optim.step()

        # Compute policy loss
        new_act, new_logp = self.policy(obs)
        new_q1, new_q2 = self.critic(obs, new_act)
        min_new_q = torch.min(new_q1, new_q2)
        policy_loss = ((alpha * new_logp) - min_new_q).mean()

        # Update policy
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        # Update alpha with dual descent
        if self.c.automatic_entropy_tuning:
            alpha_loss = -(
                self.log_alpha * (new_logp + self.target_entropy).detach()
            ).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
        else:
            alpha_loss = torch.tensor(0).to(self.device)

        # Update target critic
        if updates % self.c.target_update_freq == 0:
            soft_update(self.critic_target, self.critic, self.c.tau)

        return (
            q1_loss.item(),
            q2_loss.item(),
            policy_loss.item(),
            alpha_loss.item(),
        )

    def train(self):
        while self.step < self.c.num_steps:
            obs = self.env.reset()
            done = False
            episode_reward = 0
            episode_success = 0
            while not done:
                # Train agent
                if len(self.buffer) > self.c.batch_size:
                    # Number of updates per environment step
                    for i in range(self.c.updates_per_step):
                        # Update parameters of all the networks
                        batch = self.buffer.sample(self.c.batch_size)
                        (
                            critic_1_loss,
                            critic_2_loss,
                            policy_loss,
                            entropy_loss,
                        ) = self.update_parameters(*batch, self.updates)
                        self.logger.record("train/critic_1_loss", critic_1_loss)
                        self.logger.record("train/critic_2_loss", critic_2_loss)
                        self.logger.record("train/policy_loss", policy_loss)
                        self.logger.record("train/entropy_loss", entropy_loss)
                        self.logger.record("train/alpha", self.log_alpha.exp().item())
                        self.updates += 1

                # Take environment step
                if self.step < self.c.start_step:
                    action = self.env.action_space.sample()
                else:
                    with torch.no_grad():
                        action = self.select_action(obs)
                next_obs, reward, done, info = self.env.step(action)
                episode_reward += reward
                episode_success += info.get("success", 0)
                self.step += 1

                # Ignore done if it comes from truncation
                real_done = 0 if info.get("TimeLimit.truncated", False) else float(done)
                self.buffer.push(obs, action, reward, next_obs, real_done)
                obs = next_obs

            if self.episode % self.c.eval_freq == 0:
                self.evaluate()

            if self.episode % self.c.checkpoint_freq == 0:
                self.save_checkpoint()

            self.logger.record("train/return", episode_reward)
            self.logger.record("train/success", float(episode_success > 0))
            self.logger.record("train/step", self.step)
            self.logger.dump(step=self.step)
            self.episode += 1

    def evaluate(self):
        for _ in range(self.c.num_eval_episodes):
            obs = self.env.reset()
            done = False
            episode_reward = 0
            episode_success = 0
            while not done:
                with torch.no_grad():
                    action = self.select_action(obs, evaluate=True)
                next_obs, reward, done, info = self.env.step(action)
                episode_reward += reward
                episode_success += info.get("success", 0)
                obs = next_obs
            self.logger.record("test/return", episode_reward)
            self.logger.record("test/success", float(episode_success > 0))

    def save_checkpoint(self):
        ckpt_path = os.path.join(self.logger.dir, f"models_{self.episode}.pt")
        ckpt = {
            "step": self.step,
            "episode": self.episode,
            "updates": self.updates,
            "policy": self.policy.state_dict(),
            "policy_optim": self.policy_optim.state_dict(),
            "critic": self.critic.state_dict(),
            "critic_optim": self.critic_optim.state_dict(),
            "log_alpha": self.log_alpha,
        }
        if self.c.automatic_entropy_tuning:
            ckpt["alpha_optim"] = self.alpha_optim.state_dict()
        torch.save(ckpt, ckpt_path)

    def load_checkpoint(self):
        # Load models from the latest checkpoint
        ckpt_paths = list(glob.glob(os.path.join(self.logger.dir, "models_*.pt")))
        if len(ckpt_paths) > 0:
            max_episode = 0
            for path in ckpt_paths:
                episode = path[path.rfind("/") + 8 : -3]
                if episode.isdigit() and int(episode) > max_episode:
                    max_episode = int(episode)
            ckpt_path = os.path.join(self.logger.dir, f"models_{max_episode}.pt")
            ckpt = torch.load(ckpt_path)
            print(f"Loaded checkpoint from {ckpt_path}")

            self.step = ckpt["step"]
            self.episode = ckpt["episode"]
            self.updates = ckpt["updates"]
            self.policy.load_state_dict(ckpt["policy"])
            self.policy_optim.load_state_dict(ckpt["policy_optim"])
            self.value_function.load_state_dict(ckpt["value_function"])
            self.value_function_optim.load_state_dict(ckpt["value_function_optim"])
            self.log_alpha = ckpt["log_alpha"]
            if self.c.automatic_entropy_tuning:
                self.alpha_optim = Adam([self.log_alpha], lr=self.c.lr)
                self.alpha_optim.load_state_dict(ckpt["alpha_optim"])
