from abc import ABCMeta

import numpy as np
import tensorflow as tf
import tree


class ValuePredictionTask(metaclass=ABCMeta):
    def __init__(self, algorithm, criteria, batch_size=128):
        self.method = algorithm
        self.criteria = criteria
        self.batch_size = batch_size

        self.total_steps = 0

    def step(self, num_steps=1, episodic=False):

        result = {
            'errors': {},
        }

        state_0s = self.dataset['samples']['state_0']
        actions = self.dataset['samples']['action']
        state_1s = self.dataset['samples']['state_1']
        rewards = self.dataset['samples']['reward']
        terminals = self.dataset['samples']['terminal']
        rhos = self.dataset['samples'].get(
            'rho', np.full(terminals.shape, np.nan))

        N = state_0s.shape[0]

        method = self.method

        steps_diagnostics = []
        for i in range(self.total_steps, self.total_steps + num_steps):
            batch_indices = np.random.randint(N, size=self.batch_size)
            step_diagnostics = method.update_V(
                state_0s=np.atleast_2d(state_0s[batch_indices]),
                actions=np.atleast_2d(actions[batch_indices]),
                state_1s=np.atleast_2d(state_1s[batch_indices]),
                rewards=np.atleast_2d(rewards[batch_indices]),
                terminals=np.atleast_2d(terminals[batch_indices]),
                rhos=np.atleast_2d(rhos[batch_indices]))
            steps_diagnostics += [step_diagnostics]

        self.total_steps += num_steps
        algorithm_diagnostics = tree.map_structure(
            lambda *x: tf.reduce_mean(x), *steps_diagnostics)

        errors = {}
        for i_e, criterion in enumerate(self.criteria):
            error_fn = getattr(self, criterion)
            errors[criterion] = error_fn(method.V)

        result = {'errors': errors, **algorithm_diagnostics}
        return result

    def MSE(self, value_function):
        predictions = value_function.values(self.true_value_states)
        true_values = self.true_values
        tf.debugging.assert_equal(tf.shape(predictions), tf.shape(true_values))
        MSE = tf.losses.MeanSquaredError(tf.losses.Reduction.AUTO)(
            y_pred=predictions, y_true=true_values)
        return MSE

    def MSPBE(self):
        raise NotImplementedError

    def MSBE(self):
        raise NotImplementedError

    def MSPBE_tar(self):
        raise NotImplementedError

    def MSBE_tar(self):
        raise NotImplementedError

    def RMSE(self, *args, **kwargs):
        return np.sqrt(self.MSE(*args, **kwargs))

    def RMSPBE(self, *args, **kwargs):
        return np.sqrt(self.MSPBE(*args, **kwargs))

    def RMSBE(self, *args, **kwargs):
        return np.sqrt(self.MSBE(*args, **kwargs))

    def RMSPBE_tar(self, *args, **kwargs):
        return np.sqrt(self.MSPBE_tar(*args, **kwargs))

    def RMSBE_tar(self, *args, **kwargs):
        return np.sqrt(self.MSBE_tar(*args, **kwargs))
