# Usage: python X.py [LABELS_DIR]
#
# Due to the legacy nature of the MIL code, we need to do some trickery here -
# some dirty tricks to get this to work. 
# 1. Use the legacy MIL code to get labels. Go to mil/Google where i've created
#   label extraction notebook.
# 2. Run this script on those labels

import numpy as np
import sys
import argparse
import time

import divideandconquer.common.utils as utils
from divideandconquer.common.dataprocessor_mil import GoogleDataMIL

class ConfigOptions:
    def __init__(self):
        self.LABELS_DIR = None
        self.TOP_K = None
        self.TOP_K_SHRINK = None
        self.BOTTOM_K_SHRINK = None
        self.NUM_SUBINSTANCE = None
        self.NUM_OUTPUT = None
        self.OUT_DIR = None
        self.IN_PREFIX = None
        self.NEG_FRACTION = None

        self.__parsed = None
        parser = argparse.ArgumentParser(description='Configuration Options')
        parser.add_argument('-d', '--label-dir', required=True,
                            help='Directory containing predicted labels and \
                            bag labels')
        parser.add_argument('-o', '--out-dir', required=True,
                            help='Directory to dump data to')
        parser.add_argument('-i', '--in-dir', required=True,
                            help='Directory/prefix to read MIL data from')
        parser.add_argument('-t', '--top-k', required=False, default=None,
                            help='The value of k for top-k for each \
                            sub-instance', type=int)
        parser.add_argument('-l', '--shrink-bottom-k', required=False, default=2,
                            help='Used when extracting negative sub-instances.\
                            Everything outside top-k +- margin will be\
                            considered negatives.', type=int)
        parser.add_argument('-s', '--shrink-top-k', required=False, default=2,
                            type=int, help="Used when extracting negative sub-\
                            instances. Everything inside [top-k - shrink] will\
                            be taken if top-k is satisfied after shrinking.")
        parser.add_argument('-nn', '--no-negatives-from-positive',
                            dest='noneg', action='store_true', default=False,
                            help="Exclude negative instances drawn from\
                            positive bags")
        parser.add_argument('-nf', '--neg-fraction', default=1.0,
                            help="Fraction of negatives to sample from each\
                            bag", type=float)
        self.__parsed = parser.parse_args()

    def configure(self):
        psd = self.__parsed
        self.LABELS_DIR = psd.label_dir
        self.OUT_DIR = psd.out_dir
        self.IN_PREFIX = psd.in_dir
        self.TOP_K = psd.top_k
        self.BOTTOM_K_SHRINK = psd.shrink_bottom_k
        self.TOP_K_SHRINK = psd.shrink_top_k
        self.NO_NEG = psd.noneg
        self.NEG_FRACTION = psd.neg_fraction

    def echo(self):
        print('Labels Dir: ', self.LABELS_DIR)
        print('Output Dir: ', self.OUT_DIR)
        print("MIL in Dir: ", self.IN_PREFIX)
        print('Top-k: ', self.TOP_K)
        print('Bottom shrink amount : ', self.BOTTOM_K_SHRINK)
        print('Top-k shrink amount : ', self.TOP_K_SHRINK)
        print("Num subinstance: ", self.NUM_SUBINSTANCE)
        print("Num output: ", self.NUM_OUTPUT)
        print("Exclude negative from positives: ", self.NO_NEG)
        print("Negative sampling fraction: ", self.NEG_FRACTION)

def load_labels_files_bag(labelsDir, name):
    assert name in ['train', 'test', 'val']
    lab_= np.load(labelsDir + 'lab_' + name + '.npy')
    fil_= np.load(labelsDir + 'fil_' + name + '.npy')
    bag_= np.load(labelsDir + 'bag_' + name + '.npy')
    return lab_, fil_, bag_

def extractNonNegativeIndices(predictedLabels, bag, minSubsequenceLen,
                                   numClass):
    '''
    Given the predicted labels of shape
    ¦   [-1, numSubinstance]
    and the value for top-k, this method returns the instances that
    constitute the non-negative set.

    The non-negative consists of the subset of instances in the positive set
    that was part of correctly predicted bags. Incorrectly predicted bags are
    excluded.

    Y bag is the correct bag level prediction. Shape [-1]

    Returns list of  (prediction, length, start, end). End is exclusive. If
    prediction is zero, it means there was a miss classification or it was a
    negative instance. Length, start and end values for such tuples is that of
    the longest class found.
    '''
    assert(predictedLabels.ndim == 2)
    assert(bag.ndim == 1)
    assert(np.issubdtype(predictedLabels.dtype.type, np.integer))
    assert(np.issubdtype(bag.dtype.type, np.integer))
    scoreList = []
    indexList = []
    for x in range(1, numClass):
        scores = utils.getLengthScores(predictedLabels, val=x)
        length = np.max(scores, axis=1)
        indices = np.argmax(scores, axis=1)
        scoreList.append(length)
        indexList.append(indices)
    scoreList = np.array(scoreList)
    scoreList = scoreList.T
    indexList = np.array(indexList)
    indexList = indexList.T
    assert(scoreList.ndim == 2)
    assert(scoreList.shape[0] == predictedLabels.shape[0])
    assert(scoreList.shape[1] == numClass - 1)
    assert(indexList.shape[0] == scoreList.shape[0])
    assert(indexList.shape[1] == scoreList.shape[1])
    length = np.max(scoreList, axis=1)
    labels = np.argmax(scoreList, axis=1) + 1
    x = np.arange(len(scoreList))
    y = np.argmax(scoreList, axis=1)
    indices = indexList[x, y]
    assert(length.ndim == 1)
    assert(labels.ndim==1)
    assert(indices.ndim ==1)
    assert(indices.shape[0] == length.shape[0])
    assert(labels.shape[0] == length.shape[0])
    predictionIndex = (length >= minSubsequenceLen)
    prediction = np.zeros(len(labels))
    prediction[predictionIndex] = labels[predictionIndex]
    start = indices.astype(int) - length.astype(int) + 1
    end = indices.astype(int) + 1
    z = zip(prediction.astype(int), length.astype(int),
            start.astype(int), end.astype(int))
    return list(z)

def createNonNegativeData(nonNegIndices, X, bag, files, name, config):
    '''
    nonNegIndices; Output of extractNonNegativeIndices
    outDir: Directory where files will be dumped
    name: ['train', 'test', 'val']

    Using the non-negative indices, extracts positive data out of X and dumps
    onto disk.
    '''
    outDir = config.OUT_DIR
    assert name in ['train', 'test', 'val']
    assert len(X) == len(files)
    assert X.ndim == 4
    new_x, new_y, new_files = [], [], []
    for i, val in enumerate(nonNegIndices):
        label, length, start, end = val
        bag_label = bag[i]
        # ignore negatives
        if bag_label == 0:
            continue
        # ignore miss-classifications
        if label != bag_label:
            continue
        y_ = label
        file_ = files[i]
        if length - 2 * config.TOP_K_SHRINK >= config.TOP_K:
            start += config.TOP_K_SHRINK
            end -= config.TOP_K_SHRINK
        for j in range(start, end):
            new_x.append(X[i, j, :, :])
            new_y.append(y_)
            new_files.append(file_)
    new_x = np.array(new_x)
    new_y = np.array(new_y)
    new_files = np.array(new_files)
    np.save(outDir + 'pos_x_' + name, new_x)
    np.save(outDir + 'pos_y_' + name, new_y)
    print('Pos y bincount: ', np.bincount(new_y))
    np.save(outDir + 'pos_files_' + name, new_files)
    print('x ', new_x.shape)
    print('y ', new_y.shape)
    print('f ', new_files.shape)

def createNegativeData(nonNegIndices, X, bag, files, name, config):
    '''
    nonNegIndices; Output of extractNonNegativeIndices
    outDir: Directory where files will be dumped
    name: ['train', 'test', 'val']

    Using the non-negative indices, extracts negative data out of X and dumps
    onto disk.
    '''
    outDir = config.OUT_DIR
    assert name in ['train', 'test', 'val']
    assert len(X) == len(files)
    assert X.ndim == 4
    new_x, new_y, new_files = [], [], []
    count = 0
    count2 = 0
    numNegSamples = int(config.NEG_FRACTION * config.NUM_SUBINSTANCE)
    if numNegSamples <= 0:
        print("WARNING: Sample fraction provided is too low. Will sample one\
              instancinstance per bag")
        numNegSamples = 1

    for i, val in enumerate(nonNegIndices):
        label, length, start, end = val
        y_ = 0
        file_ = files[i]
        bag_label = bag[i]
        # miss classified
        if label != bag_label:
            continue
        # Negatives
        if label == 0:
            end = X.shape[1]
            start = 0
            indices = np.arange(start, end)
            indices = np.random.choice(indices, size=numNegSamples)
            for j in indices:
                count += 1
                new_x.append(X[i, j, :, :])
                new_y.append(y_)
                new_files.append(file_)
            continue
        # Correctly classified positive instance
        if config.NO_NEG:
            continue
        for j in range(0, start - config.BOTTOM_K_SHRINK):
            count2 += 1
            new_x.append(X[i, j, :, :])
            new_y.append(y_)
            new_files.append(file_)

        for j in range(end + config.BOTTOM_K_SHRINK, X.shape[1]):
            count2 += 1
            new_x.append(X[i, j, :, :])
            new_y.append(y_)
            new_files.append(file_)
    new_x = np.array(new_x)
    new_y = np.array(new_y)
    new_files = np.array(new_files)
    np.save(outDir + 'neg_x_' + name, new_x)
    np.save(outDir + 'neg_y_' + name, new_y)
    np.save(outDir + 'neg_files_' + name, new_files)
    print('x ', new_x.shape)
    print('y ', new_y.shape)
    print('f ', new_files.shape)
    print("Negatives from negatives", count)
    print("Negatives from positives", count2)

def processData(X, Y, labels, files, bag, config, name):
    '''
    Extracts the data and creates the train,test, val files
    '''
    assert name in ['train', 'test', 'val']
    assert X.ndim == 4
    assert Y.ndim == 3
    assert labels.ndim == 3
    assert files.ndim == 1
    assert bag.ndim == 1

    # STEP 1: Extract non-negative subinstances and dump them
    print("Positive")
    class_ = np.argmax(labels, axis=-1)
    nonneg = extractNonNegativeIndices(class_, bag, config.TOP_K,
                                       config.NUM_OUTPUT)
    createNonNegativeData(nonneg, X, bag, files, name=name,
                          config=config)
    # STEP 2; Extract negative subinstances and dump them
    print("Negative")
    createNegativeData(nonneg, X, bag, files, name=name,
                       config=config)
    # STEP 3: Combine, join, create train dataset
    def load_concat(files):
        x1 = []
        for f in files:
            x1.append(np.load(config.OUT_DIR + f))
        return np.concatenate(x1, axis=0)

    x_files = ['pos_x_%s.npy' % name, 'neg_x_%s.npy' % name]
    x_ = load_concat(x_files)
    y_files = ['pos_y_%s.npy' % name, 'neg_y_%s.npy' % name]
    y_ = load_concat(y_files)
    files_ = load_concat(['pos_files_%s.npy' % name, 'neg_files_%s.npy' % name])
    assert len(x_) == len(y_)
    assert len(files_) == len(y_)
    group = (x_,  y_, files_)
    x_, y_, files_ = utils.unison_shuffled_copies(group)
    print("Train bincount ", end='')
    print(np.bincount(y_))
    y_ = utils.to_onehot(y_, config.NUM_OUTPUT)
    np.save(config.OUT_DIR + 'x_%s' % name, x_)
    np.save(config.OUT_DIR + 'y_%s' % name, y_)
    np.save(config.OUT_DIR + 'files_%s' % name, files_)
    print("Final stats")
    print("x :", x_.shape)
    print("y :", y_.shape)
    print("files: ", files_.shape)

def main():
    config = ConfigOptions()
    config.configure()

    # Load data and echo some stats. This will be useful for debugging
    lab_train, fil_train, bag_train = load_labels_files_bag(config.LABELS_DIR, 'train')
    lab_test, fil_test, bag_test = load_labels_files_bag(config.LABELS_DIR, 'test')
    lab_val, fil_val, bag_val = load_labels_files_bag(config.LABELS_DIR, 'val')
    print('Train: label - ', lab_train.shape, ' file - ', fil_train.shape,
          ' bag - ', bag_train.shape)
    print('Test: label - ', lab_test.shape, ' file - ', fil_test.shape,
          ' bag - ', bag_test.shape)
    print('Val: label - ', lab_val.shape, ' file - ', fil_val.shape,
          ' bag - ', bag_val.shape)
    config.NUM_SUBINSTANCE = lab_train.shape[1]
    config.NUM_OUTPUT = lab_train.shape[2]
    predictions = np.argmax(lab_test, axis=-1)
    df = utils.analysisModelMultiClass(predictions, bag_test,
                                       config.NUM_SUBINSTANCE,
                                       config.NUM_OUTPUT, silent=True)
    acc, ssl = np.max(df.acc.values), np.argmax(df.acc.values) + 1
    print("Max accuracy is %f @ssl %d" % (acc, ssl))
    print()
    config.TOP_K = ssl
    config.echo()

    # Load the data
    start = time.time()
    dataP = GoogleDataMIL()
    ret = dataP.getData(config.IN_PREFIX, seed=42)
    x_train, y_train, x_val, y_val, files_train, files_val = ret
    x_test, y_test, files_test = dataP.getTest(config.IN_PREFIX, seed=42)
    end = time.time()
    print("Data loaded in %ds" % (end - start))
    bag_train_ = np.argmax(y_train[:, 0, :], axis=1)
    bag_test_ = np.argmax(y_test[:, 0, :], axis=1)
    bag_val_ = np.argmax(y_val[:, 0, :], axis=1)
    assert np.array_equal(bag_test_, bag_test)
    assert np.array_equal(bag_train_, bag_train)
    assert np.array_equal(bag_val_, bag_val)

    print()
    print("Processing train data")
    print(bag_train.shape)
    processData(x_train, y_train, lab_train, fil_train, bag_train, config,
                name='train')
    print()
    print("Processing val data")
    processData(x_val, y_val, lab_val, fil_val, bag_val, config,
                name='val')
    print()
    print("Processing test data")
    processData(x_test, y_test, lab_test, fil_test, bag_test, config,
                name='test')

main()

