
import math
import random
import numpy as np
import os
import sys
from tqdm import tqdm
# sys.path.append('..')

from collections import namedtuple
import argparse
from itertools import count, chain
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils import *

from RL.env_binary_question import BinaryRecommendEnv
from RL.env_enumerated_question import EnumeratedRecommendEnv
from RL.RL_evaluate import dqn_evaluate
from gcn import GraphEncoder
from encoder import TransGate
from torch.autograd import Variable
from torch.distributions import Categorical
import warnings
from encoder import NegRanker

warnings.filterwarnings("ignore")
EnvDict = {
    LAST_FM: BinaryRecommendEnv,
    LAST_FM_STAR: BinaryRecommendEnv,
    YELP: EnumeratedRecommendEnv,
    YELP_STAR: BinaryRecommendEnv
    }
FeatureDict = {
    LAST_FM: 'feature',
    LAST_FM_STAR: 'feature',
    YELP: 'large_feature',
    YELP_STAR: 'feature'
}

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'next_cand'))

class DQN(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=100):
        super(DQN, self).__init__()
        # V(s)
        self.fc2_value = nn.Linear(hidden_size, hidden_size)
        self.out_value = nn.Linear(hidden_size, 1)
        # Q(s,a)
        self.fc2_advantage = nn.Linear(hidden_size + action_size, hidden_size)   
        self.out_advantage = nn.Linear(hidden_size, 1)

    def forward(self, x, y, choose_action=True):
        """
        :param x: encode history [N*L*D]; y: action embedding [N*K*D]
        :return: v: action score [N*K]
        """
        # V(s)
        value = self.out_value(F.relu(self.fc2_value(x))).squeeze(dim=2) #[N*1*1]
        
        # Q(s,a)
        if choose_action:
            x = x.repeat(1, y.size(1), 1)
        
        state_cat_action = torch.cat((x, y),dim=2)
        advantage = self.out_advantage(F.relu(self.fc2_advantage(state_cat_action))).squeeze(dim=2) #[N*K]
        return advantage


class Agent(nn.Module):
    def __init__(self, device, state_size, action_size, hidden_size, gcn_net, learning_rate, l2_norm, PADDING_ID, EPS_START = 0.9, EPS_END = 0.1, EPS_DECAY = 0.0001, tau=0.01):
        super(Agent, self).__init__()
        self.EPS_START = EPS_START
        self.EPS_END = EPS_END
        self.EPS_DECAY = EPS_DECAY
        self.steps_done = 0
        self.device = device
        self.gcn_net = gcn_net
        self.policy_net = DQN(state_size, action_size, hidden_size)
        self.target_net = DQN(state_size, action_size, hidden_size)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.loss_func = nn.MSELoss()
        self.PADDING_ID = PADDING_ID
        self.tau = tau
        self.item_ranker = NegRanker(hidden_size=state_size, action_size=action_size)
        self.optimizer = optim.Adam(chain(self.policy_net.parameters(),self.gcn_net.parameters(), self.item_ranker.parameters()), lr=learning_rate, weight_decay = l2_norm)


    def select_action(self, state, cand_feature, cand_item, action_space, is_test=False, is_last_turn=False):
        state_emb = self.gcn_net([state])
        
        cand_feature = torch.LongTensor([cand_feature]).to(self.device)
        cand_item = torch.LongTensor([cand_item]).to(self.device)

        cand_feat_emb = self.gcn_net.embedding(cand_feature)
        cand_item_emb = self.gcn_net.embedding(cand_item)

        cand = torch.cat((cand_feature, cand_item), 1)

        cand_emb = torch.cat((cand_feat_emb, cand_item_emb), 1)

        self.steps_done += 1
        action_value = self.policy_net(state_emb, cand_emb)
        prob = Categorical(action_value.softmax(1))
        if is_test:
            if (len(action_space[1]) <= 10 or is_last_turn):
                return torch.tensor(action_space[1][0], device=self.device, dtype=torch.long), action_space[1]
            action = cand[0][action_value.argmax().item()]
            sorted_actions = cand[0][action_value.sort(1, True)[1].tolist()]
            return action, sorted_actions.tolist()
        else:
            action = prob.sample()
            log_prob = prob.log_prob(action)
            action = cand[0][action]
            sorted_actions = cand[0][action_value.sort(1, True)[1].tolist()]
            return action, sorted_actions.tolist(), log_prob, 0, cand_item 

    def save_model(self, data_name, filename, epoch_user):
        save_rl_agent(dataset=data_name, model={'policy': self.policy_net.state_dict(), 'gcn': self.gcn_net.state_dict(), 'item_ranker': self.item_ranker.state_dict()}, filename=filename, epoch_user=epoch_user)
    
    def load_model(self, data_name, filename, epoch_user):
        model_dict = load_rl_agent(dataset=data_name, filename=filename, epoch_user=epoch_user)
        self.policy_net.load_state_dict(model_dict['policy'])
        self.gcn_net.load_state_dict(model_dict['gcn'])
        self.item_ranker.load_state_dict(model_dict['item_ranker'])


def train(args, kg, dataset, filename):
    env = EnvDict[args.data_name](kg, dataset, args.data_name, args.embed, seed=args.seed, max_turn=args.max_turn, cand_num=args.cand_num, cand_item_num=args.cand_item_num,
                       attr_num=args.attr_num, mode='train', ask_num=args.ask_num, entropy_way=args.entropy_method, fm_epoch=args.fm_epoch, args=args)
    set_random_seed(args.seed)
    embed = torch.FloatTensor(np.concatenate((env.ui_embeds, env.feature_emb, np.zeros((1,env.ui_embeds.shape[1]))), axis=0))
    
    if args.transgate:
        gcn_net = TransGate(device=args.device, entity=embed.size(0), emb_size=embed.size(1), kg=kg, embeddings=embed, \
            fix_emb=args.fix_emb, seq=args.seq, gcn=args.gcn, hidden_size=args.hidden).to(args.device)
    else:
        gcn_net = GraphEncoder(device=args.device, entity=embed.size(0), emb_size=embed.size(1), kg=kg, embeddings=embed, \
            fix_emb=args.fix_emb, seq=args.seq, gcn=args.gcn, hidden_size=args.hidden).to(args.device)
    
    agent = Agent(device=args.device, state_size=args.hidden, action_size=embed.size(1), \
        hidden_size=args.hidden, gcn_net=gcn_net, learning_rate=args.learning_rate, l2_norm=args.l2_norm, PADDING_ID=embed.size(0)-1).to(args.device)

    
    if args.load_rl_epoch != 0 :
        load_filename = 'train-data-{}-RL-cand_num-{}-cand_item_num-{}-embed-{}-seq-{}-gcn-{}-transgate-{}-rank-{}_PG'.format(
        args.data_name, args.cand_num, 10, args.embed, args.seq, args.gcn, args.transgate, args.item_rank)
        print('Staring loading rl model in epoch {}'.format(args.load_rl_epoch))
        agent.load_model(data_name=args.data_name, filename=load_filename, epoch_user=args.load_rl_epoch)
    
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, weight_decay = args.l2_norm)
    test_performance = []
    
    if args.eval_num == 1:
        SR15_mean = dqn_evaluate(args, kg, dataset, agent, filename, 0)
        test_performance.append(SR15_mean)
        print('Test SR15: ', SR15_mean)
        return
    
    for train_step in tqdm(range(1, args.max_steps + 1)):
        SR5, SR10, SR15, AvgT, Rank, total_reward = 0., 0., 0., 0., 0., 0.
        loss = torch.tensor(0, dtype=torch.float, device=args.device)
        
        # blockPrint()
        for i_episode in range(args.sample_times):
            print('\n================new tuple:{}===================='.format(i_episode))
            if not args.fix_emb:
                state, cand, action_space = env.reset(agent.gcn_net.embedding.weight.data.cpu().detach().numpy())  # Reset environment and record the starting state
            else:
                state, cand, action_space = env.reset() 

            epi_reward = 0
            is_last_turn = False

            reward_list = []
            log_prob_list = Variable(torch.Tensor()).to(args.device)

            for t in count():   # user  dialog
                if t == 14:
                    is_last_turn = True

                action, sorted_actions, log_prob, item_score, item_idx = agent.select_action(state, cand[0], cand[1], action_space, is_last_turn=is_last_turn)
                if not args.fix_emb:
                    next_state, next_cand, action_space, reward, done, success, _ = env.step(action.item(), sorted_actions, agent.gcn_net.embedding.weight.data.cpu().detach().numpy())
                else:
                    next_state, next_cand, action_space, reward, done, success, _ = env.step(action.item(), sorted_actions)
                

                epi_reward += reward
                reward_list.append(reward)
                log_prob_list = torch.cat([log_prob_list, log_prob.reshape(1)])
                
                if done:
                    next_state = None

                state = next_state
                cand = next_cand

                if done:
                    # every episode update the target model to be same with model
                    if reward == 1:  # recommend successfully
                        if t < 5:
                            SR5 += 1
                            SR10 += 1
                            SR15 += 1
                        elif t < 10:
                            SR10 += 1
                            SR15 += 1
                        else:
                            SR15 += 1
                        Rank += (1/math.log(t+3,2) + (1/math.log(t+2,2)-1/math.log(t+3,2))/math.log(done+1,2))
                    else:
                        Rank += 0
                    AvgT += t+1
                    total_reward += epi_reward
                    break

            rewards = []
            R = 0
            for r in reward_list[::-1]:
                R = r + args.gamma * R
                rewards.insert(0, R)

            loss = torch.sum(torch.mul(log_prob_list, Variable(torch.Tensor(rewards).to(args.device))).mul(-1)) 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        enablePrint() # Enable print function
        print('SR5:{}, SR10:{}, SR15:{}, AvgT:{}, Rank:{}, rewards:{} '
                  'Total epoch_uesr:{}'.format(SR5 / args.sample_times, SR10 / args.sample_times, SR15 / args.sample_times,
                                                AvgT / args.sample_times, Rank / args.sample_times, total_reward / args.sample_times, args.sample_times))

        if train_step % args.eval_num == 0:
            SR15_mean = dqn_evaluate(args, kg, dataset, agent, filename, train_step)
            test_performance.append(SR15_mean)
            print('Test SR15: ', SR15_mean)
        
        if train_step % args.save_num == 0:
            agent.save_model(data_name=args.data_name, filename=filename, epoch_user=train_step)
    print(test_performance)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', '-seed', type=int, default=1, help='random seed.')
    parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
    parser.add_argument('--epochs', '-me', type=int, default=50000, help='the number of RL train epoch')
    parser.add_argument('--fm_epoch', type=int, default=0, help='the epoch of FM embedding')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size.')
    parser.add_argument('--gamma', type=float, default=0.999, help='reward discount factor.')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate.')
    parser.add_argument('--l2_norm', type=float, default=1e-6, help='l2 regularization.')
    parser.add_argument('--hidden', type=int, default=100, help='number of samples')
    parser.add_argument('--memory_size', type=int, default=50000, help='size of memory ')

    parser.add_argument('--data_name', type=str, default=LAST_FM, choices=[LAST_FM, LAST_FM_STAR, YELP, YELP_STAR],
                        help='One of {LAST_FM, LAST_FM_STAR, YELP, YELP_STAR}.')
    parser.add_argument('--entropy_method', type=str, default='weight_entropy', help='entropy_method is one of {entropy, weight entropy}')
    # Although the performance of 'weighted entropy' is better, 'entropy' is an alternative method considering the time cost.
    parser.add_argument('--max_turn', type=int, default=15, help='max conversation turn')
    parser.add_argument('--attr_num', type=int, help='the number of attributes')
    parser.add_argument('--mode', type=str, default='train', help='the mode in [train, test]')
    parser.add_argument('--ask_num', type=int, default=1, help='the number of features asked in a turn')
    parser.add_argument('--load_rl_epoch', type=int, default=0, help='the epoch of loading RL model')

    parser.add_argument('--sample_times', type=int, default=100, help='the epoch of sampling')
    parser.add_argument('--max_steps', type=int, default=500, help='max training steps')
    parser.add_argument('--eval_num', type=int, default=100, help='the number of steps to evaluate RL model and metric')
    parser.add_argument('--save_num', type=int, default=100, help='the number of steps to save RL model and metric')
    parser.add_argument('--observe_num', type=int, default=500, help='the number of steps to print metric')
    parser.add_argument('--cand_num', type=int, default=10, help='candidate sampling number')
    parser.add_argument('--cand_item_num', type=int, default=10, help='candidate item sampling number')
    parser.add_argument('--fix_emb', action='store_false', help='fix embedding or not')
    parser.add_argument('--embed', type=str, default='transe', help='pretrained embeddings')
    parser.add_argument('--seq', type=str, default='transformer', choices=['rnn', 'transformer', 'mean'], help='sequential learning method')
    parser.add_argument('--gcn', action='store_false', help='use GCN or not')
    
    parser.add_argument('--item_rank', action='store_true')
    parser.add_argument('--transgate', action='store_true')


    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    args.device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
    print(args.device)
    print('data_set:{}'.format(args.data_name))
    kg = load_kg(args.data_name)
    #reset attr_num
    feature_name = FeatureDict[args.data_name]
    feature_length = len(kg.G[feature_name].keys())
    print('dataset:{}, feature_length:{}'.format(args.data_name, feature_length))
    args.attr_num = feature_length  # set attr_num  = feature_length
    print('args.attr_num:', args.attr_num)
    print('args.entropy_method:', args.entropy_method)

    dataset = load_dataset(args.data_name)
    filename = 'train-data-{}-RL-cand_num-{}-cand_item_num-{}-embed-{}-seq-{}-gcn-{}-transgate-{}-rank-{}_PG'.format(
        args.data_name, args.cand_num, args.cand_item_num, args.embed, args.seq, args.gcn, args.transgate, args.item_rank, args.entropy_method)
    train(args, kg, dataset, filename)

if __name__ == '__main__':
    main()