"""Code to do with pdb structure sequences."""
import h5py
import numpy as np

import src.feat.database as db


def add_seq_information(dataset, seq_src, seqmodel_params):
    """Add sequence information to existing tf dataset."""
    import tensorflow as tf
    data = {}
    with h5py.File(seq_src, 'r') as f:
        for group in f.keys():
            data[group] = {}
            for chain in f[group].keys():
                res_to_pos = dict(
                    [(x, i) for i, x in
                     enumerate(f[group][chain]['pos_to_res'][:])])
                data[group][chain] = {
                    'pssm': f[group][chain]['pssm'][:],
                    'psfm': f[group][chain]['psfm'][:],
                    'res_to_pos': res_to_pos}

    def _get_pssm_window(pdb_names, models, chains, residues):
        num_ex = pdb_names.shape[0]
        radius = seqmodel_params['cons_window_radius']
        size = radius * 2 + 1
        pssms = np.zeros((num_ex, 2, size, 20), dtype='i4')
        psfms = np.zeros((num_ex, 2, size, 20), dtype='f4')
        for ex in range(num_ex):
            for which in (0, 1):
                pdb_name = db.get_pdb_name(
                    pdb_names[ex, which], with_type=False)
                model = models[ex, which]
                chain = chains[ex, which]
                residue = residues[ex, which]
                if pdb_name not in data:
                    continue
                grp = data[pdb_name][model + '_' + chain]
                if grp['pssm'].shape[0] == 0:
                    continue
                pos = grp['res_to_pos'][residue]
                start, end = pos - radius, pos + radius + 1
                # Get window around location of interest.
                for i, idx in enumerate(range(start, end)):
                    if idx >= 0 and idx < grp['pssm'].shape[0]:
                        pssms[ex, which, i] = grp['pssm'][idx]
                        psfms[ex, which, i] = grp['psfm'][idx]
        return pssms, psfms

    seq_dataset = dataset.map(
        lambda x: tf.py_func(
            _get_pssm_window,
            [x['pdb_name'], x['model'], x['chain'], x['residue']],
            [tf.int32, tf.float32]))
    dataset = tf.data.Dataset.zip((dataset, seq_dataset))
    dataset = dataset.map(
        lambda ex, seq: dict({'pssm': seq[0], 'psfm': seq[1]}, **ex))
    return dataset
