import numpy as np
import pandas as pd
import os
import tensorflow as tf

def runOpList(sess, dataIngest, opList, X, Y, batchSize, epochs=1,
              feed_dict=None):
    '''
    Run opList. Session is assumed to contain a graph.
    '''
    dataIngest.runInitializer(sess, X, Y, batchSize, epochs)
    ret = []
    while True:
        try:
            resultList = sess.run(opList, feed_dict=feed_dict)
            ret.append(resultList)
        except tf.errors.OutOfRangeError:
            break
    return ret

def to_onehot(indices, numClasses):
    assert indices.ndim == 1
    n = max(indices) + 1
    assert numClasses <= n
    b = np.zeros((len(indices), numClasses))
    b[np.arange(len(indices)), indices] = 1
    return b

def unison_shuffled_copies(elems):
    assert len(elems) > 1
    length = len(elems[0])
    p = np.random.permutation(length)
    shuffled = []
    for i in range(len(elems)):
        curr = elems[i]
        assert len(curr) == length
        shuffled.append(curr[p])
    return shuffled

def getPrecisionRecall(cmatrix, label=1):
    trueP = cmatrix[label][label]
    denom = np.sum(cmatrix, axis=0)[label]
    if denom == 0:
        denom = 1
    recall = trueP / denom
    denom = np.sum(cmatrix, axis=1)[label]
    if denom == 0:
        denom = 1
    precision = trueP / denom
    return precision, recall

def getMacroPrecisionRecall(cmatrix):
    # TP + FP
    precisionlist = np.sum(cmatrix, axis=1)
    # TP + FN
    recalllist = np.sum(cmatrix, axis=0)
    precisionlist__ = [cmatrix[i][i]/ x if x!= 0 else 0 for i,x in enumerate(precisionlist)]
    recalllist__ = [cmatrix[i][i]/x if x!=0 else 0 for i,x in enumerate(recalllist)]
    precision = np.sum(precisionlist__)
    precision /= len(precisionlist__)
    recall = np.sum(recalllist__)
    recall /= len(recalllist__)
    return precision, recall

def getMicroPrecisionRecall(cmatrix):
    # TP + FP
    precisionlist = np.sum(cmatrix, axis=1)
    # TP + FN
    recalllist = np.sum(cmatrix, axis=0) 
    num =0.0
    for i in range(len(cmatrix)):
        num += cmatrix[i][i]

    precision = num / np.sum(precisionlist)
    recall = num / np.sum(recalllist)
    return precision, recall

def getMacroMicroFScore(cmatrix):
    '''
    Returns macro and micro f-scores.
    Refer: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf
    '''
    precisionlist = np.sum(cmatrix, axis=1)
    recalllist = np.sum(cmatrix, axis=0) 
    precisionlist__ = [cmatrix[i][i]/ x if x != 0 else 0 for i,x in enumerate(precisionlist)]
    recalllist__ = [cmatrix[i][i]/x if x != 0 else 0 for i,x in enumerate(recalllist)]
    macro = 0.0
    for i in range(len(precisionlist)):
        denom = precisionlist__[i] + recalllist__[i]
        numer = precisionlist__[i] * recalllist__[i] * 2
        if denom == 0:
            denom = 1
        macro += numer / denom
    macro /= len(precisionlist)

    num = 0.0
    for i in range(len(precisionlist)):
        num += cmatrix[i][i]

    denom1 = np.sum(precisionlist)
    denom2 = np.sum(recalllist)
    pi = num / denom1
    rho = num / denom2
    denom = pi + rho
    if denom == 0:
        denom = 1
    micro = 2 * pi * rho / denom
    return macro, micro

def getConfusionMatrix(predicted, target, numClasses):
    '''
    confusion[i][j]: Number of elements of class j
        predicted as class i
    '''
    assert(predicted.ndim == 1)
    assert(target.ndim == 1)
    arr = np.zeros([numClasses, numClasses])
    for i in range(len(predicted)):
        arr[predicted[i]][target[i]] += 1
    return arr

def getLengthScores(Y_predicted, val=1):
    '''
    Returns an matrix which contains the length of the longest positive
    subsequence of val ending at that index.
    Y_predicted: [-1, numSubinstance] Is the instance level class
        labels.
    '''
    scores = np.zeros(Y_predicted.shape)
    for i, bag in enumerate(Y_predicted):
        for j, instance in enumerate(bag):
            prev = 0
            if j > 0:
                prev = scores[i, j-1]
            if instance == val:
                scores[i, j] = prev + 1
            else:
                scores[i, j] = 0
    return scores

def bagStats(Y_predicted, Y_bag, minSubsequenceLen = 4, numClass=2,
             redirFile = None):
    '''
    Returns bag level statistics given instance level predictions

    A bag is considered to belong to a non-zero class if
    minSubsequenceLen is satisfied. Otherwise, it is assumed
    to belong to class 0. class 0 is negative by default.

    Y_predicted is the predicted instance level results
    [-1, numsubinstance]
    Y bag is the correct bag level prediction
    [-1]
    '''
    assert(Y_predicted.ndim == 2)
    assert(Y_bag.ndim == 1)

    scoreList = []
    for x in range(1, numClass):
        scores = getLengthScores(Y_predicted, val=x)
        length = np.max(scores, axis=1)
        scoreList.append(length)
    scoreList = np.array(scoreList)
    scoreList = scoreList.T
    assert(scoreList.ndim == 2)
    assert(scoreList.shape[0] == Y_predicted.shape[0])
    assert(scoreList.shape[1] == numClass - 1)
    length = np.max(scoreList, axis=1)
    assert(length.ndim == 1)
    assert(length.shape[0] == Y_predicted.shape[0])
    predictionIndex = (length >= minSubsequenceLen)
    prediction = np.zeros((Y_predicted.shape[0]))
    labels = np.argmax(scoreList, axis=1) + 1
    prediction[predictionIndex] = labels[predictionIndex]
    assert(len(Y_bag) == len(prediction))
    correct = (prediction == Y_bag).astype('int')
    acc = np.mean(correct)
    prediction = prediction.astype('int')
    cmatrix = getConfusionMatrix(prediction, Y_bag, numClass)
    return acc, cmatrix

def analysisModelMultiClass(predictions, Y_bag, numSubinstance,
                            numClass, redirFile=None, verbose=False,
                            silent=False):
    '''
    some basic analysis on predictions and true labels
    This is the multiclass version
    predictions [-1, numsubinstance] is the instance level prediction
    trueLabels [-1, numsubinstance] is the instance level true label
        This is used as bagLabel if bag labels no provided.
    verbose: Prints verbose data frame. Includes additionally, precision
        and recall information.
    In the 2 class setting, precision, recall and f-score for
    class 1 is also printed.
    '''
    if silent is True:
        redirFile = open(os.devnull, 'w')
    assert (predictions.ndim == 2)
    assert (predictions.shape[1] == numSubinstance)
    assert (Y_bag.ndim == 1)
    assert (len(Y_bag) == len(predictions))
    pholder = [0.0] * numSubinstance
    df = pd.DataFrame()
    df['len'] = np.arange(1, numSubinstance + 1)
    df['acc'] = pholder
    df['macro-fsc'] = pholder
    df['macro-pre'] = pholder
    df['macro-rec'] = pholder

    df['micro-fsc'] = pholder
    df['micro-pre'] = pholder
    df['micro-rec'] = pholder
    colList = []
    colList.append('acc') 
    colList.append('macro-fsc')
    colList.append('macro-pre')
    colList.append('macro-rec')

    colList.append('micro-fsc')
    colList.append('micro-pre')
    colList.append('micro-rec')
    for i in range(0, numClass):
        pre = 'pre_%02d' % i
        rec = 'rec_%02d' % i
        df[pre] = pholder
        df[rec] = pholder
        colList.append(pre)
        colList.append(rec)

    for i in range(1, numSubinstance + 1):
        trueAcc, cmatrix = bagStats(predictions, Y_bag, numClass=numClass,
                                    minSubsequenceLen=i, redirFile = redirFile)
        df.iloc[i-1, df.columns.get_loc('acc')] = trueAcc

        macro, micro = getMacroMicroFScore(cmatrix)
        df.iloc[i-1, df.columns.get_loc('macro-fsc')] = macro
        df.iloc[i-1, df.columns.get_loc('micro-fsc')] = micro

        pre, rec = getMacroPrecisionRecall(cmatrix)
        df.iloc[i-1, df.columns.get_loc('macro-pre')] = pre
        df.iloc[i-1, df.columns.get_loc('macro-rec')] = rec

        pre, rec = getMicroPrecisionRecall(cmatrix)
        df.iloc[i-1, df.columns.get_loc('micro-pre')] = pre
        df.iloc[i-1, df.columns.get_loc('micro-rec')] = rec
        for j in range(numClass):
            pre, rec = getPrecisionRecall(cmatrix, label=j)
            pre_ = df.columns.get_loc('pre_%02d' % j)
            rec_ = df.columns.get_loc('rec_%02d' % j)
            df.iloc[i-1, pre_ ] = pre
            df.iloc[i-1, rec_ ] = rec

    df.set_index('len')
    # Comment this line to include all columns
    colList = ['len', 'acc', 'macro-fsc', 'macro-pre', 'macro-rec']
    colList += ['micro-fsc', 'micro-pre', 'micro-rec']
    if verbose:
        for col in df.columns:
            if col not in colList:
                colList.append(col)
    if numClass == 2:
        precisionList = df['pre_01'].values
        recallList = df['rec_01'].values
        denom = precisionList + recallList
        denom[denom == 0] = 1
        numer = 2 * precisionList * recallList
        f_ = numer / denom
        df['fscore_01'] = f_
        colList.append('fscore_01')

    df = df[colList]
    with pd.option_context('display.max_rows', 100,
                           'display.max_columns', 100,
                           'expand_frame_repr', True):
        print(df, file=redirFile)

    idx = np.argmax(df['acc'].values)
    val = np.max(df['acc'].values)
    print("Max accuracy %f at subsequencelength %d" % (val, idx + 1),
          file=redirFile)
    val = np.max(df['micro-fsc'].values)
    idx = np.argmax(df['micro-fsc'].values) 
    print("Max micro-f %f at subsequencelength %d" % (val, idx + 1),
          file=redirFile)
    val = df['micro-pre'].values[idx]
    print("Micro-precision %f at subsequencelength %d" % (val, idx + 1),
          file=redirFile)
    val = df['micro-rec'].values[idx]
    print("Micro-recall %f at subsequencelength %d" % (val, idx + 1),
          file=redirFile)

    idx = np.argmax(df['macro-fsc'].values)
    val = np.max(df['macro-fsc'].values)
    print("Max macro-f %f at subsequencelength %d" % (val, idx + 1),
          file=redirFile)
    val = df['macro-pre'].values[idx]
    print("macro-precision %f at subsequencelength %d" % (val, idx + 1),
          file=redirFile)
    val = df['macro-rec'].values[idx]
    print("macro-recall %f at subsequencelength %d" % (val, idx + 1),
          file=redirFile)
    if numClass == 2 and verbose:
        idx = np.argmax(df['fscore_01'].values)
        val = np.max(df['fscore_01'].values)
        print('Max fscore %f at subsequencelength %d' % (val, idx + 1),
              file=redirFile)
        print('Precision %f at subsequencelength %d' %
              (df['pre_01'].values[idx], idx + 1), file=redirFile)
        print('Recall %f at subsequencelength %d' % (df['rec_01'].values[idx],
                                                     idx + 1), file=redirFile)
    if silent is True:
        redirFile.close()
    return df

class GraphManager:
    '''
    Manages saving and restoring graphs. Designed to be used with EMI-RNN
    though is general enough to be useful otherwise as well.
    '''

    def __init__(self):
        pass

    def checkpointModel(self, saver, sess, modelPrefix, globalStep=1000, silent=False):
        saver.save(sess, modelPrefix, global_step=globalStep)
        if not silent:
            print('Model saved to %s, global_step %d' % (modelPrefix, globalStep))

    def loadCheckpoint(self, sess, modelPrefix, globalStep):
        metaname = modelPrefix + '-%d.meta' % globalStep
        basename = os.path.basename(metaname)
        fileList = os.listdir(os.path.dirname(modelPrefix))
        fileList = [x for x in fileList if x.startswith(basename)]
        assert len(fileList) > 0, 'Checkpoint file not found'
        msg = 'Too many or too few checkpoint files for globalStep: %d' % globalStep
        assert len(fileList) is 1, msg
        chkpt = basename + '/' + fileList[0]
        saver = tf.train.import_meta_graph(metaname)
        metaname = metaname[:-5]
        saver.restore(sess, metaname)
        graph = tf.get_default_graph()
        return graph

def flatten_features(X, Y, group=8, mode='FC', groupStride=None):
    '''
    Mode:
        'RNN'     : The time step dimension will be retained. Returns 3D
        'FC'      : The time step dimension will be flattened. Returns 2D
        'Stacked' : The time step dimension and group dimension will be
                    retained. Returns 4D.
    group
    groupStride   : In terms of number of steps
    '''
    assert mode in ['FC', 'RNN', 'Stacked']
    assert X.ndim == 3
    assert Y.ndim == 2
    assert group in np.arange(1, X.shape[1]+1)
    assert len(X) == len(Y)
    x_new, y_new = [], []
    r = X.shape[1] % group
    # We don't want incomplete groups
    if r != 0:
        X = X[:, :-r, :]

    if groupStride is None:
        groupStride = group
    else:
        groupStride = int(groupStride)
        assert groupStride > 1

    for i in range(len(X)):
        label = Y[i]
        instance = X[i]
        start = 0
        end = group
        x_group_new, y_group_new = [], []
        while end <= len(instance):
            x = instance[start:end, :]
            assert len(x) == group
            if mode is 'FC':
                x = np.reshape(x, [-1])
            x_group_new.append(x), y_group_new.append(label)
            start = start + groupStride
            end = end + groupStride
            # print(start, end, instance.shape)
        x_group_new = np.array(x_group_new)
        y_group_new = np.array(y_group_new)
        x_new.append(x_group_new), y_new.append(y_group_new)
    x_new = np.array(x_new)
    y_new = np.array(y_new)
    if mode == 'RNN':
        assert x_new.ndim == 4
        x_new = np.reshape(x_new, [-1, x_new.shape[2], x_new.shape[3]])
        y_new = np.reshape(y_new, [-1, y_new.shape[2]])
    return x_new, y_new

def convertToDetectionLabels(Y, label):
    '''
    Takes the 2D one hot encoded labels vector and converts it into a binary
    one hot encoded vector that indicates the presence or absence of label
    '''
    assert Y.ndim == 2
    bag = np.argmax(Y, axis=1)
    bag = (bag == label).astype(int)
    new_Y = np.zeros([len(bag), 2])
    new_Y[np.arange(len(bag)), bag] = 1
    return new_Y

def softmax(x, axis=1):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    div = np.sum(e_x, axis=axis, keepdims=True)
    return e_x / div

def sigmoid(x):
    return 1.0 / (1 + np.exp(-x).astype(np.float32))
