import sys
import numpy as np
np.set_printoptions(threshold=np.inf)
import csv
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, balanced_accuracy_score

featuresets = ["COOS7_deeploc_features/fc_2",
               "COOS7_self_supervised_features/conv3",
               "COOS7_self_supervised_features/conv4",
               "COOS7_self_supervised_features/conv5",
               "COOS7_texture_features",
               "COOS7_vgg16_features/block4_conv2",
               "COOS7_vgg16_features/block4_conv1",
               "COOS7_vgg16_features/block3_conv3"]
basedir = "./"

for featureset in featuresets:
    print ("Loading files")
    training_file = basedir + "/" + featureset + "/" + "train.txt"
    training_cells = np.array([row for row in csv.reader(open(training_file), delimiter="\t")])
    training_cells = training_cells[training_cells[:, 1].argsort()]
    training_labels = training_cells[:, 0]
    if training_cells[1][-1] == '':
        training_features = training_cells[:, 2:-1].astype(np.float32)
    else:
        training_features = training_cells[:, 2:].astype(np.float32)

    test1_file = basedir + "/" + featureset + "/" + "test1.txt"
    test1_cells = np.array([row for row in csv.reader(open(test1_file), delimiter="\t")])
    test1_labels = test1_cells[:, 0]
    if test1_cells[1][-1] == '':
        test1_features = test1_cells[:, 2:-1].astype(np.float32)
    else:
        test1_features = test1_cells[:, 2:].astype(np.float32)

    test2_file = basedir + "/" + featureset + "/" + "test2.txt"
    test2_cells = np.array([row for row in csv.reader(open(test2_file), delimiter="\t")])
    test2_labels = test2_cells[:, 0]
    test2_features = test2_cells[:, 2:-1].astype(np.float32)
    if test2_cells[1][-1] == '':
        test2_features = test2_cells[:, 2:-1].astype(np.float32)
    else:
        test2_features = test2_cells[:, 2:].astype(np.float32)

    test3_file = basedir + "/" + featureset + "/" + "new_test3.txt"
    test3_cells = np.array([row for row in csv.reader(open(test3_file), delimiter="\t")])
    test3_labels = test3_cells[:, 0]
    test3_features = test3_cells[:, 2:-1].astype(np.float32)
    if test3_cells[1][-1] == '':
        test3_features = test3_cells[:, 2:-1].astype(np.float32)
    else:
        test3_features = test3_cells[:, 2:].astype(np.float32)

    test4_file = basedir + "/" + featureset + "/" + "new_test4.txt"
    test4_cells = np.array([row for row in csv.reader(open(test4_file), delimiter="\t")])
    test4_labels = test4_cells[:, 0]
    test4_features = test4_cells[:, 2:-1].astype(np.float32)
    if test4_cells[1][-1] == '':
        test4_features = test4_cells[:, 2:-1].astype(np.float32)
    else:
        test4_features = test4_cells[:, 2:].astype(np.float32)

    print ("Scaling files")
    scaler = StandardScaler().fit(training_features)
    training_features = scaler.transform(training_features)
    test1_features = scaler.transform(test1_features)
    test2_features = scaler.transform(test2_features)
    test3_features = scaler.transform(test3_features)
    test4_features = scaler.transform(test4_features)

    le = LabelEncoder().fit(training_labels)
    label_map = (np.unique(training_labels), le.transform(np.unique(training_labels)))
    training_labels = le.transform(training_labels)
    test1_labels = le.transform(test1_labels)
    test2_labels = le.transform(test2_labels)
    test3_labels = le.transform(test3_labels)
    test4_labels = le.transform(test4_labels)

    print("Building random forest classifier")
    n_estimators = [int(x) for x in np.linspace(start=20, stop=200, num=10)]
    max_features = ['log2', 'sqrt']
    max_depth = [int(x) for x in np.linspace(1, 50, num=5)]
    max_depth.append(None)
    min_samples_split = [2, 5, 10, 20, 40]
    min_samples_leaf = [1, 2, 4, 8, 16]

    random_grid = {'n_estimators': [200],
                   'max_features': max_features,
                   'max_depth': max_depth,
                   'min_samples_split': min_samples_split,
                   'min_samples_leaf': min_samples_leaf,}

    rf = RandomForestClassifier()
    rf_random = RandomizedSearchCV(estimator=rf, param_distributions=random_grid, n_iter=100, cv=4, verbose=2,
                                   random_state=42, n_jobs=100, scoring='balanced_accuracy')
    rf_random.fit(training_features, training_labels)
    rfc = rf_random.best_estimator_

    outdir = "./final_rf_reports"
    output_name = outdir + "/" + featureset.replace("/", "_") + ".txt"
    sys.stdout = open(output_name, "a")
    print (label_map)
    print (rfc)

    print ("Evaluating on test sets")
    train_pred = rfc.predict(training_features)
    test1_pred = rfc.predict(test1_features)
    test2_pred = rfc.predict(test2_features)
    test3_pred = rfc.predict(test3_features)
    test4_pred = rfc.predict(test4_features)

    train_score = classification_report(training_labels, train_pred, digits=4)
    train_confusion = confusion_matrix(training_labels, train_pred)
    train_accuracy = balanced_accuracy_score(training_labels, train_pred)

    test1_score = classification_report(test1_labels, test1_pred, digits=4)
    test1_confusion = confusion_matrix(test1_labels, test1_pred)
    test1_accuracy = balanced_accuracy_score(test1_labels, test1_pred)

    test2_score = classification_report(test2_labels, test2_pred, digits=4)
    test2_confusion = confusion_matrix(test2_labels, test2_pred)
    test2_accuracy = balanced_accuracy_score(test2_labels, test2_pred)

    test3_score = classification_report(test3_labels, test3_pred, digits=4)
    test3_confusion = confusion_matrix(test3_labels, test3_pred)
    test3_accuracy = balanced_accuracy_score(test3_labels, test3_pred)

    test4_score = classification_report(test4_labels, test4_pred, digits=4)
    test4_confusion = confusion_matrix(test4_labels, test4_pred)
    test4_accuracy = balanced_accuracy_score(test4_labels, test4_pred)

    print ("TRAINING DATASET")
    print (train_score)
    print (train_confusion)
    print ("ACCURACY:", train_accuracy, "ERROR:", 1.0 - train_accuracy)

    print ("TEST DATASET 1")
    print (test1_score)
    print (test1_confusion)
    print("ACCURACY:", test1_accuracy, "ERROR:", 1.0 - test1_accuracy)

    print ("TEST DATASET 2")
    print (test2_score)
    print (test2_confusion)
    print("ACCURACY:", test2_accuracy, "ERROR:", 1.0 - test2_accuracy)

    print ("TEST DATASET 3")
    print (test3_score)
    print (test3_confusion)
    print("ACCURACY:", test3_accuracy, "ERROR:", 1.0 - test3_accuracy)

    print ("TEST DATASET 4")
    print (test4_score)
    print (test4_confusion)
    print("ACCURACY:", test4_accuracy, "ERROR:", 1.0 - test4_accuracy)

    sys.stdout = sys.__stdout__

    print("Building logistic regresion classifier")
    lr = LogisticRegression(penalty='l1', n_jobs=100, solver='saga', class_weight='balanced').fit(training_features, training_labels)
    outdir = "./final_lr_reports"
    output_name = outdir + "/" + featureset.replace("/", "_") + ".txt"
    sys.stdout = open(output_name, "a")
    print(label_map)

    print("Evaluating on test sets")
    train_pred = lr.predict(training_features)
    test1_pred = lr.predict(test1_features)
    test2_pred = lr.predict(test2_features)
    test3_pred = lr.predict(test3_features)
    test4_pred = lr.predict(test4_features)

    train_score = classification_report(training_labels, train_pred, digits=4)
    train_confusion = confusion_matrix(training_labels, train_pred)
    train_accuracy = balanced_accuracy_score(training_labels, train_pred)

    test1_score = classification_report(test1_labels, test1_pred, digits=4)
    test1_confusion = confusion_matrix(test1_labels, test1_pred)
    test1_accuracy = balanced_accuracy_score(test1_labels, test1_pred)

    test2_score = classification_report(test2_labels, test2_pred, digits=4)
    test2_confusion = confusion_matrix(test2_labels, test2_pred)
    test2_accuracy = balanced_accuracy_score(test2_labels, test2_pred)

    test3_score = classification_report(test3_labels, test3_pred, digits=4)
    test3_confusion = confusion_matrix(test3_labels, test3_pred)
    test3_accuracy = balanced_accuracy_score(test3_labels, test3_pred)

    test4_score = classification_report(test4_labels, test4_pred, digits=4)
    test4_confusion = confusion_matrix(test4_labels, test4_pred)
    test4_accuracy = balanced_accuracy_score(test4_labels, test4_pred)

    print("TRAINING DATASET")
    print(train_score)
    print(train_confusion)
    print("ACCURACY:", train_accuracy, "ERROR:", 1.0 - train_accuracy)

    print("TEST DATASET 1")
    print(test1_score)
    print(test1_confusion)
    print("ACCURACY:", test1_accuracy, "ERROR:", 1.0 - test1_accuracy)

    print("TEST DATASET 2")
    print(test2_score)
    print(test2_confusion)
    print("ACCURACY:", test2_accuracy, "ERROR:", 1.0 - test2_accuracy)

    print("TEST DATASET 3")
    print(test3_score)
    print(test3_confusion)
    print("ACCURACY:", test3_accuracy, "ERROR:", 1.0 - test3_accuracy)

    print("TEST DATASET 4")
    print(test4_score)
    print(test4_confusion)
    print("ACCURACY:", test4_accuracy, "ERROR:", 1.0 - test4_accuracy)

    sys.stdout = sys.__stdout__

for featureset in featuresets:
        print("Loading files")
        training_file = basedir + "/" + featureset + "/" + "train.txt"
        training_cells = np.array([row for row in csv.reader(open(training_file), delimiter="\t")])
        training_cells = training_cells[training_cells[:, 1].argsort()]
        training_labels = training_cells[:, 0]
        if training_cells[1][-1] == '':
            training_features = training_cells[:, 2:-1].astype(np.float32)
        else:
            training_features = training_cells[:, 2:].astype(np.float32)

        test1_file = basedir + "/" + featureset + "/" + "test1.txt"
        test1_cells = np.array([row for row in csv.reader(open(test1_file), delimiter="\t")])
        test1_labels = test1_cells[:, 0]
        if test1_cells[1][-1] == '':
            test1_features = test1_cells[:, 2:-1].astype(np.float32)
        else:
            test1_features = test1_cells[:, 2:].astype(np.float32)

        test2_file = basedir + "/" + featureset + "/" + "test2.txt"
        test2_cells = np.array([row for row in csv.reader(open(test2_file), delimiter="\t")])
        test2_labels = test2_cells[:, 0]
        test2_features = test2_cells[:, 2:-1].astype(np.float32)
        if test2_cells[1][-1] == '':
            test2_features = test2_cells[:, 2:-1].astype(np.float32)
        else:
            test2_features = test2_cells[:, 2:].astype(np.float32)

        test3_file = basedir + "/" + featureset + "/" + "new_test3.txt"
        test3_cells = np.array([row for row in csv.reader(open(test3_file), delimiter="\t")])
        test3_labels = test3_cells[:, 0]
        test3_features = test3_cells[:, 2:-1].astype(np.float32)
        if test3_cells[1][-1] == '':
            test3_features = test3_cells[:, 2:-1].astype(np.float32)
        else:
            test3_features = test3_cells[:, 2:].astype(np.float32)

        test4_file = basedir + "/" + featureset + "/" + "new_test4.txt"
        test4_cells = np.array([row for row in csv.reader(open(test4_file), delimiter="\t")])
        test4_labels = test4_cells[:, 0]
        test4_features = test4_cells[:, 2:-1].astype(np.float32)
        if test4_cells[1][-1] == '':
            test4_features = test4_cells[:, 2:-1].astype(np.float32)
        else:
            test4_features = test4_cells[:, 2:].astype(np.float32)

        print("Scaling files")
        scaler = StandardScaler().fit(training_features)
        training_features = scaler.transform(training_features)
        test1_features = scaler.transform(test1_features)
        test2_features = scaler.transform(test2_features)
        test3_features = scaler.transform(test3_features)
        test4_features = scaler.transform(test4_features)

        le = LabelEncoder().fit(training_labels)
        label_map = (np.unique(training_labels), le.transform(np.unique(training_labels)))
        training_labels = le.transform(training_labels)
        test1_labels = le.transform(test1_labels)
        test2_labels = le.transform(test2_labels)
        test3_labels = le.transform(test3_labels)
        test4_labels = le.transform(test4_labels)

        print("Building kNN classifier")
        knn = KNeighborsClassifier(n_neighbors=11, n_jobs=100).fit(training_features, training_labels)
        outdir = "./final_knn_reports"
        output_name = outdir + "/" + featureset.replace("/", "_") + ".txt"
        sys.stdout = open(output_name, "a")
        print(label_map)

        print("Evaluating on test sets")
        train_pred = knn.predict(training_features)
        test1_pred = knn.predict(test1_features)
        test2_pred = knn.predict(test2_features)
        test3_pred = knn.predict(test3_features)
        test4_pred = knn.predict(test4_features)

        train_score = classification_report(training_labels, train_pred, digits=4)
        train_confusion = confusion_matrix(training_labels, train_pred)
        train_accuracy = balanced_accuracy_score(training_labels, train_pred)

        test1_score = classification_report(test1_labels, test1_pred, digits=4)
        test1_confusion = confusion_matrix(test1_labels, test1_pred)
        test1_accuracy = balanced_accuracy_score(test1_labels, test1_pred)

        test2_score = classification_report(test2_labels, test2_pred, digits=4)
        test2_confusion = confusion_matrix(test2_labels, test2_pred)
        test2_accuracy = balanced_accuracy_score(test2_labels, test2_pred)

        test3_score = classification_report(test3_labels, test3_pred, digits=4)
        test3_confusion = confusion_matrix(test3_labels, test3_pred)
        test3_accuracy = balanced_accuracy_score(test3_labels, test3_pred)

        test4_score = classification_report(test4_labels, test4_pred, digits=4)
        test4_confusion = confusion_matrix(test4_labels, test4_pred)
        test4_accuracy = balanced_accuracy_score(test4_labels, test4_pred)

        print("TRAINING DATASET")
        print(train_score)
        print(train_confusion)
        print("ACCURACY:", train_accuracy, "ERROR:", 1.0 - train_accuracy)

        print("TEST DATASET 1")
        print(test1_score)
        print(test1_confusion)
        print("ACCURACY:", test1_accuracy, "ERROR:", 1.0 - test1_accuracy)

        print("TEST DATASET 2")
        print(test2_score)
        print(test2_confusion)
        print("ACCURACY:", test2_accuracy, "ERROR:", 1.0 - test2_accuracy)

        print("TEST DATASET 3")
        print(test3_score)
        print(test3_confusion)
        print("ACCURACY:", test3_accuracy, "ERROR:", 1.0 - test3_accuracy)

        print("TEST DATASET 4")
        print(test4_score)
        print(test4_confusion)
        print("ACCURACY:", test4_accuracy, "ERROR:", 1.0 - test4_accuracy)

        sys.stdout = sys.__stdout__