from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import os
import numpy as np
import torch
import argparse
import random

from data_utils import SST5Processor, SST2Processor, IMDBProcessor, TRECProcessor

from classifier import SequenceClassifier


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

parser = argparse.ArgumentParser()

parser.add_argument('--task', default='sst-2', choices=['sst-5', 'sst-2', 'imdb', 'trec'])
parser.add_argument('--seed', default=159, type=int)

parser.add_argument('--epochs', default=30, type=int)
parser.add_argument("--learning_rate", default=4e-5, type=float)
parser.add_argument("--warmup_proportion", default=0.1, type=float)
parser.add_argument("--min_epochs", default=0, type=int)
parser.add_argument("--w_lr", default=0.01, type=float)

parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--grad_step', default=8, type=int)

#parser.add_argument('--data_rate', default=1.0, type=float)
parser.add_argument('--num_per_class', default=100, type=int)
parser.add_argument('--dev_num_per_class', default=10, type=int)
parser.add_argument('--imbalance_rate', default=1.0, type=float)
parser.add_argument('--noise_rate', default=0., type=float)

parser.add_argument('--plain', default=False, action='store_true')

parser.add_argument('--group_id', default='0', type=str)
parser.add_argument('--name', default='name', type=str)
parser.add_argument('--result_fn',
                    default='results.txt', type=str)

args = parser.parse_args()
print(args)

random.seed(args.seed)
np.random.seed(seed=args.seed)

task_processor_dict = {
    'sst-5': SST5Processor,
    'sst-2': SST2Processor,
    'imdb': IMDBProcessor,
    'trec': TRECProcessor
}

output_dir = 'output_{}'.format(args.group_id)
os.system('mkdir {}'.format(output_dir))

output_dir = os.path.join(
    output_dir, '{}_{}epochs_lr{}_tr{}_dev{}_{}imbalance_{}noise_{}_{}'.format(
        args.task, args.epochs, args.learning_rate,
        args.num_per_class, args.dev_num_per_class,
        args.imbalance_rate, args.noise_rate,
        'plain' if args.plain else '',
        args.name))
os.system('mkdir {}'.format(output_dir))

log_file = open(os.path.join(output_dir, 'scores.txt'), 'w+')


def _print_log(*inputs):
    print(*inputs)
    print(*inputs, file=log_file)
    log_file.flush()


def main():
    processor = task_processor_dict[args.task]()
    # train
    if args.num_per_class < 0:
        num_per_class = None
    else:
        num_per_class = {label: args.num_per_class for label in processor.get_labels()}
        num_per_class[processor.get_labels()[0]] = int(
            num_per_class[processor.get_labels()[0]] * args.imbalance_rate)
    train_examples = processor.get_train_examples(
        num_per_class=num_per_class, noise_rate=args.noise_rate)

    # dev
    if args.dev_num_per_class < 0:
        dev_num_per_class = None
    else:
        dev_num_per_class = {label: args.dev_num_per_class for label in processor.get_labels()}
    dev_examples = processor.get_dev_examples(
        num_per_class=dev_num_per_class)
    # test
    test_examples = processor.get_test_examples()


    classifier = SequenceClassifier(
        label_list=processor.get_labels(), device=device)

    #num_train_steps = args.epochs * (len(train_examples) // (
    #        args.batch_size * args.grad_step) + 1)

    classifier.get_optimizer(
        learning_rate=args.learning_rate,
        warmup_proportion=args.warmup_proportion,
        #t_total=num_train_steps)
        t_total=-1) #TODO

    classifier.load_data(
        'train', train_examples, args.batch_size, shuffle=True,
        has_w=not args.plain, w_lr=args.w_lr)
    classifier.load_data(
        'dev', dev_examples, args.batch_size, shuffle=False)
    classifier.load_data(
        'test', test_examples, args.batch_size, shuffle=False)

    best_dev_acc = -1.
    best_dev_epoch = 0
    final_test_acc = -1.
    do_test = False
    for epoch in range(args.epochs):
        classifier.train_epoch(plain=args.plain, grad_step=args.grad_step)

        dev_acc = classifier.evaluate('dev')
        test_acc = classifier.evaluate('test')

        if dev_acc > best_dev_acc:
            do_test = True
            best_dev_epoch = epoch + 1
        best_dev_acc = max(best_dev_acc, dev_acc)

        _print_log('Epoch {}, Dev Acc: {:.4f}, Best Ever: {:.4f}'.format(
            epoch, 100. * dev_acc, 100. * best_dev_acc))

        if do_test:
            final_test_acc = test_acc
            do_test = False

        if epoch < args.min_epochs:
            best_dev_acc = -1.
            best_dev_epoch = 0
            final_test_acc = -1.

    print('Final Test Acc: {:.4f}, epoch {}'.format(100. * final_test_acc, best_dev_epoch))

    with open(args.result_fn, 'a+') as result_file:
        s_full = '{:.4f},{:.4f},{}'.format(
            100. * best_dev_acc, 100. * final_test_acc,
            best_dev_epoch)
        print('{}'.format(s_full), file=result_file)
        result_file.flush()


if __name__ == '__main__':
    main()
