import os
import sys
import pprint
from collections import defaultdict
import time

import numpy as np
import torch
from torch.utils.data import DataLoader

from elf.options import PyOptionSpec
from rlpytorch.behavior_clone.coach_dataset import CoachDataset, compute_cache
from rlpytorch.behavior_clone import utils
from rlpytorch.behavior_clone.conv_rnn_coach import ConvRnnCoach
from rlpytorch.behavior_clone.conv_rnn_generator import ConvRnnGenerator
from rlpytorch.behavior_clone.conv_onehot_coach import ConvOneHotCoach
from rlpytorch.utils import set_all_seeds, EvalMode, Logger


def evaluate(model, device, data_loader, epoch, name, norm_loss):
    assert not model.training

    losses = defaultdict(list)
    t = time.time()
    for batch_idx, batch in enumerate(data_loader):
        utils.to_device(batch, device)
        if norm_loss:
            loss, all_losses = model.compute_eval_loss(batch)
        else:
            loss, all_losses = model.compute_loss(batch)

        for key, val in all_losses.items():
            losses[key].append(val.item())

    print('%s epoch: %d, time: %.2f' % (name, epoch, time.time() - t))
    for key, val in losses.items():
        print('\t%s: %.5f' % (key, np.mean(val)))

    return np.mean(losses['loss'])


def parse_args():
    spec = PyOptionSpec()

    # train config
    # spec.addStrOption('data_mode', 'one_hot/word', '')
    spec.addIntOption('batch_size', '', 128)
    spec.addIntOption('epochs', '', 10)
    spec.addIntOption('gpu', '', 0)
    spec.addIntOption('seed', '', 1)
    # spec.addStrOption(
    #     'train_dataset',
    #     'path to train dataset',
    #     'data3/dev.json_min10')
    spec.addStrOption(
        'val_dataset',
        'path to val dataset',
        'data3/dev.json_min10')
    spec.addStrOption('model_folder', 'folder to save model', 'model-test')

    # optim
    spec.addStrOption('optim', '', 'adam')
    spec.addFloatOption('lr', '', 1e-3)
    spec.addFloatOption('beta1', 'for adam', 0.9)
    spec.addFloatOption('beta2', 'for adam', 0.999)
    spec.addFloatOption('grad_clip', 'grad clip by norm', 0.5)
    spec.addFloatOption('moving_avg_decay', 'moving avg for enemy count', 0.9995)
    spec.addIntOption('max_instruction_span', '', 20)

    spec.addStrOption('model', '', '/private/home/hengyuan/elf2-bc-dev/src_py/rlpytorch/behavior_clone/sweep_coach/conv_rnn_gen/modelconv_rnn_gen_inst_cls_dim512/best_checkpoint.pt')

    # debug
    spec.addBoolOption('dev', 'for debug', False)

    # if os.environ['MODEL'] == 'conv_rnn' or os.environ['MODEL'] == 'conv_bow':
    #     spec.merge(ConvRnnCoach.get_option_spec())
    # elif os.environ['MODEL'] == 'conv_onehot':
    #     spec.merge(ConvOneHotCoach.get_option_spec())
    # elif os.environ['MODEL'] == 'conv_rnn_gen':
    spec.merge(ConvRnnGenerator.get_option_spec())

    option_map = spec.parse()
    return option_map


def main():
    torch.backends.cudnn.benchmark = True

    option_map = parse_args()
    options = option_map.getOptions()

    if not os.path.exists(options.model_folder):
        os.mkdir(options.model_folder)
    logger_path = os.path.join(options.model_folder, 'train.log')
    if not options.dev:
        sys.stdout = Logger(logger_path)

    print('Args:\n%s\n' % pprint.pformat(vars(options)))

    if options.gpu < 0:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:%d' % options.gpu)

    set_all_seeds(options.seed)

    model = utils.load_model(options.model).to(device)
    print(model)

    # train_dataset = CoachDataset(
    #     options.train_dataset,
    #     options.moving_avg_decay,
    #     options.num_resource_bin,
    #     options.resource_bin_size,
    #     options.max_num_prev_cmds,
    #     model.inst_dict,
    #     options.max_instruction_span,
    # )
    # val_dataset = CoachDataset(
    #     options.val_dataset,
    #     options.moving_avg_decay,
    #     options.num_resource_bin,
    #     options.resource_bin_size,
    #     options.max_num_prev_cmds,
    #     model.inst_dict,
    #     options.max_instruction_span,
    # )
    print('moving_avg_decay:', options.moving_avg_decay)
    eval_500 = CoachDataset(
        options.val_dataset,
        options.moving_avg_decay,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        model.inst_dict,
        options.max_instruction_span,
        num_instructions=500)

    eval_250 = CoachDataset(
        options.val_dataset,
        options.moving_avg_decay,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        model.inst_dict,
        options.max_instruction_span,
        num_instructions=250)

    eval_50 = CoachDataset(
        options.val_dataset,
        options.moving_avg_decay,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        model.inst_dict,
        options.max_instruction_span,
        num_instructions=50)

    # compute_cache(train_dataset)
    # compute_cache(val_dataset)
    compute_cache(eval_500)
    compute_cache(eval_250)
    compute_cache(eval_50)

    # if options.optim == 'adamax':
    #     optimizer = torch.optim.Adamax(
    #         model.parameters(),
    #         lr=options.lr,
    #         betas=(options.beta1, options.beta2))
    # elif options.optim == 'adam':
    #     optimizer = torch.optim.Adam(
    #         model.parameters(),
    #         lr=options.lr,
    #         betas=(options.beta1, options.beta2))
    # else:
    #     assert False, 'not supported'

    # train_loader = DataLoader(
    #     train_dataset,
    #     options.batch_size,
    #     shuffle=True,
    #     num_workers=1,# if options.dev else 10,
    #     pin_memory=(options.gpu >= 0))
    # val_loader = DataLoader(
    #     val_dataset,
    #     options.batch_size,
    #     shuffle=False,
    #     num_workers=1,# if options.dev else 10,
    #     pin_memory=(options.gpu >= 0))
    eval_loader500 = DataLoader(
        eval_500,
        options.batch_size,
        shuffle=False,
        num_workers=1,#0 if options.dev else 10,
        pin_memory=(options.gpu >= 0))
    eval_loader250 = DataLoader(
        eval_250,
        options.batch_size,
        shuffle=False,
        num_workers=1,#0 if options.dev else 10,
        pin_memory=(options.gpu >= 0))
    eval_loader50 = DataLoader(
        eval_50,
        options.batch_size,
        shuffle=False,
        num_workers=1,#0 if options.dev else 10,
        pin_memory=(options.gpu >= 0))

    # best_val_nll = float('inf')
    # for epoch in range(1, options.epochs + 1):
    #     print('==========')
    #     # train(model, device, optimizer, options.grad_clip, train_loader, epoch)
    model.pos_candidate_inst = None
    model.options.num_instructions = 500
    with torch.no_grad(), EvalMode(model):
        # val_nll = evaluate(model, device, val_loader, epoch, 'val', False)
        eval_nll = evaluate(model, device, eval_loader500, 0, 'eval', True)
        print('inst 500, eval_nll:', eval_nll)

    model.pos_candidate_inst = None
    model.options.num_instructions = 250
    with torch.no_grad(), EvalMode(model):
        # val_nll = evaluate(model, device, val_loader, epoch, 'val', False)
        eval_nll = evaluate(model, device, eval_loader250, 0, 'eval', True)
        print('inst 250, eval_nll:', eval_nll)

    model.pos_candidate_inst = None
    model.options.num_instructions = 50
    with torch.no_grad(), EvalMode(model):
        # val_nll = evaluate(model, device, val_loader, epoch, 'val', False)
        eval_nll = evaluate(model, device, eval_loader50, 0, 'eval', True)
        print('inst 50, eval_nll', eval_nll)


if __name__ == '__main__':
    # os.environ['MODEL'] = 'rnn'
    # model, options = main()
    # device = torch.device('cpu')
    # inst, inst_len = model._get_neg_candidate_inst(device, 0)
    main()
