"""Track inter-epoch information."""
import glob
import json
import os
import re

import h5py
import numpy as np


class EpochTracker:
    """Tracks epochs for a given training run."""

    def __init__(self, out_dir):
        """
        Initialization.

        Args:
            out_dir (str):
                Directory where outputs are located.
        """
        self.out_dir = out_dir

    def get_next_epoch(self):
        """Get what epoch the model located in out_dir is on."""
        if not os.path.exists(self.out_dir):
            return 0

        epoch_count = 0
        checkpoints = glob.glob(self._get_checkpoint_base(-1))
        for ckpt in checkpoints:
            m = re.search('.*checkpoint-(.+)-.*\.ckpt', ckpt)
            if m:
                epoch = int(m.group(1))
                if epoch >= epoch_count:
                    epoch_count = epoch + 1

        return epoch_count

    def has_checkpoint(self):
        """Determine if we have a checkpointed epoch."""
        return self.latest_checkpoint_filename() != ""

    def has_init(self, epoch=-1):
        """Determine if we have initialization metagraph."""
        return os.path.exists(self.metagraph_filename(epoch))

    def latest_checkpoint_filename(self):
        """Get path to latest checkpoint, or empty if non-existent."""
        next_epoch = self.get_next_epoch()
        return self.checkpoint_filename(next_epoch - 1)

    def best_checkpoint_filename(self):
        """Get path to best checkpoint, by validation loss."""
        next_epoch = self.get_next_epoch()

        best_epoch_filename = ""
        min_val = float('inf')
        for epoch in range(next_epoch):
            file_name = self.loss_summary_filename(epoch)
            if not os.path.exists(file_name):
                continue
            with open(file_name, 'r') as f:
                loss_summary = json.load(f)
                val_loss = loss_summary['val_loss']
                if val_loss < min_val:
                    min_val = val_loss
                    best_epoch_filename = self.checkpoint_filename(epoch)

        return best_epoch_filename

    def latest_metagraph_filename(self):
        """Get latest metagraph."""
        return self.metagraph_filename(self.get_next_epoch() - 1)

    def metagraph_filename(self, epoch=-1):
        """Get path to metagraph for provided epoch (or init if no epoch)."""
        if epoch == -1:
            return self.out_dir + "/init.ckpt.meta"
        else:
            candidates = glob.glob(self._get_checkpoint_base(epoch) +
                                   '-*.ckpt.meta')
            if len(candidates) != 0:
                return candidates[0]
            else:
                return ""

    def checkpoint_filename(self, epoch):
        """Get path to checkpoint corresponding to given epoch."""
        candidate = None
        candidates_new = glob.glob(self._get_checkpoint_base(epoch) +
                                   '-*.ckpt.index')
        if len(candidates_new) != 0:
            # Prune off .index suffix.
            candidate = candidates_new[0][:-6]
        else:
            # See if we find a candidate in old format.
            candidates_old = glob.glob(self._get_checkpoint_base(epoch) +
                                       '-*.ckpt')
            if len(candidates_old) != 0:
                candidate = candidates_old[0]

        if candidate is None:
            return ""
        else:
            return candidate

    def is_validation_decreasing(self):
        """Return if validation is still decreasing across epochs."""
        next_epoch = self.get_next_epoch()

        min_val = float('inf')
        for epoch in range(next_epoch):
            file_name = self.loss_summary_filename(epoch)
            with open(file_name, 'r') as f:
                loss_summary = json.load(f)
                val_loss = loss_summary['val_loss']
                if val_loss > min_val:
                    return False
                else:
                    min_val = val_loss

        return True

    def loss_summary_filename(self, epoch):
        """Get path to loss summary file for provided epoch."""
        return self.out_dir + '/loss-summary-{:02d}.json'.format(epoch)

    def epoch_detail_filename(self, epoch):
        """Get path to epoch detail file for provided epoch."""
        return self.out_dir + '/epoch-detail-{:02d}.h5'.format(epoch)

    def _get_checkpoint_base(self, epoch):
        """Return checkpoint filename with just epoch filled in."""
        if epoch == -1:
            chk_file_base = self.out_dir + '/checkpoint-*'
        else:
            chk_file_base = self.out_dir + '/checkpoint-{epoch:02d}' \
                .format(epoch=epoch)
        return chk_file_base

    def _get_checkpoint_format(self):
        return self.out_dir + '/checkpoint-{epoch:02d}-{val_loss:.2f}.ckpt'

    def _write_losses(self, losses, epoch):
        """Write epoch val and training loss to file."""
        with open(self.loss_summary_filename(epoch), 'w') as f:
            json.dump(losses, f)

    def _write_epoch_details(self, losses, times, next_epoch):
        """Write per-batch losses and timing to file."""
        filename = self.epoch_detail_filename(next_epoch)
        with h5py.File(filename, 'w') as f:
            f.create_dataset('loss', data=np.array(losses))
            f.create_dataset('time', data=np.array(times))
