"""
Author: Anonymous
Code for uncertainty-aware self-training for few label learning.
"""

import random
import logging 
import numpy as np
from collections import defaultdict
from numpy.random import seed
from tensorflow.keras.preprocessing import sequence
from string import punctuation
from tensorflow.python.client import device_lib
import os 

import models
import sys
import re
import math

from tensorflow.keras import backend as kb
from tensorflow.keras import optimizers
from tensorflow.keras.models import load_model

import tensorflow.keras as K

import tensorflow as tf

from sklearn.utils import shuffle

from tensorflow.keras.utils import multi_gpu_model, to_categorical
from tensorflow.keras.losses import CategoricalCrossentropy, MeanSquaredError

import time
import sampler

# from tensorflow.keras.utils import plot_model

logger = logging.getLogger(__name__)

GLOBAL_SEED = int(os.getenv("PYTHONHASHSEED"))


class CosineLRSchedule:
    """
    Cosine annealing with warm restarts, described in paper
    "SGDR: stochastic gradient descent with warm restarts"
    https://arxiv.org/abs/1608.03983

    Changes the learning rate, oscillating it between `lr_high` and `lr_low`.
    It takes `period` epochs for the learning rate to drop to its very minimum,
    after which it quickly returns back to `lr_high` (resets) and everything
    starts over again.

    With every reset:
        * the period grows, multiplied by factor `period_mult`
        * the maximum learning rate drops proportionally to `high_lr_mult`

    This class is supposed to be used with
    `keras.callbacks.LearningRateScheduler`.
    """
    def __init__(self, lr_high: float, lr_low: float, initial_period: int = 50,
                 period_mult: float = 2, high_lr_mult: float = 0.97):
        self._lr_high = lr_high
        self._lr_low = lr_low
        self._initial_period = initial_period
        self._period_mult = period_mult
        self._high_lr_mult = high_lr_mult

    def __call__(self, epoch, lr):
        return self.get_lr_for_epoch(epoch)

    def get_lr_for_epoch(self, epoch):
        assert epoch >= 0
        t_cur = 0
        lr_max = self._lr_high
        period = self._initial_period
        result = lr_max
        for i in range(epoch + 1):
            if i == epoch:  # last iteration
                result = (self._lr_low +
                          0.5 * (lr_max - self._lr_low) *
                          (1 + math.cos(math.pi * t_cur / period)))
            else:
                if t_cur == period:
                    period *= self._period_mult
                    lr_max *= self._high_lr_mult
                    t_cur = 0
                else:
                    t_cur += 1
        return result


def create_learning_rate_scheduler(max_learn_rate=5e-5,
                                   end_learn_rate=1e-7,
                                   warmup_epoch_count=10,
                                   total_epoch_count=90):

    def lr_scheduler(epoch):
        if epoch < warmup_epoch_count:
            res = (max_learn_rate/warmup_epoch_count) * (epoch + 1)
        else:
            res = max_learn_rate*math.exp(math.log(end_learn_rate/max_learn_rate)*(epoch-warmup_epoch_count+1)/(total_epoch_count-warmup_epoch_count+1))
        return float(res)
    learning_rate_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=1)

    return learning_rate_scheduler



def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return len([x.name for x in local_device_protos if x.device_type == 'GPU'])

def generate_sequence_data(MAX_SEQUENCE_LENGTH, input_file, tokenizer, unlabeled=False):
    
    X = []
    y = []

    label_count = defaultdict(int)

    with open(input_file, encoding="ISO-8859-1") as f:
        for line in f:
            tok = line.strip().split('\t')
            input_tokens = tokenizer.tokenize(tok[0].strip())

            if len(input_tokens) > MAX_SEQUENCE_LENGTH-2:
                input_tokens = input_tokens[:MAX_SEQUENCE_LENGTH-2]
            input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"]
            input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
            input_ids = input_ids + [0]*(MAX_SEQUENCE_LENGTH - len(input_tokens))
            X.append(input_ids)
            if not unlabeled:
                y.append(int(tok[1].strip()))
                label_count[int(tok[1].strip())] += 1
            else:
                y.append(-1)
    
    for key in label_count.keys():
        print ("Count of instances with label {} is {}".format(key, label_count[key]))

    X = np.array(X)
    y = np.array(y)

    print ('X shape:', X.shape)
    print ('y shape:', y.shape)

    return X, y


class MetricsCallback(K.callbacks.Callback):

    def __init__(self, test_data):
        super().__init__()
        self.test_data = test_data        

    def on_epoch_end(self, epoch, logs={}):
        print ("Eval at end of epoch {}: {}".format(epoch, self.model.evaluate(self.test_data[0], self.test_data[1], verbose=0)))

def chunks(x, batch_size=32):
    """Yield successive batch-sized chunks from x."""
    return [x[i:i + batch_size] for i in range(0, len(x), batch_size)]

def mc_dropout_evaluate(model, gpus, classes, x, y=None, T=30, batch_size=256, training=True):
    y_T = np.zeros((T,len(x), classes))
    acc = None

    print ("Yielding predictions looping over ...")
    strategy = tf.distribute.MirroredStrategy()
    data=tf.data.Dataset.from_tensor_slices(x).batch(batch_size*gpus)
    dist_data = strategy.experimental_distribute_dataset(data)

    for i in range(T):

        print ("{}".format(i), end =" ")

        y_pred = []

        with strategy.scope():
            def eval_step(inputs):
                return model([inputs, np.zeros((inputs.shape[0], inputs.shape[1]))], training=training).numpy()[:,0]

            def distributed_eval_step(dataset_inputs):
                return strategy.experimental_run_v2(eval_step, args=(dataset_inputs,))

            for batch in dist_data:
                pred = distributed_eval_step(batch)
                for gpu in range(gpus):
                    y_pred.extend(pred.values[gpu])

        y_T[i] = np.array(y_pred)

    #compute mean
    y_mean = np.mean(y_T, axis=0)
    assert y_mean.shape == (len(x), classes)

    #compute majority prediction
    y_pred = np.array([np.argmax(np.bincount(row)) for row in np.transpose(np.argmax(y_T, axis=-1))])
    assert y_pred.shape == (len(x),)

    if y is not None:
        assert y.shape == (len(x),)
        acc = (float)(len(np.where(y_pred==y)[0]))/len(y)
        print ("EVAL ACC: {}".format(acc))

    #compute variance
    y_var = np.var(y_T, axis=0)
    assert y_var.shape == (len(x), classes)

    return y_mean, y_var, y_pred, acc, y_T

def train_model(MAX_SEQUENCE_LENGTH, tokenizer, sup_batch_size, unsup_size, sample_size, x_train_all, y_train_all, x_test, y_test, x_unlabeled, model_dir, bert_model_file, sample_scheme, T, alpha, valid_split, sup_epochs, unsup_epochs, N_base):
        
        labels = set(y_train_all)
        print ("Class labels ", labels)

        x_train_all, y_train_all = shuffle(x_train_all, y_train_all, random_state=GLOBAL_SEED)
        train_size = int((1. - valid_split)*len(x_train_all))
        x_train = x_train_all[:train_size]
        y_train = y_train_all[:train_size]
        x_dev = x_train_all[train_size:]
        y_dev = y_train_all[train_size:]


        print("X Train Shape " + str(x_train.shape) + ' ' + str(y_train.shape))
        print("X Dev Shape " + str(x_dev.shape) + ' ' + str(y_dev.shape))
        print("X Test Shape " + str(x_test.shape) + ' ' + str(y_test.shape))

        strategy = tf.distribute.MirroredStrategy()
        gpus = strategy.num_replicas_in_sync
        print('Number of devices: {}'.format(gpus))

        #run the base model n times with different initialization to select best base model based on validation loss
        best_base_model = None
        best_validation_loss = np.inf
        for counter in range(N_base):
            with strategy.scope():
                strong_model = models.construct_bert(bert_model_file, MAX_SEQUENCE_LENGTH, len(labels))
                strong_model.compile(optimizer=K.optimizers.Adam(),
                    loss=K.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=[K.metrics.SparseCategoricalAccuracy(name="acc")])
                
                if counter == 0:
                    print(strong_model.summary())

            model_file = os.path.join(model_dir, "strong_model.h5")
            if os.path.exists(model_file):
                strong_model.load_weights(model_file)
                best_base_model = strong_model
                print ("Model file loaded from {}".format(model_file))
                break

            strong_model.fit(x=[x_train, np.zeros((len(x_train), MAX_SEQUENCE_LENGTH))], y=y_train, batch_size=sup_batch_size*gpus, shuffle=True, epochs=sup_epochs, callbacks=[create_learning_rate_scheduler(max_learn_rate=1e-5, end_learn_rate=1e-7, warmup_epoch_count=20, total_epoch_count=sup_epochs), K.callbacks.EarlyStopping(patience=20, restore_best_weights=True)], validation_data=([x_dev, np.zeros((len(x_dev), MAX_SEQUENCE_LENGTH))], y_dev))

            val_loss = strong_model.evaluate([x_dev, np.zeros((len(x_dev), MAX_SEQUENCE_LENGTH))], y_dev)
            print ("Validation loss for run {} : {}".format(counter, val_loss))
            if val_loss[0] < best_validation_loss:
                best_base_model = strong_model
                best_validation_loss = val_loss[0]

        strong_model = best_base_model
        print ("Best validation loss for base model {}: {}".format(best_validation_loss, strong_model.evaluate([x_dev, np.zeros((len(x_dev), MAX_SEQUENCE_LENGTH))], y_dev)))

        if not os.path.exists(model_file):
            strong_model.save_weights(model_file)
            print ("Model file saved to {}".format(model_file))

        best_val_acc = 0.
        best_eval_acc = 0.
        max_eval_acc = 0.

        for epoch in range(25):

            print ("Starting loop {}".format(epoch))

            eval_acc = strong_model.evaluate([x_test, np.zeros((len(x_test), MAX_SEQUENCE_LENGTH))], y_test, verbose=0)[-1]
            val_acc = strong_model.evaluate([x_dev, np.zeros((len(x_dev), MAX_SEQUENCE_LENGTH))], y_dev, verbose=0)[-1]
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_eval_acc = eval_acc
            if eval_acc > max_eval_acc:
                max_eval_acc = eval_acc

            print ("Evaluation acc {}".format(eval_acc))

            model_file = os.path.join(model_dir, "model_{}_{}.h5".format(epoch, sample_scheme))
            
            if os.path.exists(model_file):
               strong_model.load_weights(model_file)
               print ("Model file loaded from {}".format(model_file))
               continue

            #compute confidence on the unlabeled set
            print ("Evaluating uncertainty on {} number of unlabeled instances".format(sample_size))
            x_unlabeled_sample = x_unlabeled[np.random.choice(len(x_unlabeled), sample_size, replace=False)]
            print (x_unlabeled_sample[:5])


            #mc_dropout_evaluate(strong_model, gpus, x_test, y=y_test, T=50)
            if 'uni' in sample_scheme and 'bald' not in sample_scheme:
                y_mean, y_var, y_T = None, None, None
            else:
                y_mean, y_var, y_pred, _, y_T = mc_dropout_evaluate(strong_model, gpus, len(labels), x_unlabeled_sample, T=T)

            if 'soft' not in sample_scheme:
                y_pred = strong_model.predict([x_unlabeled_sample, np.zeros((len(x_unlabeled_sample), MAX_SEQUENCE_LENGTH))], batch_size=256)
                y_pred = np.argmax(y_pred, axis=-1).flatten()

            # sample from unlabeled set
            if 'weight' in sample_scheme:
                weight = True
            else:
                weight = False

            if 'bald' in sample_scheme and 'eas' in sample_scheme:
                f_ = sampler.sample_by_bald_easiness

            if 'bald' in sample_scheme and 'eas' in sample_scheme and 'clas' in sample_scheme:
                f_ = sampler.sample_by_bald_class_easiness

            if 'bald' in sample_scheme and 'dif' in sample_scheme:
                f_ = sampler.sample_by_bald_difficulty

            if 'bald' in sample_scheme and 'dif' in sample_scheme and 'clas' in sample_scheme:
                f_ = sampler.sample_by_bald_class_difficulty

            if 'uni' in sample_scheme:
                print ("Sampling uniformly")
                indices = np.random.choice(len(x_unlabeled_sample), unsup_size, replace=False)
                x_batch = x_unlabeled_sample[indices]
                y_batch = y_pred[indices]
                x_weight = np.ones(len(y_batch))
            else:
                x_batch, y_batch, x_weight = f_(tokenizer, x_unlabeled_sample, y_mean, y_var, y_pred, unsup_size, len(labels), weight=weight, y_T=y_T)

            if not weight:
                print ("Not using weight.")
                x_weight = np.ones(len(x_batch))
                print ("Weights ", x_weight[:10])
            else:
                print ("Using Weights ", x_weight[:10])
                x_weight = -np.log(x_weight+1e-10)*alpha
                print ("Weights ", x_weight[:10])

            strong_model.fit(x=[x_batch, np.zeros((len(x_batch), MAX_SEQUENCE_LENGTH))], y=y_batch, validation_data=([x_dev, np.zeros((len(x_dev), MAX_SEQUENCE_LENGTH))], y_dev), batch_size=32*gpus, shuffle=True, sample_weight=x_weight, epochs=unsup_epochs, callbacks=[create_learning_rate_scheduler(max_learn_rate=1e-5, end_learn_rate=1e-7, warmup_epoch_count=3, total_epoch_count=unsup_epochs), K.callbacks.EarlyStopping(patience=5, restore_best_weights=True)])

            if not os.path.exists(model_file):
                strong_model.save_weights(model_file)
                print ("Model file saved to {}".format(model_file))

        print ("Final eval acc {}".format(best_eval_acc))
        print ("Best eval acc {}".format(max_eval_acc))
