"""
Author: Anonymous
Code for uncertainty-aware self-training for few label learning.
"""

import sys
import os

GLOBAL_SEED = int(os.getenv("PYTHONHASHSEED"))
print ("Global seed ", GLOBAL_SEED)

from uast import generate_sequence_data, train_model

import argparse
from bert import bert_tokenization
import numpy as np
import random

if __name__ == '__main__':

	# construct the argument parse and parse the arguments
	parser = argparse.ArgumentParser()
	parser.add_argument("--task", required=True, help="name of the task")
	parser.add_argument("--path", required=True, help="path of data")
	parser.add_argument("--model_dir", required=True, help="path of model directory")
	parser.add_argument("--seq_len", required=True, type=int, help="sequence length")
	parser.add_argument("--sup_batch_size", nargs="?", type=int, default=4, help="batch size for fine-tuning base model")
	parser.add_argument("--sample_size", nargs="?", type=int, default=16384, help="number of unlabeled samples for evaluating uncetainty on in each self-training iteration")
	parser.add_argument("--unsup_size", nargs="?", type=int, default=4096, help="number of pseudo-labeled samples used in each self-training iteration")
	parser.add_argument("--sample_scheme", required=True, help="Sampling scheme to use")
	parser.add_argument("--sup_labels", nargs="?", type=int, default=60, help="number of labeled samples per class for training and validation")
	parser.add_argument("--T", nargs="?", type=int, default=30, help="number of masked models gor uncertainty estimation")
	parser.add_argument("--alpha", nargs="?", type=float, default=0.1, help="hyper-parameter for confident training loss")
	parser.add_argument("--valid_split", nargs="?", type=float, default=0.5, help="percentage of labeled samples to use for validation for each class")
	parser.add_argument("--sup_epochs", nargs="?", type=int, default=70, help="number of epochs for fine-tuning base model")
	parser.add_argument("--unsup_epochs", nargs="?", type=int, default=25, help="number of self-training iterations")
	parser.add_argument("--N_base", nargs="?", type=int, default=10, help="number of times to run base model fine-tuning for best base model selection")

	args = vars(parser.parse_args())
	print (args)

	task_name = args["task"]
	MAX_SEQ_LENGTH = args["seq_len"]
	sup_batch_size = args["sup_batch_size"]
	unsup_size = args["unsup_size"]
	sample_size = args["sample_size"]
	path = args["path"]
	model_dir = args["model_dir"]
	sample_scheme = args["sample_scheme"]
	sup_labels = args["sup_labels"]
	T = args["T"]
	alpha = args["alpha"]
	valid_split = args["valid_split"]
	sup_epochs = args["sup_epochs"]
	unsup_epochs = args["unsup_epochs"]
	N_base = args["N_base"]

	vocab_file = path + "/pre-trained-data/vocab.txt"
	bert_model_file = path + "/bert_output/" + task_name #pre-trained-data/uncased_L-12_H-768_A-12"

	print ("Task " , task_name)
	print ("Max Seq Length ", MAX_SEQ_LENGTH)
	print ("Sup Batch Size ", sup_batch_size)
	print ("Unsup Size ", unsup_size)
	print ("Sample Size ", sample_size)
	print ("Path ", path)
	print ("Directory of script ", os.path.dirname(os.path.abspath(__file__)))
	print ("Number of labels for supervision for each class ", sup_labels)
	print ("Sampling scheme ", sample_scheme)
	print ("T ", T)
	print ("Alpha ", alpha)
	print ("Valid split ", valid_split)
	print ("Sup epochs ", sup_epochs)
	print ("Unsup epochs ", unsup_epochs)
	print ("N_base ", N_base)

	label_file = path+"/datasets/"+task_name+"/labels.tsv"

	tokenizer = bert_tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)

	x_train_all, y_train_all = generate_sequence_data(MAX_SEQ_LENGTH, path+"/datasets/"+task_name+"/train.tsv" ,tokenizer)

	x_test, y_test = generate_sequence_data(MAX_SEQ_LENGTH, path+"/datasets/"+task_name+"/test.tsv", tokenizer)

	x_unlabeled, _ = generate_sequence_data(MAX_SEQ_LENGTH, path+"/datasets/"+task_name+"/" + task_name + ".txt", tokenizer, unlabeled=True)

	labels = set(y_train_all)
	if 0 not in labels:
		y_train_all -= 1
		y_test -= 1 	
	labels = set(y_train_all)	
	print ("Labels ", labels)

	x_train = []
	y_train = []
	for i in labels:
		#get sup_labels positive example
		indx = np.where(y_train_all==i)[0]
		random.Random(GLOBAL_SEED).shuffle(indx)
		indx = indx[:sup_labels]
		x_train.extend(x_train_all[indx])
		y_train.extend(np.full(sup_labels, i))

	x_train = np.array(x_train)
	y_train = np.array(y_train)

	train_model(MAX_SEQ_LENGTH, tokenizer, sup_batch_size, unsup_size, sample_size, x_train, y_train, x_test, y_test, x_unlabeled, model_dir, bert_model_file, sample_scheme, T, alpha, valid_split, sup_epochs, unsup_epochs, N_base)
