# Experiment
# 
# Divide the data into disjoint bricks in the time-step domain. That is, divide
# the 49 times steps into disjoint groups and see if a classifier can learn
# them.
# 
# Classifier here is a stacked LSTM. That is, instead of predicting at the
# brick level, I'm passing it into a tree-based classifier. This should be
# easier and probably much more accurate than vanilla brick level classifier.
# Further, the thought here is that transfer learning can be used to train the
# brick level classifier.
#


import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import os
import argparse
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from divideandconquer.common.graphs import DataIngest
from divideandconquer.common.graphs import StackedLSTM_Network
from divideandconquer.common.trainers import SimpleTrainer
from divideandconquer.common.utils import unison_shuffled_copies
from divideandconquer.common.utils import GraphManager, runOpList
from divideandconquer.common.utils import flatten_features
from divideandconquer.common.utils import convertToDetectionLabels

class ConfigOptions:
    def __init__(self):
        HIDDEN_DIM0 = None
        HIDDEN_DIM1 = None
        KEEP_PROB0 = None
        KEEP_PROB1 = None
        LEARNING_RATE = None
        NUM_EPOCHS = None
        BATCH_SIZE = None
        DATA_DIR = None
        GROUP_SIZE = None
        CHECKPOINT_DIR = None
        USE_DROPOUT = None
        self.__parsed = None

        self.parser = argparse.ArgumentParser()
        self.parser.add_argument('-d', '--data-dir', required=True,
                                 help="Input data directory")
        self.parser.add_argument('-hd0', '--hidden-dim0', required=True,
                                 help='Hidden dimension of LSTM layer0',
                                 type=int)
        self.parser.add_argument('-hd1', '--hidden-dim1', required=True,
                                 help='Hidden dimension of LSTM layer1',
                                 type=int)
        self.parser.add_argument('-kp1', '--keep-prob1', required=False,
                                 help='Dropout keep-prob for layer 1',
                                 type=float, default=1.0)
        self.parser.add_argument('-kp0', '--keep-prob0', required=False,
                                 help='Dropout keep-prob for layer 0',
                                 type=float, default=1.0)
        self.parser.add_argument('-lr', '--learning-rate', required=False,
                                 default=0.001, help="Learning rate",
                                 type=float)
        self.parser.add_argument('-m', '--momentum', required=False,
                                 default=0.9, help="Momentum for Nesterov",
                                 type=float)
        self.parser.add_argument('-bs', '--batch-size', required=False,
                                 default=128, type=int, help='Batch size')
        self.parser.add_argument('-ne', '--epochs', required=False,
                                 default=100, type=int,
                                 help='Number of epochs')
        self.parser.add_argument('-gs', '--group-size', required=False,
                                  default=8, type=int, help='group \
                                  size')
        self.parser.add_argument('-c', '--checkpoint-dir', required=True,
                                 help="Directory to dump checkpoints")
        self.__parsed = self.parser.parse_args()

    def configure(self):
        psd = self.__parsed
        self.DATA_DIR = psd.data_dir
        self.HIDDEN_DIM0 = psd.hidden_dim0
        self.HIDDEN_DIM1 = psd.hidden_dim1
        self.LEARNING_RATE = psd.learning_rate
        self.MOMENTUM = psd.momentum
        self.BATCH_SIZE = psd.batch_size
        self.NUM_EPOCHS = psd.epochs
        self.GROUP_SIZE = psd.group_size
        self.CHECKPOINT_DIR = psd.checkpoint_dir
        self.KEEP_PROB0 = psd.keep_prob0
        self.KEEP_PROB1 = psd.keep_prob1
        # We alywas use dropout (with default 1.0 if need be)
        self.USE_DROPOUT = True
        assert 0.0 <= self.KEEP_PROB0 <= 1.0
        assert 0.0 <= self.KEEP_PROB1 <= 1.0
        msg = "Directory is not writable: %s" % self.CHECKPOINT_DIR
        assert os.access(self.CHECKPOINT_DIR, os.W_OK), msg
        self.CHECKPOINT_DIR += '/model'

    def echoConfig(self):
        print("DATA_DIR:", self.DATA_DIR)
        print("HIDDEN_DIM0:", self.HIDDEN_DIM0)
        print("HIDDEN_DIM1:", self.HIDDEN_DIM1)
        print("KEEP_PROB0: ", self.KEEP_PROB0)
        print("KEEP_PROB1: ", self.KEEP_PROB1)
        print("LEARNING_RATE:", self.LEARNING_RATE)
        print("MOMENTUM:", self.MOMENTUM)
        print("BATCH_SIZE:", self.BATCH_SIZE)
        print("NUM_EPOCHS:", self.NUM_EPOCHS)
        print("GROUP_SIZE:", self.GROUP_SIZE)
        print("CHECKPOINT_DIR:", self.CHECKPOINT_DIR)


def main():
    config = ConfigOptions()
    config.configure()
    config.echoConfig()
    dataDir = config.DATA_DIR
    # This is the MIL data. Hence
    #   - x is of dimension [-1, NUM_SUBINSTANCE, NUM_OUTPUT]
    #   - y is of dimension [-1, NUM_SUBINSTANCE, NUM_OUTPUT]
    # NOTE that y here is the predicted labels from the dataset and not the
    # initialization label. Hence, be careful about how bag label is retrived.
    x_train = np.load(dataDir + 'x_train.npy')
    y_train = np.load(dataDir + 'y_train.npy')
    x_val = np.load(dataDir + 'x_val.npy')
    y_val = np.load(dataDir + 'y_val.npy')
    mean = np.mean(np.reshape(x_train, [-1, x_train.shape[-1]]), axis=0)
    std = np.std(np.reshape(x_train, [-1, x_train.shape[-1]]), axis=0)
    std[std[:] < 0.000001] = 1
    x_train = (x_train - mean) / std
    x_val = (x_val - mean) / std
    # y_train = convertToDetectionLabels(y_train, label=0)
    print('x_train ', x_train.shape)
    print('y_train ', y_train.shape)
    x_new, y_new = flatten_features(x_train, y_train, group=config.GROUP_SIZE,
                                        mode='Stacked')
    x_val_new, y_val_new = flatten_features(x_val, y_val, group=config.GROUP_SIZE,
                                        mode='Stacked')
    y_new = y_new[:, 0, :]
    y_val_new = y_val_new[:, 0, :]
    print('x_new', x_new.shape)
    print('y_new', y_new.shape)
    NUM_BRICKS = x_new.shape[1]
    NUM_BRICK_TIMESTEPS = x_new.shape[2]
    NUM_INPUT = x_new.shape[3]
    NUM_OUTPUT = y_new.shape[1]
    # Note that these are only approximate values
    # and not true bag.
    bag = np.argmax(y_new, axis=1)
    bins= np.bincount(bag)
    print(bins, '(not computed from true bags)')
    print("Trivial acc: ", np.max(bins)/sum(bins), " for class: ",
          np.argmax(bins))
    # Create the computation graph
    x_dim = [None, NUM_BRICKS, NUM_BRICK_TIMESTEPS,  NUM_INPUT]
    y_dim = [None, NUM_OUTPUT]
    dataIngest = DataIngest(x_dim, y_dim)
    stackedLstm_network = StackedLSTM_Network(NUM_BRICKS, NUM_BRICK_TIMESTEPS,
                                              NUM_INPUT, config.HIDDEN_DIM0,
                                              config.HIDDEN_DIM1, NUM_OUTPUT,
                                              useDropout=config.USE_DROPOUT)
    trainer = SimpleTrainer()
    tf.reset_default_graph()
    x_batch, y_batch = dataIngest()
    logits = stackedLstm_network(x_batch)
    trainer(logits, y_batch, optimizer='Nesterov', learningRate =
            config.LEARNING_RATE, momentum=config.MOMENTUM)
    # Initialize the variables (i.e. assign their default value)
    init = tf.global_variables_initializer()
    lossOp = trainer.getLossOp()
    trainOp = trainer.getTrainOp()
    accOp = trainer.getAccuracyOp()
    sess = tf.Session()
    kp0 = stackedLstm_network.keep_prob0
    kp1 = stackedLstm_network.keep_prob1
    feed_dict_train = {kp0: config.KEEP_PROB0, kp1: config.KEEP_PROB1}
    feed_dict_test = {kp0:1.00 , kp1: 1.0}
    # Train
    sess.run(init)
    dataIngest.runInitializer(sess, x_new, y_new, config.BATCH_SIZE,
                              config.NUM_EPOCHS)
    currBatch = 0
    batchPerEpoch = int(len(x_new) / config.BATCH_SIZE)
    while True:
        try:
            if currBatch % 1000 == 0:
                epoch = int(currBatch / batchPerEpoch)
                batch = currBatch % batchPerEpoch
                _, loss, acc = sess.run([trainOp, lossOp, accOp],
                                        feed_dict=feed_dict_train)
                print('Epoch: %5d batch: %5d (%7d) loss: %2.5f acc: %2.5f' %
                      (epoch, batch, currBatch, loss, acc))
            else:
                sess.run([trainOp], feed_dict=feed_dict_train)
            currBatch += 1
        except tf.errors.OutOfRangeError:
            break
    # Run inference on the entire train set and validation set
    result = runOpList(sess, dataIngest, logits, x_new, y_new, batchSize=1000,
                       feed_dict=feed_dict_test)
    result = np.concatenate(result)
    predictions = np.argmax(result, axis=1)
    target = np.argmax(y_new, axis=1)
    print(predictions.shape, target.shape)
    acc = np.mean((predictions == target).astype(int))
    print("Train set accuracy: ", acc)

    result = runOpList(sess, dataIngest, logits, x_val_new, y_val_new,
                       batchSize=1000, feed_dict=feed_dict_test)
    result = np.concatenate(result)
    predictions = np.argmax(result, axis=1)
    target = np.argmax(y_val_new, axis=1)
    print(predictions.shape, target.shape)
    acc = np.mean((predictions == target).astype(int))
    print("Validation set accuracy: ", acc)
    # Dump graph
    saver = tf.train.Saver(save_relative_paths=True)
    graphManager = GraphManager()
    graphManager.checkpointModel(saver, sess, config.CHECKPOINT_DIR)


main()
