"""
File to contain stuff about domain restrictions. We can typically speed things
up a good deal by using domain restrictions, but these are complicated. We
can also incorporate upper bounds into this object.

"""
import numpy as np
import numbers
import utilities as utils
import torch
import torch.nn as nn
import copy
from cvxopt import solvers, matrix


class Domain(object):
    """ Can support combinations of box + l2 bounds """
    def __init__(self, dimension, x):
        """ For now just set the dimension of the ambient space
            and the central point (which can be none)"""
        self.dimension = dimension

        if x is None:
            self.x = None
        else:
            self.x = utils.as_numpy(x).reshape(-1)

        # Things we'll set later
        self.box_low = None
        self.box_high = None
        self.l2_radius = None
        self.linf_radius = None

        # Original box constraints to be kept separate from those generated
        # by upper bounds.
        self.original_box_low = None
        self.original_box_high = None
        self.unmodified_bounds_low = None
        self.unmodified_bounds_high = None

    def as_dict(self):
        return {'dimension':                self.dimension,
                'x':                        self.x,
                'box_low':                  self.box_low,
                'box_high':                 self.box_high,
                'l2_radius':                self.l2_radius,
                'linf_radius':              self.linf_radius,
                'original_box_low':         self.original_box_low,
                'original_box_high':        self.original_box_high,
                'unmodified_bounds_low':    self.unmodified_bounds_low,
                'unmodified_bounds_high':   self.unmodified_bounds_high}

    @classmethod
    def from_dict(cls, saved_dict):
        domain = cls(saved_dict['dimension'], saved_dict['x'])

        for s in ['box_low', 'box_high', 'l2_radius', 'linf_radius',
                  'original_box_low', 'original_box_high',
                  'unmodified_bounds_low',  'unmodified_bounds_high']:
            setattr(domain, s, saved_dict[s])

        return domain


    ###########################################################################
    #                                                                         #
    #                   FORWARD FACING METHODS                                #
    #                                                                         #
    ###########################################################################
    def set_original_hyperbox_bound(self, lo, hi):
        """ Sets the original hyperbox bounds which don't ever get modified """

        # Standard hyperbox setup
        lo = self._number_to_arr(lo)
        hi = self._number_to_arr(hi)
        self.set_hyperbox_bound(lo, hi)

        # And then do the original things
        self.original_box_low = lo
        self.original_box_high = hi
        self.unmodified_bounds_high = np.ones(self.dimension, dtype=np.bool)
        self.unmodified_bounds_low = np.ones(self.dimension, dtype=np.bool)



    def set_hyperbox_bound(self, lo, hi):
        self._add_box_constraint(lo, hi)

    def set_upper_bound(self, bound, lp_norm):
        {'l_inf': self.set_l_inf_upper_bound,
         'l_2': self.set_l_2_upper_bound}[lp_norm](bound)


    def set_l_inf_upper_bound(self, bound):
        assert self.x is not None
        if self.linf_radius is not None:
            self.linf_radius = min([bound, self.linf_radius])
        else:
            self.linf_radius = bound

        self._add_box_constraint(self.x - bound, self.x + bound)

    def set_l_2_upper_bound(self, bound):
        if bound is None:
            return
        assert self.x is not None
        if self.l2_radius is not None:
            self.l2_radius = min([bound, self.l2_radius])
        else:
            self.l2_radius = bound
        # also update box constraints if we can
        self._add_box_constraint(self.x - bound, self.x + bound)


    def feasible_facets(self, A, b, indices_to_check=None):
        """ Given numpy arrays A, b (corresponding to polytope Ax <= b)
            we want to know which constraints of the form <a_i, x> = b_i
            are feasible within the specified domain
        ARGS:
            A : numpy.ndarray (M x self.dimension) - constraint matrix
            b : numpy.ndarray (M) - constants
            indices_to_check : list of indices (out of M) to check (in the case
                               that we don't want to check them all)
        RETURNS:
            SET of indices that are viable under the l2 and l-inf box.
            Not everything in this list is feasible, but everything that is
            rejected is INFEASIBLE
        """
        A, b, map_fxn = self._idx_map_helper(A, b, indices_to_check)

        l_inf_set = self._linf_box_feasible_facets(A, b)
        l_2_set = self._l2_ball_feasible_facets(A, b)
        both_set = l_inf_set.intersection(l_2_set)
        return set(map_fxn(i) for i in both_set)


    def minimal_facet_projections(self, A, b, indices_to_check=None):
        """ Given numpy arrays A, b (corresponding to polytope Ax <= b)
            we want to know which constraints of the form <a_i, x> = b_i
            have minimal projections that fall within the specified l_2, l_inf
            bounds
        ARGS:
            A : numpy.ndarray (M x self.dimension) - constraint matrix
            b : numpy.ndarray (M) - constants
            indices_to_check : list of indices (out of M) to check (in the case
                               that we don't want to check them all)
        RETURNS:
            SET of indices that are viable under the l2 and l-inf box.
            Not everything in this list is feasible, but everything that is
            rejected is INFEASIBLE
        """
        self._compute_linf_radius()
        A, b, map_fxn = self._idx_map_helper(A, b, indices_to_check)

        l_inf_set = self._minimal_facet_projection_helper(A, b, 'l_inf')
        l_2_set = self._minimal_facet_projection_helper(A, b, 'l_2')
        both_set = l_inf_set.intersection(l_2_set)
        return set(map_fxn(i) for i in both_set)



    def original_box_constraints(self):
        """ Returns two np arrays for the hyperplane constraints that are in
            both the original constraints and the hyperbox low/hi bounds
        """
        eps = 1e-8
        As, bs = [], []
        if self.box_low is not None:
            As.append(-1 * np.eye(self.dimension)[self.unmodified_bounds_low])
            bs.append((-1 * self.box_low[self.unmodified_bounds_low]))
        if self.box_high is not None:
            As.append(np.eye(self.dimension)[self.unmodified_bounds_high])
            bs.append((self.box_high[self.unmodified_bounds_high]))

        if As != []:
            return np.vstack(As), np.hstack(bs)
        else:
            return None, None


    def box_constraints(self):
        """ Returns two np arrays for the hyperplane constraints if we're
            box bounded.
        RETURNS (A, b) for
            - A is a (2N, N) numpy array
            - B is a (2N,) numpy array

        VERTICAL ORDER IS ALWAYS LOWER_BOUNDS -> UPPER_BOUNDS
        """
        As, bs = [], []

        if self.box_low is not None:
            As.append(-1 * np.eye(self.dimension))
            bs.append(-1 * self.box_low)
        if self.box_high is not None:
            As.append(np.eye(self.dimension))
            bs.append(self.box_high)

        if As != []:
            return np.vstack(As), np.hstack(bs)
        else:
            return None, None


    def nonredundant_box_constraints(self, A, b, tight_idx):
        """ Computes an index of redundant box constraints given a system
            Ax <= b, where the i'th row is tight (i.e. <a_i, x> = b_i)

            This is just a fast check to see which of the box constraint
            hyperplanes don't intersect the hyperplane <a_i, x> = b_i
        RETURNS:
            (A,b ) where A is an (M, n) array for M <= 2n
                     and b is an (M,) array

        OLD DOCUMENTATION: ...
        Let a, b be the tight constraints (<a, x> = b)
        and let L, U be the box constraints (vectors of size n)

        Then a necessary (but not sufficient) condition for feasibility of this
        facet is that
            - <a+, U> + <a-, L> := b_u >= b  (and)
            - <a+, L> + <a-, U> := b_l <= b
        Now the hyperplane <a, x> = b intersects hyperplane (x_i = c) iff
            b_u - (a_i+ * u_i + a_i- * l_i) + a_i * c>= b  (and)
            b_l - (a_i+ * l_i + a_i- * u_i) + a_i * c>= b
        """

        # First check inputs/state makes sense
        assert self.box_low is not None
        assert self.box_high is not None
        assert isinstance(tight_idx, int)
        a, b = A[tight_idx].reshape(-1), b[tight_idx]

        # Next separate a into its positive and negative components
        a_plus = np.maximum(a, 0, a.copy())
        a_minus = a - a_plus

        # compute upper bounds and lower bounds for ALL indices
        a_plus_u = a_plus * self.box_high
        a_plus_l = a_plus * self.box_low
        a_minus_u = a_minus * self.box_high
        a_minus_l = a_minus * self.box_low

        b_u = np.sum(a_plus_u + a_minus_l) # this better be >= b
        b_l = np.sum(a_plus_l + a_minus_u) # this better be <= b

        # compute upper/lower bounds as vectors lacking the i'th component
        b_u_lacking_i = b_u - a_plus_u - a_minus_l
        b_l_lacking_i = b_l - a_plus_l - a_minus_u

        # Compute upper bound feasibilities
        uppers_u = b_u_lacking_i + a * self.box_high  >= b
        uppers_l = b_l_lacking_i + a * self.box_high <= b
        uppers = np.logical_and(uppers_u, uppers_l)

        # Compute lower bound feasibilities
        lowers_u = b_u_lacking_i + a * self.box_low >= b
        lowers_l = b_l_lacking_i + a * self.box_low <= b
        lowers = np.logical_and(lowers_u, lowers_l)

        # Boolean selector array is lower -> upper
        selector = np.hstack((lowers, uppers))

        box_constraint_A, box_constraint_b = self.box_constraints()

        return box_constraint_A[selector, :], box_constraint_b[selector]


    def box_to_tensor(self):
        """ If box bounds are not None, returns a tensor version of these bounds
            which is useful for interval propagation
        """
        if self.box_low is None or self.box_high is None:
            return None
        else:
            stacked = np.hstack([self.box_low.reshape(-1, 1),
                                 self.box_high.reshape(-1, 1)])
            return torch.Tensor(stacked)


    def current_upper_bound(self, lp_norm):
        """ Accessor method for current upper bound on each norm """
        return {'l_2': self.l2_radius,
                'l_inf': self.linf_radius}[lp_norm]


    def contains(self, y):
        """ Given a numpy array y (of shape (self.dimension,)), checks to see
            if y is valid in the domain
        """

        assert isinstance(y, np.ndarray)
        y = y.reshape(-1)

        checks = []

        # Box checks
        if self.box_low is not None:
            checks.append(all(y >= self.box_low))
        if self.box_high is not None:
            checks.append((all(y <= self.box_high)))

        # Linf checks
        if self.linf_radius is not None:
            checks.append(abs(y - self.x).max() <= self.linf_radius)

        # L2 checks
        if self.l2_radius is not None:
            checks.append(np.linalg.norm(y - self.x, 2) <= self.l2_radius)

        return all(checks)


    ###########################################################################
    #                                                                         #
    #    FEASIBILITY / MINIMAL PROJECTION HELPERS                             #
    #                                                                         #
    ###########################################################################

    @classmethod
    def _idx_map_helper(cls, A, b, indices_to_check=None):
        if indices_to_check is None:
            identity_fxn = lambda i : i
            return A, b, identity_fxn
        else:
            indices_to_check = list(sorted(indices_to_check))
            A = A[indices_to_check, :]
            b = b[indices_to_check]
            idx_map = {i: el for i, el in enumerate(indices_to_check)}
            map_fxn = lambda i: idx_map[i]
            return A, b, map_fxn


    def _linf_box_feasible_facets(self, A, b):
        """ Same args as self.feasible_facets """
        m = A.shape[0]
        if self.box_low is None or self.box_high is None:
            return set(range(m))

        A_plus = np.maximum(A, 0)
        A_minus = np.minimum(A, 0)

        upper_check = (A_plus.dot(self.box_high) +
                       A_minus.dot(self.box_low)) >= b
        lower_check = (A_plus.dot(self.box_low) +
                       A_minus.dot(self.box_high)) <= b
        total_check = np.logical_and(upper_check, lower_check)

        return {i for i in range(m) if total_check[i]} # <-- this is a set


    def _l2_ball_feasible_facets(self, A, b):
        """ Same args as self.feasible_facets """
        return set(range(A.shape[0])) # NOT IMPLEMENTED


    def _minimal_facet_projection_helper(self, A, b, lp):
        upper_bound = {'l_2': self.l2_radius, 'l_inf': self.linf_radius}[lp]
        if upper_bound is None:
            return set(range(A.shape[0]))
        dual_norm = {'l_2': None, 'l_inf': 1}[lp]

        duals = np.linalg.norm(A, ord=dual_norm, axis=1)

        under_upper_bound = np.divide(b - A.dot(self.x), duals) <= upper_bound
        return set(i for i, el in enumerate(under_upper_bound) if el)



    ###########################################################################
    #                                                                         #
    #                   HELPERS FOR BOX BOUNDS                                #
    #                                                                         #
    ###########################################################################



    def _number_to_arr(self, number_val):
        """ Converts float to array of dimensi
        """
        assert isinstance(number_val, numbers.Real)
        return np.ones(self.dimension) * number_val


    def _compute_new_lohi(self, lo_or_hi, vals):
        """ Takes an array to replace new lows or highs
        ARGS:
            lo_or_hi : string - 'lo' or 'hi'
            vals: np.array of dimension self.dimension - new bounds to be
                                                         considered
        RETURNS:
            - array that's elementwise max or min, depending on lo_or_hi
            - boolean numpy array with the things that have changed
        """
        eps = 1e-8
        assert lo_or_hi in ['lo', 'hi']
        assert isinstance(vals, np.ndarray) and vals.shape == (self.dimension,)
        comp = {'lo': np.maximum,'hi': np.minimum}[lo_or_hi]
        current = {'lo': self.box_low, 'hi': self.box_high}[lo_or_hi]
        output = comp(vals, current)
        unchanged = abs(current - output) < eps
        return output, unchanged

    def _add_box_constraint(self, lo, hi):
        """ Adds a box constraint.
        ARGS:
            lo: float or np.array(self.dimension) - defines the coordinate-wise
                lowerbounds
            hi: float or np.array(self.dimension) - defines the coordinate-wise
                upperbounds
        RETURNS:
            None
        """

        # Make sure these are arrays
        if isinstance(lo, numbers.Real):
            lo = self._number_to_arr(lo)
        if isinstance(hi, numbers.Real):
            hi = self._number_to_arr(hi)

        # Set the lows and highs if they're not already set
        set_lo, set_hi = False, False
        if self.box_low is None:
            set_lo = True
            self.box_low = lo
        if self.box_high is None:
            set_hi = True
            self.box_high = hi

        # Otherwise, set the lows and highs by taking elementwise max/min
        if not set_lo:
            self.box_low, low_unchanged = self._compute_new_lohi('lo', lo)
            np.logical_and(self.unmodified_bounds_low, low_unchanged,
                           self.unmodified_bounds_low)
        if not set_hi:
            self.box_high, high_unchanged = self._compute_new_lohi('hi', hi)
            np.logical_and(self.unmodified_bounds_high, high_unchanged,
                           self.unmodified_bounds_high)

        return

    def _compute_linf_radius(self):
        """ Modifies the self.linf_radius based on the box bounds """
        if self.box_high is None and self.box_low is None:
            return None

        linf_radius = max([np.abs(self.box_high - self.x).max(),
                           np.abs(self.x - self.box_low).max()])
        if self.linf_radius is None:
            self.linf_radius = linf_radius
        else:
            self.linf_radius = min(linf_radius, self.linf_radius)
        return self.linf_radius

    ###########################################################################
    #                                                                         #
    #                   HELPERS FOR L2 DEAD CONSTRAINTS                       #
    #                                                                         #
    ###########################################################################


    def l2_bound_layer1(self, weight, bias, x):
        """ Given input weights and biases will compute if each neuron can be
            on or off

        We do this using cvxopt conelp, but it's a messy formulation
        the G matrix looks like
        [ -I | I | 0row | -I]^T
        and the h matrix looks like
        [-lo_bounds, u_bounds, radius, -x]
        """

        G_list = [-1 * np.eye(self.dimension),
                  np.eye(self.dimension),
                  np.zeros((1, self.dimension)),
                  -1 * np.eye(self.dimension)]
        G = matrix(np.vstack(G_list).astype(np.float))

        h_list = [-1 * self.box_low,
                  self.box_high,
                  np.array([self.l2_radius]),
                  -x]
        h = matrix(np.hstack(h_list).astype(np.float))

        dims = {'l': self.dimension * 2, 'q': [self.dimension + 1], 's': [0]}

        m = weight.shape[0]

        # Do lowers first
        new_lows, new_highs = [], []
        for scale, working_list in [(1, new_lows), (-1, new_highs)]:
            for i in range(m):
                if (i % 10) == 0:
                    print(scale, i)
                c = matrix(weight[i].astype(np.float) * scale)
                import time
                start = time.time()
                solve_out = solvers.conelp(c, G, h, dims, solver='mosek')
                print(i, time.time() - start)
                working_list.append(scale * solve_out['primal objective'] +
                                    bias[i])

        return new_lows, new_highs




