import numpy as np
import pandas as pd


def form_regions(pair, positive):
    """Helper function to process either positive or negative regions."""
    import tensorflow as tf
    if positive:
        idx = pair['pos_idx']
    else:
        idx = pair['neg_idx']
    num = tf.shape(idx)[0]
    result = {}
    to_gather = ['resname', 'pdb_name', 'model', 'residue', 'center', 'chain']
    for name in to_gather:
        pieces = []
        for side in (0, 1):
            if name != 'center':
                subname = '{:}s{:}'.format(name, side)
            else:
                subname = 'positions{:}'.format(side)
            pieces.append(tf.gather(pair[subname], idx[:, side]))
        pieces = tf.stack(pieces, axis=1)
        result[name] = pieces

    labels = tf.cast(tf.fill((num,), positive), tf.float32)
    result['label'] = labels

    return result


def _pair_in_list(name, list):
    """
    Check to see if the provided key is in the list.
    """
    import tensorflow as tf
    return tf.not_equal(tf.size(tf.sets.set_intersection(
        tf.expand_dims([name], 0), list.T)), 0)


def _pairdataset_to_regpairdataset(pair_dataset, positive):
    """Transform dataset outputing complex pairs to output amino acid pairs."""
    import tensorflow as tf

    if positive:
        ref = 'num_pos'
    else:
        ref = 'num_neg'

    # This dataset extracts specific amino acid pairs.
    regpair_dataset = pair_dataset.flat_map(
        lambda x: tf.data.Dataset.from_tensor_slices(
            form_regions(x, positive)))
    # Remove non-essential things from complex pairs, for speed purposes.
    prunedpair_dataset = pair_dataset.map(_prune)
    # This dataset just replicates the generating protein pair so that we can
    # feed alongside each amino acid pair.
    repeat_prunedpair_dataset = prunedpair_dataset.flat_map(
        lambda x: tf.data.Dataset.from_tensors(x).repeat(x[ref]))
    dataset = tf.data.Dataset.zip((repeat_prunedpair_dataset, regpair_dataset))
    dataset = dataset.map(_combine, num_parallel_calls=8)
    return dataset


def _prune(pair):
    import tensorflow as tf
    pruned_pair = {'positions0': pair['positions0'],
                   'positions1': pair['positions1'],
                   'elements0': pair['elements0'],
                   'elements1': pair['elements1'],
                   'aids0': pair['aids0'],
                   'aids1': pair['aids1'],
                   'num_pos': tf.cast(tf.shape(pair['pos_idx'])[0], tf.int64),
                   'num_neg': tf.cast(tf.shape(pair['neg_idx'])[0], tf.int64)}
    return pruned_pair


def _combine(pair, regpair):
    result = pair.copy()
    result.update(regpair)
    return result


def get_dataset_size(dataset):
    """Get number of elements in tf dataset."""
    import tensorflow as tf

    iterator = dataset.make_one_shot_iterator()

    count = 0
    next_el = iterator.get_next()
    with tf.Session() as sess:
        while True:
            try:
                sess.run(next_el)
                count += 1
            except tf.errors.OutOfRangeError:
                break
    return count


def create_tf_dataset(run_params, source, tfrecords, fn):
    import tensorflow as tf

    if source == 'training':
        testing = False
        num_examples = run_params['num_training']
        keep_file = run_params['keep_file_training']
        prune_file = run_params['prune_file_training']
        keep_file_pairs = run_params['keep_file_pairs_training']
        np.random.seed(run_params['seed'])
        pos_seed, neg_seed = np.random.randint(10000000, size=4)[0:2]
    elif source == 'validation':
        testing = False
        num_examples = run_params['num_validation']
        keep_file = run_params['keep_file_validation']
        prune_file = run_params['prune_file_validation']
        keep_file_pairs = run_params['keep_file_pairs_validation']
        np.random.seed(run_params['seed'])
        pos_seed, neg_seed = np.random.randint(10000000, size=4)[2:4]
    else:
        testing = True
        num_examples = run_params['num_testing']
        run_params['rolls_per_pass'] = run_params['num_rolls']
        keep_file = run_params['keep_file_testing']
        prune_file = run_params['prune_file_testing']
        keep_file_pairs = run_params['keep_file_pairs_testing']
    num_passes = run_params['num_rolls'] / run_params['rolls_per_pass']
    num_batches = num_examples / run_params['batch_size'] * num_passes
    assert num_examples % run_params['batch_size'] == 0
    assert run_params['num_rolls'] % run_params['rolls_per_pass'] == 0
    assert run_params['batch_size'] % run_params['towers'] == 0

    buffer_size = run_params['shuffle_buffer']
    num_interleaved = run_params['num_interleaved']

    towers = run_params['towers']
    # The 2x multiplier is to ensure we get a balanced number of pos and neg.
    aug_batch_size = run_params['batch_size'] / towers * 2 * \
        run_params['rolls_per_pass']

    dataset = tf.data.TFRecordDataset(tfrecords).interleave(
        lambda x: tf.data.Dataset.from_tensors(
            parse_tf_example(x)), num_interleaved)
    if len(prune_file) > 0 and len(keep_file) > 0:
        raise RuntimeError(
            "Can't specify list to prune and list to keep at same time!")
    if len(prune_file) != 0:
        list_to_prune = pd.read_table(prune_file, names=['complex'],
                                      delim_whitespace=True, dtype=str)
        dataset = dataset.filter(
            lambda x: ~_pair_in_list(x['complex'], list_to_prune))
    if len(keep_file) != 0:
        list_to_keep = pd.read_table(keep_file, names=['complex'],
                                     delim_whitespace=True, dtype=str)
        dataset = dataset.filter(
            lambda x: _pair_in_list(x['complex'], list_to_keep))
    if len(keep_file_pairs) != 0:
        pairs_to_keep = pd.read_table(keep_file_pairs, names=['id'],
                                      delim_whitespace=True, dtype=str)
        dataset = dataset.filter(
            lambda x: _pair_in_list(x['complex'] + '_' + x['id'],
                                    pairs_to_keep))
    pos_dataset = _pairdataset_to_regpairdataset(dataset, True)
    neg_dataset = _pairdataset_to_regpairdataset(dataset, False)
    if not testing:
        if run_params['loose']:
            pos_dataset = pos_dataset.apply(tf.contrib.data.shuffle_and_repeat(
                buffer_size, num_passes, seed=pos_seed))
            neg_dataset = neg_dataset.apply(tf.contrib.data.shuffle_and_repeat(
                buffer_size, num_passes, seed=neg_seed))
        else:
            pos_dataset = pos_dataset.shuffle(buffer_size, seed=pos_seed,
                                              reshuffle_each_iteration=True)
            pos_dataset = pos_dataset.take(num_examples)
            pos_dataset = pos_dataset.repeat(num_passes)
            neg_dataset = neg_dataset.shuffle(buffer_size, seed=neg_seed,
                                              reshuffle_each_iteration=True)
            neg_dataset = neg_dataset.take(num_examples)
            neg_dataset = neg_dataset.repeat(num_passes)
    dataset = tf.data.Dataset.zip((pos_dataset, neg_dataset))
    if not testing:
        dataset = dataset.repeat(run_params['max_epochs'])

    # Interleave positive and negative examples.  From
    # https://stackoverflow.com/questions/46938530/produce-balanced-mini-batch-with-dataset-api
    dataset = dataset.flat_map(
        lambda ex_pos, ex_neg:
        tf.data.Dataset.from_tensors(ex_pos)
        .repeat(run_params['rolls_per_pass'])
        .concatenate(tf.data.Dataset.from_tensors(ex_neg)
                     .repeat(run_params['rolls_per_pass'])))

    dataset = dataset.map(fn, num_parallel_calls=8).batch(aug_batch_size)

    dataset = dataset.prefetch(buffer_size=towers)
    return dataset, num_batches


def parse_tf_example(example_serialized):
    import tensorflow as tf
    features = {}
    for s in (0, 1):
        features['positions{:}'.format(s)] = tf.VarLenFeature(tf.float32)
        features['elements{:}'.format(s)] = tf.VarLenFeature(tf.string)
        features['atom_names{:}'.format(s)] = tf.VarLenFeature(tf.string)
        features['aids{:}'.format(s)] = tf.VarLenFeature(tf.string)
        features['residues{:}'.format(s)] = tf.VarLenFeature(tf.string)
        features['resnames{:}'.format(s)] = tf.VarLenFeature(tf.string)
        features['pdb_names{:}'.format(s)] = tf.VarLenFeature(tf.string)
        features['chains{:}'.format(s)] = tf.VarLenFeature(tf.string)
        features['models{:}'.format(s)] = tf.VarLenFeature(tf.string)
    features['src0'] = tf.FixedLenFeature([], tf.string)
    features['src1'] = tf.FixedLenFeature([], tf.string)
    features['complex'] = tf.FixedLenFeature([], tf.string)
    features['pos_idx'] = tf.VarLenFeature(tf.int64)
    features['neg_idx'] = tf.VarLenFeature(tf.int64)
    pair = tf.parse_single_example(example_serialized, features=features)
    for s in (0, 1):
        for value in ['positions', 'elements', 'atom_names', 'residues',
                      'resnames', 'pdb_names', 'chains', 'models', 'aids']:
            curr = '{:}{:}'.format(value, s)
            if value != 'positions':
                pair[curr] = tf.sparse_tensor_to_dense(
                    pair[curr], default_value='')
            else:
                pair[curr] = tf.reshape(
                    tf.sparse_tensor_to_dense(pair[curr]), [-1, 3])
    pair['pos_idx'] = tf.reshape(
        tf.sparse_tensor_to_dense(pair['pos_idx']), [-1, 2])
    pair['neg_idx'] = tf.reshape(
        tf.sparse_tensor_to_dense(pair['neg_idx']), [-1, 2])
    return pair
