"""Code for using a model with a generator."""
import collections as col
import timeit

import numpy as np


BatchInfo = col.namedtuple(
    'BatchInfo',
    ['pred', 'example', 'acc',
     'loss', 'summary', 'run_metadata', 'elapsed_generator',
     'elapsed_learning', 'check_sum'])


class ModelRunner:
    """A combination of a model and a generator."""

    def __init__(self, model, generator, training, use_handle=True):
        self.training = training
        self.model = model
        self.generator = generator
        if use_handle:
            self.handle = self.model.sess.run(generator.string_handle())
        else:
            self.handle = None

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def next(self):

        learning_time_start = timeit.default_timer()
        if self.training:
            pred_out, acc, loss, summary, run_metadata, gridded_pairs =\
                self.model.train(self.handle)
        else:
            pred_out, acc, loss, summary, gridded_pairs =\
                self.model.infer_handle(handle=self.handle)
            run_metadata = None

        check_sum = np.sum(gridded_pairs['grid'])
        elapsed_learning = timeit.default_timer() - learning_time_start
        return BatchInfo(
            pred_out, gridded_pairs, acc, loss, summary, run_metadata, 0,
            elapsed_learning, check_sum)
