from absl import app
from absl import flags
from absl import logging
import numpy as np
import os.path as path
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from sklearn.metrics import brier_score_loss
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import csv

from transformations.reader.matrix import test_argument_and_file, load_and_log
from timed_logger import TimedLogger

FLAGS = flags.FLAGS

flags.DEFINE_string("path", ".", "Path to the matrices directory")
flags.DEFINE_string("features_train", None, "Name of the train features numpy matrix exported file (npy)")
flags.DEFINE_string("features_test", None, "Name of the test features numpy matrix exported file (npy)")
flags.DEFINE_string("labels_train", None, "Name of the train labels numpy matrix exported file (npy)")
flags.DEFINE_string("labels_test", None, "Name of the test labels numpy matrix exported file (npy)")
flags.DEFINE_list("l2_regs", None, "L2 regularization (list) of the last layer")
flags.DEFINE_integer("epochs", 100, "Number of epochs to train")
flags.DEFINE_list("sgd_lrs", None, "SGD learning rate (list)")
flags.DEFINE_float("sgd_momentum", 0.9, "SGD momentum")

flags.DEFINE_string("output_file", None, "File to write the output in CSV format (including headers)")
flags.DEFINE_bool("output_overwrite", True, "Writes (if True) or appends (if False) to the specified output file if any")


def _write_result(rows, overwrite):
    if len(rows) == 0:
        return
    writeheader = False
    if overwrite or not path.exists(FLAGS.output_file):
        writeheader = True
    with open(FLAGS.output_file, mode='w+' if overwrite else 'a+') as f:
        fieldnames = list(rows[0].keys())
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if writeheader:
            writer.writeheader()
        for r in rows:
            writer.writerow(r)


def _get_csv_row(l2_reg, sgd_lr, loss, brier_score, norm, accuracy):
    return {"l2_reg": l2_reg,
            "epochs": FLAGS.epochs,
            "batch_size": FLAGS.batch_size,
            "sgd_lr": sgd_lr,
            "sgd_momentum": FLAGS.sgd_momentum,
            "cross_entropy_loss": loss,
            "brier_score": brier_score,
            "forbenius_norm": norm,
            "accuracy": accuracy}


def labels_encoded(labels):
    results = []
    classes = list(np.unique(labels))
    for l in labels:
        idx = classes.index(l)
        a = np.zeros(len(classes))
        a[idx] = 1
        results.append(a)
    return results


def brier_multi(probs, labels):
    targets = labels_encoded(labels)
    return np.mean(np.sum((probs - targets)**2, axis=1))


def train_model_cross_entropy(features_train, labels_train, features_test, labels_test, classes, dimension, l2_reg, sgd_lr):

    with TimedLogger("Training the linear layer with SGD(lr={}, momentum={}), batch_size={}, L2_Reg={}, epochs={}".format(sgd_lr, FLAGS.sgd_momentum, FLAGS.batch_size, l2_reg, FLAGS.epochs)):

        model = keras.models.Sequential([
            keras.layers.Dense(classes, input_shape = (dimension,), activation='softmax', kernel_regularizer=tf.keras.regularizers.l2(l2_reg))
        ])
        model.compile(optimizer = keras.optimizers.SGD(learning_rate=sgd_lr, momentum=FLAGS.sgd_momentum), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        model.fit(features_train, labels_train, epochs = FLAGS.epochs, batch_size=FLAGS.batch_size, validation_data=(features_test, labels_test))

    norm = np.linalg.norm(model.layers[-1].get_weights()[0])
    loss_accuracy = model.evaluate(features_test, labels_test, verbose=0)
    predictions = model.predict_proba(features_test, verbose=0)
    brier_score = brier_multi(predictions, labels_test)

    logging.log(logging.INFO, "Loss and accuracy cross_entropy: {0}".format(loss_accuracy))
    logging.log(logging.INFO, "Brier score cross_entropy: {0}".format(brier_score))
    logging.log(logging.INFO, "Frobenius norm of weights cross_entropy: {}.".format(round(norm,3)))

    return _get_csv_row(l2_reg, sgd_lr, loss_accuracy[0], brier_score, norm, loss_accuracy[1])


def main(argv):

    test_argument_and_file(FLAGS.path, "features_train")
    test_argument_and_file(FLAGS.path, "features_test")
    test_argument_and_file(FLAGS.path, "labels_train")
    test_argument_and_file(FLAGS.path, "labels_test")

    train_features, dim_train, samples_train = load_and_log(FLAGS.path, "features_train")
    test_features, dim_test, samples_test = load_and_log(FLAGS.path, "features_test")

    if dim_test != dim_train:
        raise AttributeError("Train and test features do not have the same dimension!")

    train_labels, dim, samples_train_labels = load_and_log(FLAGS.path, "labels_train")
    if dim != 1:
        raise AttributeError("Train labels file does not point to a vector!")
    if samples_train_labels != samples_train:
        raise AttributeError("Train features and labels files does not have the same amount of samples!")
    test_labels, _, samples_test_labels = load_and_log(FLAGS.path, "labels_test")
    if dim != 1:
        raise AttributeError("Test labels file does not point to a vector!")
    if samples_test_labels != samples_test:
        raise AttributeError("Test features and labels files does not have the same amount of samples!")

    # Adjust labels
    unique_classes = list(np.unique(np.concatenate((train_labels, test_labels))))
    for i in range(len(train_labels)):
        train_labels[i] = unique_classes.index(train_labels[i])
    for i in range(len(test_labels)):
        test_labels[i] = unique_classes.index(test_labels[i])

    with TimedLogger("Normalizing features using MinMaxScaler"):
        minmax = MinMaxScaler(feature_range=(-1, 1), copy=True)
        train_features = minmax.fit_transform(train_features)
        test_features = minmax.transform(test_features)

    classes = len(unique_classes)
    dim = dim_train

    overwrite = FLAGS.output_overwrite
    for l2_reg in sorted([float(x) for x in FLAGS.l2_regs]):
        for sgd_lr in sorted([float(x) for x in FLAGS.sgd_lrs]):
            rows = [train_model_cross_entropy(train_features, train_labels, test_features, test_labels, classes, dim, l2_reg, sgd_lr)]

            if FLAGS.output_file:
                _write_result(rows, overwrite)
                overwrite = False

if __name__ == "__main__":
    app.run(main)
