"""Training and testing interaction prediction."""
import glob
import json
import logging
import os

import h5py
import numpy as np
import sklearn.metrics as sm
import tqdm

import src.feat.sequence as seq
import src.learning.interaction.pair_to_tfrecord as ptt
import src.learning.interaction.model as im
import src.learning.interaction.model_params as ip
import src.learning.interaction.model_runner as run
import src.learning.interaction.seqmodel_params as qp
import src.learning.interaction.test_params as sp
import src.learning.subgrid_generation as sg


def compute_accuracy(predictions, labels, model_params):
    """Compute classification accuracy with a fixed threshold on distances.
    """
    # Examples we predicted as being non-interacting.
    negative_pred = labels[predictions.ravel() < 0.5]
    # Examples we predicted as being interacting.
    positive_pred = labels[predictions.ravel() > 0.5]

    TN = np.sum(negative_pred == 0)
    FN = np.sum(negative_pred == 1)
    TP = np.sum(positive_pred == 1)
    FP = np.sum(positive_pred == 0)

    if TP + FP != 0.0:
        precision = 1.0 * TP / (TP + FP)
    else:
        precision = 0.0
    if TP + FN != 0.0:
        recall = 1.0 * TP / (TP + FN)
    else:
        recall = 0.0
    if TP + TN + FP + FN != 0.0:
        accuracy = 1.0 * (TP + TN) / (TP + TN + FP + FN)
    else:
        accuracy = 0.0
    return precision, recall, accuracy, TP, TN, FP, FN


def test_model_main(args):
    """Test best model."""
    model_params = ip.load_params(args.model_json)
    train_params = sp.load_params(args.test_json)
    (auc, te_acc, te_pre, te_rec, te_tp, te_tn, te_fp, te_fn, te_loss) = \
        test_model(model_params, train_params, args.out_dir, args.model_dir,
                   args.model_chkpt, seqmodel_json=args.seqmodel_json)

    final_report = {}
    final_report['test_auc'] = auc
    final_report['test_acc'] = te_acc
    final_report['test_pre'] = te_pre
    final_report['test_rec'] = te_rec
    final_report['test_tp'] = te_tp
    final_report['test_tn'] = te_tn
    final_report['test_fp'] = te_fp
    final_report['test_fn'] = te_fn
    final_report['test_loss'] = te_loss

    with open(args.out_dir + '/summary.json', 'w') as f:
        json.dump(final_report, f)


def test_model(model_params, test_params, out_dir, model_dir, model_chkpt="",
               seqmodel_json=""):
    """Evaluate model on outputs."""
    import tensorflow as tf

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # Create 3D grid generator.
    gen = sg.TFSubgridGenerator(
        model_params, test_params['num_directions'],
        test_params['num_rolls'])

    # Create tf.Dataset for reading in test set.
    test_tfrecords = glob.glob(
        test_params['dataset_tfrecords'] + '/*.tfrecord')
    test_dataset, num_batches_testing = ptt.create_tf_dataset(
        test_params, 'testing', test_tfrecords, gen.get_gridded_pair)

    has_seq = seqmodel_json != ""
    if has_seq:
        assert os.path.exists(seqmodel_json)
        seqmodel_params = qp.load_params(seqmodel_json)
        test_dataset = seq.add_seq_information(
            test_dataset, test_params['seq_src'], seqmodel_params)

    test_iterator = test_dataset.make_one_shot_iterator()

    # TODO: Remove this intermediate catch.
#    next_el = test_iterator.get_next()
#    with tf.Session() as sess:
#        val = sess.run(next_el)
#        import pdb; pdb.set_trace()
#    return

    # Feedable iterator.
    handle = tf.placeholder(tf.string, shape=[], name='handle')
    iterator = tf.data.Iterator.from_string_handle(
        handle, test_dataset.output_types, test_dataset.output_shapes)

    model = im.InteractionModel(model_dir)
    model.load(test_params['towers'], checkpoint_filename=model_chkpt,
               iterator=iterator, has_seq=has_seq)
    mr = run.ModelRunner(model, test_iterator, False)

    # compute final accuracy on test set
    logging.info('Starting to test model')
    all_dirs = np.zeros((0, 2), dtype='i4')
    all_rolls = np.zeros((0, 2), dtype='i4')
    all_pred = np.zeros((0))
    all_labels = np.zeros((0))
    all_uids = np.zeros((0, 2), dtype='|S50')
    total_loss = 0
    t = tqdm.trange(num_batches_testing, desc='Acc: {:6.4f}'.format(0))
    epoch_acc = 0
    for i, batch_idx in enumerate(t):
        bi = mr.next()
        epoch_acc += (bi.acc - epoch_acc) / (i + 1)
        total_loss += bi.loss
        all_pred = np.concatenate((all_pred, bi.pred))
        all_labels = np.concatenate((all_labels, bi.example['label']))
        all_uids = np.concatenate((all_uids, bi.example['uid']))
        all_dirs = np.concatenate((all_dirs, bi.example['direction']))
        all_rolls = np.concatenate((all_rolls, bi.example['roll']))
        t.set_description('Acc: {:6.4f}'.format(epoch_acc))

    all_uids = all_uids.astype(dtype='|S50')

    te_loss = total_loss / num_batches_testing

    auc = sm.roc_auc_score(all_labels, all_pred)

    fpr, tpr, thresholds = sm.roc_curve(all_labels, all_pred)

    with h5py.File(out_dir + '/roc.h5', 'w') as f:
        f.create_dataset('fpr', data=np.array(fpr))
        f.create_dataset('tpr', data=np.array(tpr))
        f.create_dataset('thresholds', data=np.array(thresholds))
        f.create_dataset('predictions', data=all_pred)
        f.create_dataset('labels', data=all_labels)
        f.create_dataset('dirs', data=all_dirs)
        f.create_dataset('rolls', data=all_rolls)
        f.create_dataset('uids', data=all_uids)

    logging.info('Test loss: {:0.5f}'.format(te_loss))
    logging.info('Test AUC: {:0.2f}'.format(auc))
    (te_pre, te_rec, te_acc, te_tp, te_tn, te_fp, te_fn) = \
        compute_accuracy(all_pred, all_labels, model_params)
    logging.info('Test acc:  {:0.2f} pre: {:0.2f} rec: {:0.2f} TP: {:6d} '
                 'TN: {:6d} FP: {:6d} FN: {:6d}'
                 .format(te_acc, te_pre, te_rec, te_tp, te_tn, te_fp, te_fn))

    return (auc, te_acc, te_pre, te_rec, te_tp, te_tn, te_fp, te_fn, te_loss)
