import torch
import numpy as np
import torch.nn as nn
import gym
import os
from collections import deque
import random
from torch.utils.data import Dataset, DataLoader
import time
from skimage.util.shape import view_as_windows
import copy

class eval_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


def soft_update_params(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(
            tau * param.data + (1 - tau) * target_param.data
        )


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def module_hash(module):
    result = 0
    for tensor in module.state_dict().values():
        result += tensor.sum().item()
    return result


def make_dir(dir_path):
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path


def preprocess_obs(obs, bits=5):
    """Preprocessing image, see https://arxiv.org/abs/1807.03039."""
    bins = 2**bits
    assert obs.dtype == torch.float32
    if bits < 8:
        obs = torch.floor(obs / 2**(8 - bits))
    obs = obs / bins
    obs = obs + torch.rand_like(obs) / bins
    obs = obs - 0.5
    return obs



class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        gym.Wrapper.__init__(self, env)
        self._k = k
        self._frames = deque([], maxlen=k)
        self._qpos = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0,
            high=1,
            shape=((shp[0] * k,) + shp[1:]),
            dtype=env.observation_space.dtype
        )
        self._max_episode_steps = env._max_episode_steps

    def reset(self):
        obs = self.env.reset()
        qpos = self.env.get_qpos()
        for _ in range(self._k):
            self._frames.append(obs)
            self._qpos.append(qpos)
        return self._get_obs(), self._get_qpos()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        self._frames.append(obs)
        self._qpos.append(info['qpos'])
        return self._get_obs(), reward, done, self._get_qpos()

    def _get_obs(self):
        assert len(self._frames) == self._k
        return np.concatenate(list(self._frames), axis=0)

    def _get_qpos(self):
        assert len(self._qpos) == self._k
        return np.stack(list(self._qpos))


class FrameStackState(gym.Wrapper):
    def __init__(self, env, k):
        gym.Wrapper.__init__(self, env)
        self._k = k
        self._frames = deque([], maxlen=k)
        self._qpos = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0,
            high=1,
            shape=((shp[0] * k,) + shp[1:]),
            dtype=env.observation_space.dtype
        )
        self._max_episode_steps = env._max_episode_steps

    def reset(self):
        (obs_state, obs) = self.env.reset()
        for _ in range(self._k):
            self._frames.append(obs)
            self._qpos.append(obs_state)
        return (obs_state, self._get_obs())

    def step(self, action):
        (obs_state,obs), reward, done, info = self.env.step(action)
        self._frames.append(obs)
        self._qpos.append(obs_state)
        return (obs_state, self._get_obs()), reward, done, info

    def _get_obs(self):
        assert len(self._frames) == self._k
        return np.concatenate(list(self._frames), axis=0)

    def _get_qpos(self):
        assert len(self._qpos) == self._k
        #return np.concatenate(list(self._qpos), axis=0)
        return np.stack(list(self._qpos))


def center_crop_image(image, output_size):
    h, w = image.shape[1:]
    new_h, new_w = output_size, output_size

    top = (h - new_h)//2
    left = (w - new_w)//2

    image = image[:, top:top + new_h, left:left + new_w]
    return image

def center_crop_images(image, output_size):
    h, w = image.shape[2:]
    new_h, new_w = output_size, output_size

    top = (h - new_h)//2
    left = (w - new_w)//2

    image = image[:, :, top:top + new_h, left:left + new_w]
    return image



class ReplayBufferEfficieint(Dataset):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, action_shape, capacity, batch_size, 
        device,image_size=84, pre_image_size=100, transform=None, 
        frame_stack=3, augment_target_same_rnd=True, 
        target_K=1, camera_id=None, from_states=False):
        self.capacity = capacity
        self.batch_size = batch_size
        self.device = device
        self.image_size = image_size
        self.pre_image_size = pre_image_size
        self.transform = transform
        self.obs_shape = obs_shape
        self.target_K = target_K
        self.camera_id = camera_id
        self.frame_stack = frame_stack
        self.from_states = from_states
        self.number_channel = obs_shape[0]
        # the proprioceptive obs is stored as float32, pixels obs as uint8
        obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
        
        self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)
        self.eps_not_dones = np.empty((capacity, 1), dtype=np.float32)
        #TO DO THIS IS HARD CODED!! NEED TO BE CHANGED!
        if self.from_states:
            self.obses_state = np.empty((capacity, 67), dtype=np.float32)


        self.augment_target_same_rnd = augment_target_same_rnd
        if self.augment_target_same_rnd:
            print('using the same random index!!!')

        self.idx = 0
        self.last_save = 0
        self.full = False

    def add(self, obs, action, reward, next_obs, done, eps_done):
        
        np.copyto(self.obses[self.idx], obs[-1 * self.number_channel:, :, :])
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.eps_not_dones[self.idx], not eps_done)
        if self.from_states:
            #next_obs is actually state_obs!!!!
            np.copyto(self.obses_state[self.idx], next_obs)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def sample_proprio(self):
        
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )
        
        obses = self.obses[idxs]
        next_obses = self.next_obses[idxs]

        obses = torch.as_tensor(obses, device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(
            next_obses, device=self.device
        ).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        return obses, actions, rewards, next_obses, not_dones


    def sample_rad(self,aug_funcs):
        
        # augs specified as flags
        # curl_sac organizes flags into aug funcs
        # passes aug funcs into sampler


        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )
        capacity = self.capacity if self.full else self.idx
        idxs_current = copy.deepcopy(idxs)
        #avoid using the last one
        idxs_current = [x-1 if not self.eps_not_dones[x] else x for x in idxs_current]
        idxs_next = [x if x+1 >= capacity or not self.eps_not_dones[x] else x+1 for x in idxs_current] 
        idxs_list = [copy.deepcopy(idxs_next), copy.deepcopy(idxs_current)]
        if self.frame_stack > 1:
            idxs_prev = idxs_current
            for t in range(1, self.frame_stack):
                idxs_prev = [x if x-1 < 0 or not self.eps_not_dones[x-1] else x-1 for x in idxs_prev]
                idxs_list.append(copy.deepcopy(idxs_prev))

        obses = []
        for t in range(self.frame_stack):
            obses.append(self.obses[idxs_list[-1 - 1 * t]]) #-1 to - self.frame_stak 
        obses = np.concatenate(obses, axis=1)
        next_obses = []
        for t in range(self.frame_stack):
            next_obses.append(self.obses[idxs_list[-2 - 1 * t]]) 
        next_obses = np.concatenate(next_obses, axis=1)

        og_obses = center_crop_images(obses,self.pre_image_size)
        og_next_obses = center_crop_images(next_obses,self.pre_image_size)

        if aug_funcs:
            for aug,func in aug_funcs.items():
                # apply crop and cutout first
                if 'crop' in aug or 'cutout' in aug or 'window' in aug:
                    if 'translate' not in aug:
                        obses = func(obses,self.image_size)
                        next_obses = func(next_obses,self.image_size)
                    else:
                        og_obses = func(obses, self.pre_image_size)
                        og_next_obses = func(next_obses, self.pre_image_size)
                if 'translate' in aug:
                    obses, rndm_idxs = func(og_obses, self.image_size, return_random_idxs=True)
                    if self.augment_target_same_rnd:
                        next_obses = func(og_next_obses, self.image_size, **rndm_idxs)
                    else:
                        next_obses = func(og_next_obses, self.image_size)

                    if self.target_K > 1:
                        all_next_obses = []
                        all_next_obses.append(next_obses)
                        if self.augment_target_same_rnd:
                            start = 1
                        else:
                            start = 0
                        for k in range(start, self.target_K):
                            all_next_obses.append(func(og_next_obses, self.image_size))    
                        next_obses = np.stack(all_next_obses, axis=1) 

        
        obses = torch.as_tensor(obses, device=self.device).float()
        next_obses = torch.as_tensor(next_obses, device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs_current], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs_current], device=self.device)
        not_dones = torch.as_tensor(self.not_dones[idxs_current], device=self.device)

        obses = obses / 255.
        next_obses = next_obses / 255.

        # augmentations go here
        if aug_funcs:
            for aug,func in aug_funcs.items():
                # skip crop and cutout augs
                if 'crop' in aug or 'cutout' in aug or 'translate' in aug or 'window' in aug:
                    continue
                obses = func(obses)
                next_obses = func(next_obses)
        
        if self.from_states:
            states = torch.as_tensor(self.obses_state[idxs_current], device=self.device)
            return (obses, states), actions, rewards, next_obses, not_dones
        else:
            return obses, actions, rewards, next_obses, not_dones

    def save(self, save_dir):
        if self.idx == self.last_save:
            return
        path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
        payload = [
            self.obses[self.last_save:self.idx],
            self.next_obses[self.last_save:self.idx],
            self.actions[self.last_save:self.idx],
            self.rewards[self.last_save:self.idx],
            self.not_dones[self.last_save:self.idx]
        ]
        self.last_save = self.idx
        torch.save(payload, path)

    def load(self, save_dir):
        chunks = os.listdir(save_dir)
        chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
        for chunk in chucks:
            start, end = [int(x) for x in chunk.split('.')[0].split('_')]
            path = os.path.join(save_dir, chunk)
            payload = torch.load(path)
            assert self.idx == start
            self.obses[start:end] = payload[0]
            self.next_obses[start:end] = payload[1]
            self.actions[start:end] = payload[2]
            self.rewards[start:end] = payload[3]
            self.not_dones[start:end] = payload[4]
            self.idx = end

    def __getitem__(self, idx):
        idx = np.random.randint(
            0, self.capacity if self.full else self.idx, size=1
        )
        idx = idx[0]
        obs = self.obses[idx]
        action = self.actions[idx]
        reward = self.rewards[idx]
        next_obs = self.next_obses[idx]
        not_done = self.not_dones[idx]

        if self.transform:
            obs = self.transform(obs)
            next_obs = self.transform(next_obs)

        return obs, action, reward, next_obs, not_done

    def __len__(self):
        return self.capacity 