import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt


class PWA():

    def __init__(self):
        # vectors defining the key hyperplanes for the switching surface
        hyp = np.array([[-1, 1],
                        [1, 1]])

        # vectors orthogonal to these vectors (can use gramm-schmidt, but can also define manually)
        orth = np.array([[1, 1],
                         [-1, 1]])

        # Output: 1

        # linear maps in each region (going around 6 regions counter-clockwise)
        a2 = -2 * np.array([-0.5, -0.2])
        a4 = -2 * np.array([0.7, 0.1])

        # compute coefficients a1
        Atmp = orth[[0, 1], :]
        btmp = np.vstack([np.inner(Atmp[0, :], a2), np.inner(Atmp[1, :], a4)])

        a1 = np.linalg.solve(Atmp, btmp)
        a1 = np.transpose(a1)

        # compute coefficients a3
        Atmp = orth[[1, 0], :]
        btmp = np.vstack([np.inner(Atmp[0, :], a2), np.inner(Atmp[1, :], a4)])

        a3 = np.transpose(np.linalg.solve(Atmp, btmp))

        self.A1 = np.vstack([a1, a2, a3, a4])

        # Output: 2

        # linear maps in each region (going around 6 regions counter-clockwise)
        a2 = 1.0 * np.array([-0.6, -0.6])
        a4 = 1.0 * np.array([0.8, 0.2])

        # compute coefficients a1
        Atmp = orth[[0, 1], :]
        btmp = np.vstack([np.inner(Atmp[0, :], a2), np.inner(Atmp[1, :], a4)])

        a1 = np.linalg.solve(Atmp, btmp)
        a1 = np.transpose(a1)

        # compute coefficients a3
        Atmp = orth[[1, 0], :]
        btmp = np.vstack([np.inner(Atmp[0, :], a2), np.inner(Atmp[1, :], a4)])

        a3 = np.transpose(np.linalg.solve(Atmp, btmp))

        self.A2 = np.vstack([a1, a2, a3, a4])

        self.hyp = hyp
        self.orth = orth

        self.B = np.array([[0], [1]])
        self.C = np.array([1, 0])

    def state_transition(self, x, u):

        # manually code this in
        hyp = self.hyp
        A1 = self.A1
        A2 = self.A2

        if (np.inner(hyp[0, :], x) <= 0) and (np.inner(hyp[1, :], x) >= 0):  # a1
            new = np.hstack([np.inner(A1[0, :], x), np.inner(A2[0, :], x)])

        elif (np.inner(hyp[0, :], x) >= 0) and (np.inner(hyp[1, :], x) >= 0):  # a2
            new = np.hstack([np.inner(A1[1, :], x), np.inner(A2[1, :], x)])

        elif (np.inner(hyp[0, :], x) >= 0) and (np.inner(hyp[1, :], x) <= 0):  # a3
            new = np.hstack([np.inner(A1[2, :], x), np.inner(A2[2, :], x)])

        else:
            new = np.hstack([np.inner(A1[3, :], x), np.inner(A2[3, :], x)])

        new += np.matmul(self.B, u)

        return new

    def get_A_for_mode(self, x):

        hyp = self.hyp
        A1 = self.A1
        A2 = self.A2

        if (np.inner(hyp[0, :], x) <= 0) and (np.inner(hyp[1, :], x) >= 0):  # a1
            return np.vstack([A1[0, :], A2[0, :]])

        elif (np.inner(hyp[0, :], x) >= 0) and (np.inner(hyp[1, :], x) >= 0):  # a2
            return np.vstack([A1[1, :], A2[1, :]])

        elif (np.inner(hyp[0, :], x) >= 0) and (np.inner(hyp[1, :], x) <= 0):  # a3
            return np.vstack([A1[2, :], A2[2, :]])

        else:
            return np.vstack([A1[3, :], A2[3, :]])

    def output_map(self, x):
        return np.matmul(self.C, x)

    def simulate(self, x0, u):
        xt = x0  # initial condition
        T = u.shape[0]  # data length
        nx = x0.shape[0]
        x = np.zeros((T + 1, nx))
        y = np.zeros((T + 1, 1))

        x[0, :] = xt
        y[0, :] = self.output_map(xt)

        for t in range(T):
            ut = u[t, :]
            xt = self.state_transition(xt, ut)

            x[t + 1, :] = xt
            y[t + 1, :] = self.output_map(xt)

        return x

    def state_transition_from_numpy(self, x, u):
        return self.state_transition(x, u)


    def visualize_net(self, xrange, yrange, dynamics, elevation, nu):
        num_data = xrange.size * yrange.size

        X, Y = np.meshgrid(xrange, yrange)

        data = np.zeros((4, num_data))

        colors = ['c', 'r']

        di = 0

        for idx_x, xi in enumerate(xrange):

            for idx_y, yi in enumerate(yrange):
                x = dynamics.simulate(np.array([xi, yi]), nu)
                data[:, di] = np.hstack([xi, yi, x[1, 0], x[1, 1]])
                di += 1

        for k in range(2):

            D = data[2 + k, :].reshape((yrange.size, xrange.size)).T

            for e in elevation:
                fig = plt.figure()
                ax = fig.add_subplot(projection='3d')
                ax.plot_surface(X, Y, D, color=colors[k])
                ax.view_init(30, e)
                plt.xlabel('x1')
                plt.ylabel('x2')
                plt.title('x ' + str(k + 1) + '_new')
                plt.show()

        return data, X, Y
