import random
import os
import gzip
import pickle
from collections import defaultdict
from hashing_atari import AtariStateHash
from reservoir_sampling import UpdateBuffer


class MixingTimeAgent:
    def __init__(
            self,
            args,
            envs,
            rl_agents,
            is_atari,
            tau,
            max_start_states,
            asymptotic_steps,
            directory,
    ):
        self.args = args
        self.envs = envs
        self.tau = tau
        self.directory = directory
        self.rl_agents = rl_agents
        self.stochastic = args.stochastic or is_atari and not args.deterministic
        self.deterministic = not self.stochastic
        self.asymptotic_steps = asymptotic_steps
        self.reporting = asymptotic_steps // args.frequency
        self.hash_state = AtariStateHash(image_hash="average")

        if self.args.start_from_latest_checkpoint and self.args.only_accumulate_returns:
            start, returns_dict = self.load_latest_checkpoint()
            logs = f"\n Resuming Epsilon Return Mixing Time from checkpoint {start} \n"
        else:
            start = 0
            logs = "\n Calculating the Epsilon Return Mixing Time \n"
            returns_dict = defaultdict(list)
            self.save_files(returns_dict, logs, name="accumulate_return")
        print(logs)
        self.start = start
        self.buffer = UpdateBuffer(returns_dict=returns_dict, age=self.start, max_start_states=max_start_states)
        self.start_cleaning = 2 * self.reporting
        self.hash_idx = {}
        self.current_idx = 0

        # Reset the env
        self.prev_idx = 0
        self.reset_env(self.prev_idx)

    def sample_task(self):
        """
        Samples a new task based on the environment transition type.
        This method supports random and cyclic task sampling.
        """
        if self.args.env_transition_type == "cycles":
            idx = (self.prev_idx + 1) % len(self.envs)
            self.prev_idx += 1
        else:
            idx = random.sample(range(len(self.envs)), 1)[0]
        return idx

    def reset_env(self, task_idx):
        """
        Resets the environment according to the task index
        """
        obs = self.envs[task_idx].reset()
        state = None
        return obs, state

    def accumulate_returns(self):

        # Reset the environment
        env_idx = self.sample_task()
        obs, state = self.reset_env(env_idx)
        clean_checkpoint = self.reporting if self.start == 0 else self.start - self.reporting
        self.start_cleaning = self.start_cleaning if self.start == 0 else self.start
        for i in range(self.start, int(self.asymptotic_steps)):
            if self.args.use_uniform_policy:
                action = [self.envs[env_idx].action_space.sample()]
            else:
                action, state = self.rl_agents[env_idx].predict(
                    obs, state=state, deterministic=self.deterministic
                )
            obs, reward, done, _ = self.envs[env_idx].step(action)
            rew = reward.tolist()[0]

            # Hash the observation
            hash_value = self.hash_state(obs)
            hash_idx = self.hash2idx(hash_value)

            # Do reservoir update of the dict
            self.buffer(rew, hash_idx)

            if (i + 1) % self.tau == 0:  # switch to next environment
                # Reset the current environment -> Doing this because sometimes it throws some weird reset errors.
                _ = self.envs[env_idx].reset()
                # Sample a new task and reset the env
                env_idx = self.sample_task()
                obs, state = self.reset_env(env_idx)
            elif done:  # reset the current environment
                obs, state = self.reset_env(env_idx)

            if (i + 1) % self.reporting == 0:
                logs = "Checkpoint {}\n".format(i + 1)
                self.buffer.save(self.directory, logs, name="accumulate_return".format(i + 1), index=i+1)
                print(logs)
                # Delete the prev checkpoint files to save the disk space
                if (i + 1) > self.start_cleaning:
                    self.clean(clean_checkpoint)
                    clean_checkpoint += self.reporting


    def get_asymptotic_reward_rate(self):
        logs = "Calculating the Asymptotic Return"
        print(logs)
        reward_rate = 0
        state = None
        env_idx = self.prev_idx
        obs = self.envs[env_idx].reset()
        asymptotic_return_dict = {}
        self.save_files(asymptotic_return_dict, logs, name="asymptotic_reward")
        if self.args.start_from_latest_checkpoint:
            start, _ = self.load_latest_checkpoint()
        else:
            start = 0

        for i in range(start, self.asymptotic_steps):
            if self.args.use_uniform_policy:
                action = [self.envs[env_idx].action_space.sample()]
            else:
                action, state = self.rl_agents[env_idx].predict(
                    obs, state=state, deterministic=self.deterministic
                )
            obs, rew, done, _ = self.envs[env_idx].step(action)
            reward_rate += rew.tolist()[0]
            if (i + 1) % self.tau == 0:
                env_idx = self.sample_task()
                obs, state = self.reset_env(env_idx)
            elif done:
                obs, state = self.reset_env(env_idx)

            if (i + 1) % self.reporting == 0:
                logs += "Checkpoint {} | Reward rate {} \n".format(
                    i + 1, reward_rate / (i + 1)
                )
                print(logs)
                self.save_files(asymptotic_return_dict, logs, name="asymptotic_reward")
                logs = ""

        asymptotic_return = reward_rate / self.asymptotic_steps
        asymptotic_return_dict["Asymptotic Return"] = asymptotic_return
        logs = "Asymptotic Return : {}".format(asymptotic_return)
        self.save_files(asymptotic_return_dict, logs, name="asymptotic_reward")
        return asymptotic_return

    def hash2idx(self, hash_value):
        if hash_value in self.hash_idx:
            return self.hash_idx[hash_value]
        else:
            self.hash_idx[hash_value] = self.current_idx
            self.current_idx += 1
            return self.hash_idx[hash_value]

    def save_files(self, results, logs, name, index=None):
        if results is not None:
            if index is not None:
                file_name = name + '_{}'.format(index)
            else:
                file_name = name

            with gzip.open(os.path.join(self.directory, file_name + '.pkl.gz'), "wb") as f_dict:
                pickle.dump(results, f_dict, protocol=pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(self.directory, name + ".txt"), "a") as f_logs:
            f_logs.write(logs)

    def run(self):
        if self.args.only_accumulate_returns:
            self.accumulate_returns()
        elif self.args.only_asymptotic_returns:
            _ = self.get_asymptotic_reward_rate()
        else:
            _ = self.get_asymptotic_reward_rate()
            self.accumulate_returns()

    def load_latest_checkpoint(self):
        if self.args.only_asymptotic_returns:
            with open(os.path.join(self.directory,  "asymptotic_reward.txt"), "r") as f:
                last_checkpoint = f.readlines()[-1]
            return int(last_checkpoint.split()[1]), None
        else:
            with open(os.path.join(self.directory, "accumulate_return.txt"), "r") as f:
                last_checkpoint = int(float(f.readlines()[-1].split()[-1]))
            with gzip.open(os.path.join(self.directory, "accumulate_return_{}.pkl.gz".format(last_checkpoint)), 'rb') as f:
                returns_dict = pickle.load(f)
            return last_checkpoint, returns_dict

    def clean(self, checkpoint):
        file = self.directory + "/accumulate_return_{}.pkl.gz".format(checkpoint)

        if os.path.exists(file):
            os.remove(file)
        else:
            raise FileNotFoundError

