"""The base modified greedy local search algorithm."""

import dataclasses
import itertools
import random
from typing import Optional, Union

import numpy as np
import tensorflow as tf

from xoid import constants
from xoid.solvers import feasibility
from xoid.solvers import vertex_solvers

from xoid.util import basics
from xoid.util import misc_util
from xoid.util import network_util
from xoid.util import numerics
from xoid.util import vertex_util


_np = basics.to_np
about_equal = numerics.about_equal

_vtx_key = vertex_util.to_vertex_key
inactive_examples_count = vertex_util.inactive_examples_count

ModelParams = network_util.ModelParams


class LevelSetExhaustedError(Exception):
    pass


@dataclasses.dataclass()
class BaseMglsOptions:
    """Options/parameters for the base mGLS."""
    loss_fn: str

    initializer: str = 'random'
    output_weights_style: Union[str, float, np.ndarray] = 1.0

    order_vertices_by_inactive_examples: bool = True
    full_binding_constraint_flip_first: bool = True
    binding_constraint_vertices_have_priority: bool = True

    regularization: Optional[str] = None
    regularization_constant: float = 0.0

    max_level_set_attempts: Optional[int] = None

    no_scs: Optional[bool] = False

    def __post_init__(self):
        assert self.initializer in {'random', 'zeros'}
        assert self.loss_fn in constants.LOSS_FNS
        if self.regularization is not None:
            assert self.regularization in constants.REGULARIZERS

        if not isinstance(self.output_weights_style, np.ndarray):
            # TODO: Also support -1.0
            assert self.output_weights_style in {1.0, 'pm_1'}


class BaseMgls:
    """Base implementation of batch mGLS.

    Does not support:
        - Variable masks.
    """
    def __init__(self, X, Y, m, options: BaseMglsOptions, *, eps=1e-6):
        self.options = options

        self.X = _np(X)
        self.Y = _np(Y)

        self.m = m
        self.N, self.d = X.shape

        self.eps = eps
        self.dtype = self.X.dtype

        self.v: np.ndarray = self._initialize_v()
        assert self.v.shape == (m,)

        self.feasibility_checker = self._initialize_feasibility_checker()
        self.vertex_solver = self._initialize_vertex_solver()

        self._set_up_caches()
        self._initialize_state()

    #################################################################
    # Methods that could be overriden by subclasses.
    #################################################################

    def _initialize_v(self):
        output_weights_style = self.options.output_weights_style
        if isinstance(output_weights_style, np.ndarray):
            v = np.copy(output_weights_style).astype(self.dtype)
            assert v.shape == (self.m,)
        elif output_weights_style == 1.0:
            v = np.ones([self.m], dtype=self.dtype)
        elif output_weights_style == 'pm_1':
            v = misc_util.make_pm_1_v(self.m, self.dtype)
        else:
            raise ValueError(output_weights_style)
        return v

    def _initialize_feasibility_checker(self):
        return feasibility.FeasibilityChecker(self.X, eps=self.eps)

    def _initialize_vertex_solver(self):
        return vertex_solvers.VertexSolver(
            self.X, self.Y, self.options.loss_fn,
            m=self.m,
            v=self.v,
            regularization=self.options.regularization,
            regularization_constant=self.options.regularization_constant,
            eps=self.eps,
            no_scs=self.options.no_scs,
        )

    def _initialize_vertex(self):
        if self.options.initializer == 'zeros':
            return np.zeros([self.m, self.N], dtype=np.int32)

        elif self.options.initializer == 'random':
            layer = tf.keras.layers.Dense(self.m)
            return tf.cast(layer(self.X) > 0, tf.int32).numpy().T

        else:
            raise ValueError(self.options.initializer)
   
    #################################################################

    def _set_up_caches(self):
        self.vertex_to_loss = {}
        self.single_unit_pattern_to_feasibility = {}

    def _initialize_state(self):
        self.current_level_set = {}
        self.explored_level_set_vertices = set()
        self.current_vertex = self._initialize_vertex()
        self.current_vertex_results = None

    def _update_state(self, vertex, within_level_set, maybe_results):
        self.current_vertex = np.copy(vertex)
        self.current_vertex_results = maybe_results

        if not within_level_set:
            self.current_level_set = {}
            self.explored_level_set_vertices = set()

    #################################################################

    def _are_about_equal(self, a, b):
        return about_equal(a, b, self.eps)

    def _has_cross_entropy_loss(self):
        return self.options.loss_fn == 'sigmoid_cross_entropy'

    #################################################################

    def _is_pattern_feasible(self, ap, unit_index):
        # ap.shape = [N]
        ap_key = basics.to_deep_tuple(ap)
        if ap_key not in self.single_unit_pattern_to_feasibility:
            feas = self.feasibility_checker.is_feasible(_np(ap), unit_index)
            self.single_unit_pattern_to_feasibility[ap_key] = feas
        return self.single_unit_pattern_to_feasibility[ap_key]

    def _is_modified_pattern_feasible(self, vertex, modified_unit_inds):
        return all(
            self._is_pattern_feasible(vertex[i], i)
            for i in set(modified_unit_inds)
        )

    def _flip_boundary_constraints(self, vertex, unit_inds, ex_inds):
        flipped = vertex_util.flip_activations(vertex, unit_inds, ex_inds)
        if not self._is_modified_pattern_feasible(flipped, unit_inds):
            return None
        return flipped

    #################################################################

    def _uncached_solve_for_vertex(self, vertex):
        return self.vertex_solver.solve(vertex)

    def _get_loss(self, vertex):
        loss, _ = self._get_loss_and_maybe_results(vertex)
        return loss

    def _get_loss_and_maybe_results(self, vertex):
        vertex = _np(vertex)
        vtx_key = _vtx_key(vertex)
        
        if vtx_key not in self.vertex_to_loss:
            results = self._uncached_solve_for_vertex(vertex)
            loss = results.loss
            self.vertex_to_loss[vtx_key] = loss
        else:
            results = None

        return self.vertex_to_loss[vtx_key], results

    #################################################################

    def _iterate_over_binding_constraint_neighbors(self, vertex):
        assert (vertex == self.current_vertex).all()

        if self.current_vertex_results is not None:
            bcs = self.current_vertex_results.binding_constraints
            
        else:
            # TODO: Better way of seeing whether we done anything.
            if self.vertex_solver.w.value is None:
                return

            if not self.vertex_solver.is_current_vertex(vertex):
                return

            bcs = self.vertex_solver.get_binding_constraints()

        unit_inds, ex_inds = vertex_util.boundary_constraints_to_indices(bcs)

        if unit_inds.size and self.options.full_binding_constraint_flip_first:
            flipped = self._flip_boundary_constraints(vertex, unit_inds, ex_inds)
            if flipped is not None:
                yield flipped

        inds = list(zip(unit_inds, ex_inds))
        random.shuffle(inds)

        for i, j in inds:
            candidate = np.copy(vertex)
            candidate[i, j] = 1 - candidate[i, j]
            if self._is_pattern_feasible(candidate[i], i):
                yield candidate

    def _iterate_over_all_neighbors(self, vertex):
        vertex = _np(vertex)

        inds = list(itertools.product(range(self.m), range(self.N)))
        random.shuffle(inds)

        # TODO: Clean up this logic a bit.
        if self.options.order_vertices_by_inactive_examples:
            
            candidates = []
            for i, j in inds:
                candidate = np.copy(vertex)
                candidate[i, j] = 1 - candidate[i, j]
                candidates.append((i, candidate))
            # The max 1 is there on purpose. The idea is that we can perfectly fit only if
            # exactly 0 or 1 examples are entirely inactive, so we treat those the same.
            candidates = sorted(
                candidates,
                key=lambda v: max(inactive_examples_count(v[1]), 1),
            )

            for i, candidate in candidates:
                if self._is_pattern_feasible(candidate[i], i):
                    yield candidate

        else:
            for i, j in inds:
                candidate = np.copy(vertex)
                candidate[i, j] = 1 - candidate[i, j]
                if self._is_pattern_feasible(candidate[i], i):
                    yield candidate

    def _neighbors_to_try_iterator(self, vertex):
        if self.options.binding_constraint_vertices_have_priority:
            yield from self._iterate_over_binding_constraint_neighbors(vertex)

        yield from self._iterate_over_all_neighbors(vertex)

    #################################################################

    def _inner_get_next_vertex_in_level_set(self):
        # TODO: Make this a lot cleaner.
        # TODO: This is preliminary and can lead to loops and errors.
        current_loss = self._get_loss(self.current_vertex)

        if self.current_vertex_results is not None:
            results = self.current_vertex_results
        else:
            results = self._uncached_solve_for_vertex(self.current_vertex)
        
        bcs = results.binding_constraints
        unit_inds, ex_inds = vertex_util.boundary_constraints_to_indices(bcs)

        # If we are at the boundary of linear region, move to the linear
        # region we are on the boundary of.
        if unit_inds.size:
            flipped = self._flip_boundary_constraints(
                self.current_vertex, unit_inds, ex_inds)
            if flipped is not None:
                maybe_loss = self.vertex_to_loss.get(_vtx_key(flipped), None)
                if (
                    maybe_loss is None
                    or self._are_about_equal(current_loss, maybe_loss)
                    or maybe_loss < current_loss
                ):
                    return flipped

        return None

    def _get_next_vertex_in_level_set(self):
        vtx = self._inner_get_next_vertex_in_level_set()
        if vtx is not None and _vtx_key(vtx) not in self.explored_level_set_vertices:
            return vtx, None
        else:
            for key, (vertex, maybe_results) in self.current_level_set.items():
                if key not in self.explored_level_set_vertices:
                    return vertex, maybe_results
        raise LevelSetExhaustedError('Ran out of vertices in the level set.')

    #################################################################

    def _solve_vertex_iter(self, vertex_iter):
        for vertex in vertex_iter:
            loss, maybe_results = self._get_loss_and_maybe_results(vertex)
            yield vertex, loss, maybe_results

    #################################################################

    def _find_next_vertex(self):
        # Returns (next_vertex, is_next_vertex_within_current_level_set)
        current_loss = self._get_loss(self.current_vertex)

        if self._are_about_equal(current_loss, 0.0):
            return None, None, None

        neighbors_iter = self._neighbors_to_try_iterator(self.current_vertex)
        for neighbor, neighboring_loss, maybe_results in self._solve_vertex_iter(neighbors_iter):
            if self._are_about_equal(current_loss, neighboring_loss):
                self.current_level_set[_vtx_key(neighbor)] = (neighbor, maybe_results)

            elif neighboring_loss < current_loss:
                return neighbor, False, maybe_results

        self.explored_level_set_vertices.add(_vtx_key(self.current_vertex))

        # Now we are on a non-escaping node of a level set.
        print('On non-escaping node of a level set.')

        max_ls_attempts = self.options.max_level_set_attempts
        n_ls_attempts = len(self.explored_level_set_vertices)
        if max_ls_attempts is not None and n_ls_attempts > max_ls_attempts:
            return None, None, None

        neighbor, maybe_results = self._get_next_vertex_in_level_set()
        key = _vtx_key(neighbor)

        if key not in self.explored_level_set_vertices:
            return neighbor, True, maybe_results

        # TODO: Figure out the correct behavior for this. We might
        # end up also never reaching this code, so it might not matter.
        raise ValueError

    def solve_iter(self, max_iters=250):
        for i in range(max_iters):
            vertex, within_level_set, maybe_results = self._find_next_vertex()
            if vertex is None:
                break
            else:
                self._update_state(vertex, within_level_set, maybe_results)

            yield i

    def solve(self, max_iters=250):
        for i in self.solve_iter(max_iters):
            print(f'{i}: {self._get_loss(self.current_vertex)}')

    #################################################################
    
    def get_current_vertex_results(self):
        if self.current_vertex_results is None:
            self.current_vertex_results = self._uncached_solve_for_vertex(
                self.current_vertex)
        return self.current_vertex_results

    def get_network_output_at_current_vertex(self, X):
        results = self.get_current_vertex_results()
        p = results.model_params
        v = self.v
        preacts = np.einsum('nd,dm->nm', X, p.w) + p.b
        acts = np.maximum(preacts, 0.0)
        return np.einsum('nm,m->n', acts, v) + p.c

    def get_network_at_current_vertex(self):
        # TODO: Cache the param values somewhere so we don't have to recompute them!
        # TODO: Ensure masking is handled properly.
        results = self.get_current_vertex_results()
        p = results.model_params
        v = self.v
        return network_util.network_from_parameters(p.w, p.b, v, p.c)

    def reset_vertex_from_weights(self, w, b):
        preacts = np.einsum('nd,dm->nm', self.X, w) + b
        vertex = (preacts.T > 0).astype(np.int32)
        self._update_state(vertex, within_level_set=False, maybe_results=None)
