import torch
import torch.nn as nn
import numpy as np

import exp_utils as PQ
import rl_utils

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ExplorationPolicy(nn.Module, rl_utils.BasePolicy):
    def __init__(self, policy, crabs):
        super().__init__()
        self.policy = policy
        self.crabs = crabs
        self.last_L = 0
        self.last_U = 0

    @torch.no_grad()
    def forward(self, states: torch.Tensor):
        device = states.device
        assert len(states) == 1
        dist = self.policy(states)

        if isinstance(dist, rl_utils.distributions.TanhGaussian):
            mean, std = dist.mean, dist.stddev

            n = 100
            states = states.repeat([n, 1])
            decay = torch.logspace(0, -3, n, base=10., device=device)
            actions = (mean + torch.randn([n, *mean.shape[1:]], device=device) * std * decay[:, None]).tanh()
        else:
            mean = dist
            n = 100
            states = states.repeat([n, 1])
            decay = torch.logspace(0, -3, n, base=10., device=device)
            actions = mean + torch.randn([n, *mean.shape[1:]], device=device) * decay[:, None]

        all_U = self.crabs.U(states, actions).detach().cpu().numpy()
        if np.min(all_U) <= 0:
            index = np.min(np.where(all_U <= 0)[0])
            action = actions[index]
            PQ.meters['expl/backup'] += index

            # if index != 0:
            #     print(index, decay[index].item(), mean.cpu().numpy(), actions[index].cpu().numpy())
        else:
            action = self.crabs.policy(states[0])
            PQ.meters['expl/backup'] += n

        return action

    @rl_utils.torch_utils.maybe_numpy
    def get_actions(self, states):
        return self(states)


n_samples_so_far = 0


def explore(t, runner, n_samples, expl_policy, buf, crabs):
    tmp = rl_utils.TorchReplayBuffer(runner.envs[0], device=device, max_buf_size=n_samples)
    ep_infos = runner.run(expl_policy, n_samples, buffer=tmp)
    buf += tmp

    h = crabs.barrier(tmp['state'])
    u = crabs.U(tmp['state'], tmp['action'])
    nh = crabs.barrier(tmp['next_state'])

    nlls = crabs.uncertainty.ensemble.get_nlls(tmp['state'], tmp['action'], tmp['next_state'])
    PQ.log.debug(f'expl as val: {nlls}')

    n_model_failures = ((nh > 0) & (u <= 0)).sum()
    n_crabs_failures = ((h <= 0) & (u > 0)).sum()

    # if n_crabs_failures > 0:
    #     breakpoint()
    #
    merged_infos = rl_utils.runner.merge_episode_stats(ep_infos)
    n_expl_unsafe_trajs = sum([info.get('episode.unsafe', 0) for info in ep_infos])
    max_h = h.max().item()
    PQ.log.info(f"[explore] # {t}: failure = [model = {n_model_failures}, crabs = {n_crabs_failures}], "
                f"expl trajs return: {merged_infos['return']}, max L = {max_h:.6f}")

    if n_expl_unsafe_trajs > 0:
        PQ.log.critical(f'[explore] {n_expl_unsafe_trajs} unsafe trajectories!')
        # breakpoint()

    PQ.writer.add_scalar('explore/n_trajs', len(ep_infos), global_step=t)
    PQ.writer.add_scalar('explore/n_unsafe_trajs', n_expl_unsafe_trajs, global_step=t)
    PQ.writer.add_scalar('explore/n_model_failures', n_model_failures, global_step=t)
    PQ.writer.add_scalar('explore/n_crabs_failures', n_crabs_failures, global_step=t)
    PQ.writer.add_scalar('explore/policy_return', merged_infos['return'][0], global_step=t)

    return tmp


def eval_and_explore(policy, expl_policy, runners, n_samples, buf, crabs):
    global n_samples_so_far

    from .debugger import evaluate

    evaluate(n_samples_so_far, runners['evaluate'], policy, "eval_and_explore", n_eval_samples=10_000)
    n_samples_so_far += n_samples
    return explore(n_samples_so_far, runners['explore'], n_samples, expl_policy, buf, crabs)
