import argparse
import time
import logging
import sys
import os
from os import listdir
from os.path import isfile, join
import pickle
import copy
import numpy as np
import tensorflow as tf
import json
from nasbench import api

from params import *
from data import Data
from nas_bench.cell import Cell
from meta_neural_net import MetaNeuralnet
from acquisition_functions import acq_fn
#from darts.arch import Arch


# nasbench methods

def get_val_mean(metrics):
    data = [metrics[108][i]['final_validation_accuracy'] for i in range(3)]
    return np.mean(data)

def num_edges(matrix):
    return np.sum(matrix)

def num_vertices(matrix):
    return np.shape(matrix)[0]

# methods that iterate through nasbench

def get_top_k_matrix(nasbench, k=5, max_vertices=7, max_edges=9):
    # return a matrix of all the top k for edges=i, vertices=j
    # this was used to gather preliminary information, and is not being used
    # in the meta neural network experiment
    n = 0
    top_matrix = [[[{'mean':0, 'loss':1} for _ in range(k)] for _ in range(max_edges)] for _ in range(max_vertices)]

    for unique_hash in nasbench.hash_iterator():
        n += 1
        if n % 100000 == 0:
            print(n)

        fix, comp = nasbench.get_metrics_from_hash(unique_hash)
        mean = get_val_mean(comp)
        vertices = num_vertices(fix['module_adjacency'])
        edges = num_edges(fix['module_adjacency'])

        if mean > top_matrix[vertices-1][edges-1][0]['mean']:
            top_matrix[vertices-1][edges-1][0] = {'mean':mean,
                                                  'loss':np.round(1-mean, 4),
                                                   'matrix':fix['module_adjacency'], 
                                                   'ops':fix['module_operations'],
                                                   'hash':unique_hash,
                                                   'vertices':vertices, 
                                                   'edges':edges}         
            top_matrix[vertices-1][edges-1].sort(key=lambda i:i['mean'])

    print('done generating top k matrix')

    return top_matrix

def get_top_k(nasbench, k=1000, max_vertices=7, max_edges=9):
    # return a list of the top k arches with vertex, edge constraints
    n = 0
    top_k = [{'mean':0, 'loss':1} for _ in range(k)]

    for unique_hash in nasbench.hash_iterator():
        n += 1
        if n % 100000 == 0:
            print(n)

        fix, comp = nasbench.get_metrics_from_hash(unique_hash)
        mean = get_val_mean(comp)
        vertices = num_vertices(fix['module_adjacency'])
        edges = num_edges(fix['module_adjacency'])

        if (vertices <= max_vertices and edges <= max_edges 
            and mean > top_k[0]['mean']):

            # todo insert and then pop the last
            top_k[0] = {'mean':mean,
                          'loss':np.round(1-mean, 4),
                           'matrix':fix['module_adjacency'], 
                           'ops':fix['module_operations'],
                           'hash':unique_hash,
                           'vertices':vertices, 
                           'edges':edges}         

            top_k.sort(key=lambda i:i['mean'])

    print('done generating top k')

    return top_k

def get_all_constrained(nasbench, 
                        min_vertices=3, 
                        max_vertices=7, 
                        min_edges=7, 
                        max_edges=9):
    # return a list of all arches with vertex, edge constraints
    n = 0
    arches = []

    for unique_hash in nasbench.hash_iterator():
        n += 1
        if n % 100000 == 0:
            print(n)

        fix, comp = nasbench.get_metrics_from_hash(unique_hash)
        vertices = num_vertices(fix['module_adjacency'])
        edges = num_edges(fix['module_adjacency'])

        if (vertices >= min_vertices and vertices <= max_vertices 
            and edges >= min_edges and edges <= max_edges):
            arch = {'matrix':fix['module_adjacency'], 
                    'ops':fix['module_operations'],
                    'hash':unique_hash,
                    'vertices':vertices,
                    'edges':edges}
            arches.append(arch)

    print('done generating all constrained')
    return arches


def prune_by_distance(xtrain, candidates, 
                        sort_by=5, 
                        top_closest=1000, 
                        test_size=10000,
                        distance='path'):
    # prune the candidate set to contain only the ones that are closest to arches in the training set
    # note: this takes a long time to run
    dists = []
    for i, candidate in enumerate(candidates):
        if i % 1000 == 0:
            print('prune by distance', i, 'of', len(candidates))
        if distance == 'adjacency':
            encoding = candidate['adjacency']
        elif distance == 'path':
            encoding = candidate['path']
        else:
            encoding = candidate['encoding']
        closest_values = [100 for _ in range(sort_by)]

        # for every candidate, find the top 5 closest arches from xtrain
        for j in range(top_closest):
            dist = np.sum(xtrain[j] != encoding)
            if dist < closest_values[-1]:
                closest_values[-1] = dist
                if sort_by > 1:
                    closest_values.sort()
        dists.append(np.mean(np.array(closest_values)))

    sorted_indices = np.argsort(dists)

    print('min dists to xtrain', [dists[i] for i in sorted_indices[:30]])

    return [candidates[i] for i in sorted_indices[:test_size]]

def load_increased_edges():
    # this method is used when the candidates are outside the search space

    arches = []
    arches_dict = json.load(open('graphs_6_11.json', 'rb'))
    ops_list = ['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3', 'output', 'input']

    for h in arches_dict.keys():
        matrix = np.array(arches_dict[h][0])
        ops_vector = arches_dict[h][1]
        ops = [ops_list[i] for i in ops_vector]
        arch = {'matrix':matrix, 'ops':ops}
        arches.append(arch)

    return arches


def experiment(search_space,
               params,
               nasbench,
               encoding_type='path',
               cutoff=40,
               distance='adj',
               explore_type='mean',
               train_type='some_random',
               train_size=1000,
               test_size=5000,
               train_vertices=6,
               test_vertices=6,
               train_edges=7,
               test_edges=9,
               trials=1,
               num_output=10,
               num_ensemble=3,
               out_file='outside',
               save_dir='outside'):

    train_errors = []
    candidate_predictions = []

    # boolean specifying if the test set is outside of nasbench
    outside_ss = ((test_edges > 9) or (test_vertices > 7))

    # generate the training data
    if train_type == 'best':
        print('generating best training arches')
        top_k = get_top_k(nasbench, k=train_size, max_vertices=train_vertices, max_edges=train_edges)
        train_data = search_space.convert_to_cells(top_k, 
                                                    encoding_type=encoding_type,
                                                    cutoff=cutoff,
                                                    train=True)

    elif train_type == 'some_random':
        print('generating best arches with some random arches')
        # first get the best architectures
        top_k = get_top_k(nasbench, k=train_size*9//10, max_vertices=train_vertices, max_edges=train_edges)
        best = search_space.convert_to_cells(top_k,
                                                encoding_type=encoding_type,
                                                cutoff=cutoff,
                                                train=True)

        # get some random architectures
        random = search_space.generate_random_dataset(num=train_size//10, 
                                                        encoding_type=encoding_type, 
                                                        cutoff=cutoff,
                                                        max_edges=train_edges,
                                                        max_nodes=test_vertices)

        # stagger the training data so that the first 1000 architectures contain
        # half best and half random architectures (this matters for prune_by_distance() later on)
        train_data = [*best[:train_size//10], *random, *best[train_size//10:]]

    # generate the candidates
    if outside_ss:
        # if the candidates are outside nasbench, we need to load the specs from a file
        candidate_arches = load_increased_edges()

    else:
        # if the candidages are inside nasbench, we can generate the specs here
        # first generate candidates with extra edges
        if train_edges < test_edges:
            increased_edges = get_all_constrained(nasbench, 
                                                  min_vertices=3,
                                                  max_vertices=train_vertices,
                                                  min_edges=train_edges+1,
                                                  max_edges=test_edges)
        else:
            increased_edges = []

        # generate candidates with extra vertices
        if train_vertices < test_vertices:
            increased_vertices = get_all_constrained(nasbench, 
                                                     min_vertices=train_vertices+1,
                                                     max_vertices=test_vertices,
                                                     min_edges=train_vertices-1,
                                                     max_edges=train_edges)
        else:
            increased_vertices = []

        # generate candidates with extra vertices and edges
        if train_vertices < test_vertices and train_edges < test_edges:
            increased_both = get_all_constrained(nasbench, 
                                                 min_vertices=train_vertices+1,
                                                 max_vertices=test_vertices,
                                                 min_edges=max(train_vertices-1, train_edges+1),
                                                 max_edges=test_edges)
        else:
            increased_both = []

        candidate_arches = [*increased_edges, *increased_vertices, *increased_both]

    # convert the candidates from the nasbench format to the bananas format
    candidates = search_space.convert_to_cells(candidate_arches,
                                               encoding_type=encoding_type,
                                               cutoff=cutoff,
                                               train=False)
            
    xtrain = np.array([d['encoding'] for d in train_data])
    ytrain = np.array([d['val_loss'] for d in train_data])

    print('len xtrain', xtrain.shape)
    print('len ytrain', ytrain.shape)
    print('len candidates', len(candidates))

    # compute top+closest (parameter for pruning the candidate set)
    if test_size < len(candidates):
        if train_size >= 5000:
            top_closest = 1000
        elif train_size >= 1000:
            top_closest = 500
        else:
            top_closest = 100

        # prune the candidate set to only contain arches close to arches from the training set
        candidates = prune_by_distance(xtrain, candidates,
                                       top_closest=top_closest,
                                       test_size=test_size,
                                       distance=distance)


    xcandidates = np.array([c['encoding'] for c in candidates])

    print('len xcandidates after pruning', xcandidates.shape)

    # since it takes awhile to generate the training and candidate sets, run 3 trials
    # e.g. we run the meta neural net 3 times with the same training and candidate sets
    for t in range(trials):

        for num in range(num_ensemble):   

            meta_neuralnet = MetaNeuralnet()         
            meta_neuralnet.fit(xtrain, ytrain, **params)

            train_pred = np.squeeze(meta_neuralnet.predict(xtrain))
            train_error = np.mean(abs(train_pred-ytrain))
            train_errors.append(train_error)
            print('finished ensemble', num)

            # get candidates predictions
            candidate_predictions.append(np.squeeze(meta_neuralnet.predict(xcandidates)))
            print('finished predicting')

            # clear the tensorflow graph
            tf.reset_default_graph()
        
        tf.keras.backend.clear_session()

        train_error = np.round(np.mean(train_errors, axis=0), 5)

        # rank the candidates by their acquisition function value
        candidate_indices = acq_fn(candidate_predictions, explore_type)

        print('Train error: {}'.format(train_error))
        results = []
        best = 100

        for i in candidate_indices[:num_output]:

            # get the data for the best candidates
            mean = np.mean(candidate_predictions, axis=0)[i]
            std = np.std(candidate_predictions, axis=0)[i]
            preds = [np.round(candidate_predictions[j][i], 3) for j in range(num_ensemble)]
            print('mean', mean)
            print('std', std)
            print('predictions', preds)

            if outside_ss:
                # if the candidates are outside of nasbench, just record the spec so we can train it later
                result = {**candidates[i], 'mean':mean, 'std':std, 'preds':preds}
                result.pop('encoding')
                results.append(result)

            else:
                # if the candidates are inside nasbench, record their val losses
                trained = search_space.query_arch(candidates[i]['spec'])
                trained.pop('encoding')
                if trained['val_loss'] < best:
                    best = trained['val_loss']

                print(trained['val_loss'])
                result = {'mean':mean, 'std':std, 'predictions':preds, **trained}
                results.append(result)

        # print to file    
        filename = os.path.join(save_dir, '{}_{}_{}.pkl'.format(out_file, encoding_type, t))

        print('overall best:', best)
        print('saving', filename)
        with open(filename, 'wb') as f:
            pickle.dump(results, f)
            f.close()


def run_outside(args, save_dir):

    out_file = args.output_filename
    metann_params = meta_neuralnet_params(args.search_space)
    num_ensemble = args.num_ensemble
    train_vertices = args.train_vertices
    test_vertices = args.test_vertices
    train_edges = args.train_edges
    test_edges = args.test_edges
    train_type = args.train_type
    trials = args.trials
    train_size = args.train_size
    logging.info(metann_params)

    mp = copy.deepcopy(metann_params)
    ss = mp.pop('search_space')
    mf = mp.pop('mf')

    nasbench_folder='./'
    nasbench = api.NASBench(nasbench_folder + 'nasbench_only108.tfrecord')
    search_space = Data(ss, loaded_nasbench=nasbench)

    results = []

    #encoding_types = ['path', 'freq', 'adjacency']

    encoding_types = ['adj', 'cat_adj', 'path', 'trunc_path',\
                    'cat_path', 'trunc_cat_path']
    
    for encoding_type in encoding_types:

        print('starting to run', encoding_type)

        candidates = experiment(search_space, mp, nasbench,
                                encoding_type=encoding_type,
                                num_ensemble=num_ensemble,
                                train_type=train_type,
                                train_vertices=train_vertices,
                                test_vertices=test_vertices,
                                train_edges=train_edges,
                                test_edges=test_edges,
                                trials=trials,
                                train_size=train_size,
                                out_file=out_file,
                                save_dir=save_dir)


def main(args):

    # make save directory
    save_dir = args.save_dir
    if not save_dir:
        save_dir = 'results_outside/'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    # set up logging
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info(args)

    run_outside(args, save_dir)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Args for path length experiments')
    parser.add_argument('--search_space', type=str, default='nasbench_outside', help='ss args')
    parser.add_argument('--output_filename', type=str, default='outside', help='name of output files')
    parser.add_argument('--train_type', type=str, default='some_random', help='name of output files')
    parser.add_argument('--save_dir', type=str, default=None, help='name of save directory')
    parser.add_argument('--num_ensemble', type=int, default=10, help='size of metann ensemble')
    parser.add_argument('--train_vertices', type=int, default=6, help='training vertices')
    parser.add_argument('--test_vertices', type=int, default=6, help='testing vertices')
    parser.add_argument('--train_edges', type=int, default=9, help='training edges')
    parser.add_argument('--test_edges', type=int, default=11, help='testing edges')
    parser.add_argument('--trials', type=int, default=1, help='num trials with same dataset')
    parser.add_argument('--train_size', type=int, default=5000, help='training set size')

    args = parser.parse_args()
    main(args)
