import tensorflow as tf
import numpy as np


class SimpleTrainer:
    def __init__(self, loss='x-entropy'):
        self.XENTROPY = 'x-entropy'
        self.L2 = 'l2'
        self.loss = loss
        assert loss in [self.XENTROPY, self.L2]
        self.__lossOp = None
        self.__trainOp = None
        self.__accuracyOp = None

    def getTrainOp(self):
        return self.__trainOp

    def getLossOp(self):
        return self.__lossOp

    def getAccuracyOp(self):
        return self.__accuracyOp

    def __loss(self, logits, y_batch):
        labels = tf.stop_gradient(y_batch)
        if self.loss == self.XENTROPY:
            loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                             labels=labels)
        else:
            loss = tf.losses.mean_squared_error(labels=labels,
                                                predictions=logits)
            print("Using L2 loss")
        lossOp = tf.reduce_mean(loss)
        self.__lossOp = lossOp
        return self.__lossOp

    def __train(self, lossOp, optimizer, learningRate, momentum, varList):
        assert optimizer in ['ADAM', 'Nesterov']
        # Define loss and optimizer
        if optimizer is 'Nesterov':
            optimizer = tf.train.MomentumOptimizer(learning_rate=learningRate,
                                                   use_nesterov=True,
                                                   momentum=momentum)
        elif optimizer is 'ADAM':
            optimizer = tf.train.AdamOptimizer(learning_rate=learningRate)

        trainOp = optimizer.minimize(lossOp, var_list=varList)
        self.__trainOp = trainOp

    def __metric(self, logits, y_batch):
        # Evaluate model (with test logits, for dropout to be disabled)
        correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y_batch, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
        self.__accuracyOp = accuracy

    def __call__(self, logits, y_batch, optimizer='Nesterov',
                 learningRate=0.001, momentum=0.9, lossFunc=None,
                 varList=None):
        lossOp = self.__loss(logits, y_batch)
        if lossFunc is not None:
            self.__lossOp = lossFunc(lossOp)
            lossOp = self.__lossOp
        self.__train(lossOp, optimizer, learningRate, momentum, varList)
        self.__metric(logits, y_batch)


