import argparse
from networks.network import Actor, Critic
from utils.utils import ReplayBuffer, make_mini_batch, convert_to_tensor

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class PPO_Gini(nn.Module):
    def __init__(self, writer, device, state_dim, action_dim, args):
        super(PPO_Gini,self).__init__()
        self.args = args
        
        self.data = ReplayBuffer(action_prob_exist = True, max_size = self.args.traj_length, state_dim = state_dim, num_action = action_dim)
        self.actor = Actor(self.args.layer_num, state_dim, action_dim, self.args.hidden_dim, \
                           self.args.activation_function, self.args.last_activation, self.args.trainable_std)
        self.critic = Critic(self.args.layer_num, state_dim, 1, \
                             self.args.hidden_dim, self.args.activation_function, last_activation=None)
        
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.args.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.args.critic_lr)

        self.writer = writer
        self.device = device
        self.episodes = None
    def get_action(self,x):
        mu,sigma = self.actor(x)
        return mu,sigma
    
    def v(self,x):
        return self.critic(x)
    
    def put_data(self,transition):
        self.data.put_data(transition)

    def put_episodes(self, episodes):
        self.episodes = episodes
        
    def get_gae(self, states, rewards, next_states, dones):
        values = self.v(states).detach()
        td_target = rewards + self.args.gamma * self.v(next_states) * (1 - dones)
        delta = td_target - values
        delta = delta.detach().cpu().numpy()
        advantage_lst = []
        advantage = 0.0
        for idx in reversed(range(len(delta))):
            if dones[idx] == 1:
                advantage = 0.0
            advantage = self.args.gamma * self.args.lambda_ * advantage + delta[idx][0]
            advantage_lst.append([advantage])
        advantage_lst.reverse()
        advantages = torch.tensor(advantage_lst, dtype=torch.float).to(self.device)
        return values, advantages
    
    def train_net(self, n_epi):
        data = self.data.sample(shuffle = False)
        states, actions, rewards, next_states, dones, old_log_probs = convert_to_tensor(self.device, data['state'], data['action'], data['reward'], data['next_state'], data['done'], data['log_prob'])
        
        old_values, advantages = self.get_gae(states, rewards, next_states, dones)
        returns = advantages + old_values
        advantages = (advantages - advantages.mean())/(advantages.std()+1e-3)
        
        for i in range(self.args.train_epoch):

            actor_loss_lst = []

            for state,action,old_log_prob,advantage,return_,old_value \
                in make_mini_batch(self.args.batch_size, states, actions, \
                                           old_log_probs,advantages,returns,old_values): 

                curr_mu,curr_sigma = self.get_action(state)
                value = self.v(state).float()
                curr_dist = torch.distributions.Normal(curr_mu,curr_sigma)
                entropy = curr_dist.entropy() * self.args.entropy_coef
                curr_log_prob = curr_dist.log_prob(action).sum(1,keepdim = True)

                #policy clipping
                ratio = torch.exp(curr_log_prob - old_log_prob.detach())
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1-self.args.max_clip, 1+self.args.max_clip) * advantage
                actor_loss = (-torch.min(surr1, surr2) - entropy).mean()
                actor_loss_lst.append(actor_loss)
                
                #value clipping (PPO2 technic)
                old_value_clipped = old_value + (value - old_value).clamp(-self.args.max_clip,self.args.max_clip)
                value_loss = (value - return_.detach().float()).pow(2)
                value_loss_clipped = (old_value_clipped - return_.detach().float()).pow(2)
                critic_loss = 0.5 * self.args.critic_coef * torch.max(value_loss,value_loss_clipped).mean()
                
                # self.actor_optimizer.zero_grad()
                # actor_loss.backward()
                # nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.max_grad_norm)
                # self.actor_optimizer.step()
                
                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.max_grad_norm)
                self.critic_optimizer.step()
                
                if self.writer != None:
                    self.writer.add_scalar("loss/actor_loss", actor_loss.item(), n_epi)
                    self.writer.add_scalar("loss/critic_loss", critic_loss.item(), n_epi)


            ppo_loss = torch.stack(actor_loss_lst).mean()
            gini_loss = self.prepare_gini()
            policy_loss = ppo_loss + self.args.lam * gini_loss
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.max_grad_norm)
            self.actor_optimizer.step()

    def prepare_gini(self):
        episodes_num = len(self.episodes['state'])
        #print('total ', episodes_num, ' episodes')
        
        sum_log_prob_lst = []
        ret_lst = []
        ratio_lst = []
        ep_len_lst = []

        for i_ep in range(episodes_num):
            reward_lst = self.episodes['reward'][i_ep]
            state_lst = self.episodes['state'][i_ep]
            action_lst = self.episodes['action'][i_ep]
            old_log_prob_lst = self.episodes['log_prob'][i_ep]

            # compute return
            traj_len = len(reward_lst)
            ep_len_lst.append(traj_len)
            ret = 0.
            for i in range(traj_len-1, -1, -1):
                ret = reward_lst[i] + self.args.gamma * ret
            ret_lst.append(ret)

            # compute log_prob
            state = torch.tensor(np.array(state_lst)).float()
            action = torch.tensor(np.array(action_lst)).float()
            curr_mu, curr_sigma = self.get_action(state)
            curr_dist = torch.distributions.Normal(curr_mu,curr_sigma)
            curr_log_prob = curr_dist.log_prob(action).sum(1,keepdim = True)
            #print('curr_log_prob:', curr_log_prob.detach().numpy().reshape(1000,) )
            sum_curr_log_prob = curr_log_prob.sum()
            sum_log_prob_lst.append(sum_curr_log_prob)

            # compute ratio
            old_log_prob = torch.tensor(np.array(old_log_prob_lst))
            #print('old_log_prob:', old_log_prob.numpy().reshape(1000,) )
            log_ratio = curr_log_prob.detach() - old_log_prob
            ratio = torch.exp(log_ratio.sum()).item()
            ratio_lst.append(ratio)

        '''choose IS ratio'''
        is_ratio = np.array(ratio_lst)
        is_ratio = np.minimum(is_ratio, 1.6)
        #print('is_ratio:', is_ratio)

        ep_len = np.array(ep_len_lst)
        
        is_ratio_t = torch.tensor(is_ratio).float()
        ret_t = torch.tensor(ret_lst).float()
        sum_log_prob_t = torch.stack(sum_log_prob_lst)

        print('ret_t:', ret_t.shape, 'sum_log_prob_t:', sum_log_prob_t.shape, 'is_ratio_t:', is_ratio_t.shape)
        gini_loss = self.compute_gini_loss(ret_t, sum_log_prob_t, is_ratio_t)
        n_trans = ep_len.sum()
        gini_loss = gini_loss.mean() / n_trans

        return gini_loss

    
    def compute_gini_loss(self, ret, sum_log_prob, ratio):
        sort_ret, indices = torch.sort(ret, descending=False)
        sort_sum_log_prob = sum_log_prob[indices]
        sort_is_ratio = ratio[indices]
        sample_size = sort_ret.shape[0]

        # compute integral CDF
        diff = sort_ret[1:] - sort_ret[:-1]
        x = torch.linspace(1., sample_size-1, sample_size-1)
        x /= sample_size
        diff = diff * x
        cumsum_diff = diff + torch.sum(diff) - torch.cumsum(diff, dim=-1)
        coef = 2. * cumsum_diff + sort_ret[:-1] - sort_ret[-1]

        gini_loss = -1 * sort_sum_log_prob[:-1] * coef * sort_is_ratio[:-1]
        return gini_loss

