import fix_imports
import tqdm
import argparse
import time
import numpy as np
import torch

import matplotlib
from sys import platform as sys_pf
if sys_pf == 'darwin':
    matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import matplotlib.style as style
import matplotlib.collections as mc
from model import model_from_state, make_model
import seaborn as sns


def sample_vectors(state, scale=1.0):
    x = np.sort(np.squeeze(state['uv']))

    lines = np.zeros([len(x), 2, 2])
    lines[:, 0, 1] = -x
    lines[:, 0, 0] = np.ones_like(x)
    lines[:, 1, 1] = x
    lines[:, 1, 0] = -np.ones_like(x)

    return lines*scale


def plot_bg(lines, ax, alpha, s=1.0):
    lx = [-s, s]
    ly1 = [-s, -s]
    ly2 = [s, s]
    ax.fill_between(lx, ly2, ly1, alpha=alpha)

    for i in range(lines.shape[0]-1):
        idx_next = (i + 1)
        lx = [-s, s]
        ly1 = [lines[i, 1, 1], lines[i, 0, 1]]
        ly2 = [lines[idx_next, 1, 1], lines[idx_next, 0, 1]]
        ax.fill_between(lx, ly1, ly2, alpha=alpha)


def least_squares_ab(state):
    x = torch.from_numpy(state['uv'].astype(np.float32))
    y = torch.from_numpy(state['x'].astype(np.float32))

    model = make_model(1, 1, 50000)
    a = model[0].weight
    b = model[0].bias
    # a = state['saved_states'][0][1]['0.weight']
    # b = state['saved_states'][0][1]['0.bias']

    M = a @ x.transpose(0, 1) + b.unsqueeze(1)
    M = torch.max(M, torch.zeros_like(M)).detach().cpu().numpy().transpose()

    M_reg = M + 0 * np.eye(M.shape[0], M.shape[1])

    lsq_soln = np.linalg.lstsq(M_reg, y.cpu().squeeze().numpy(), rcond=-1)
    c = lsq_soln[0].T
    c = torch.from_numpy(c).to(x)

    a, b, c = [x.squeeze().cpu().detach().numpy() for x in (a, b, c)]
    ac, bc = c*a, c*b

    nz_idx = np.nonzero(ac)
    ac, bc = ac[nz_idx], bc[nz_idx]

    normalize = np.linalg.norm(np.concatenate([ac.reshape(len(ac), 1), bc.reshape(len(ac), 1)], axis=1), axis=1)
    ac_n, bc_n = ac / normalize, bc / normalize

    angles = np.arctan2(ac_n, bc_n)
    polar_sort_idx = np.argsort(angles)

    return ac[polar_sort_idx], bc[polar_sort_idx]


def main():
    print(style.available)
    style.use('seaborn')

    argparser = argparse.ArgumentParser()
    argparser.add_argument("state_file", type=str)
    argparser.add_argument("--save", type=str, default="", help="Save as a video")
    argparser.add_argument("--plot-residual", action="store_true", help="Plot the residual")
    argparser.add_argument("--samples", type=int, default=500, help="Number of curve samples")
    argparser.add_argument("--print-every", type=int, default=100, help="Print a message every k epochs")
    argparser.add_argument("--num-trajectories", "-nt", type=int, default=10, help="Number of trajectory lines to plot")
    argparser.add_argument("-o", "--output", type=str, default="", 
                           help="Filename to save the figure to. Will display the figure if no file name is set")
    argparser.add_argument("-s", "--scale", type=float, default=1.0, help="Scale factor for the figure")
    args = argparser.parse_args()

    state = torch.load(args.state_file)

    num_states = len(state['saved_states'])
    t_gt = state['uv']
    ft_gt = state['x']

    min_x, max_x = np.min(t_gt), np.max(t_gt)
    min_y, max_y = np.min(ft_gt), np.max(ft_gt)
    pad_x, pad_y = 0.2*(max_x-min_x), 0.2*(max_y-min_y)

    knots_a = np.zeros([num_states, state['num_hidden_units'][0]])
    knots_b = np.zeros([num_states, state['num_hidden_units'][0]])
    knots_idx = np.zeros([num_states, state['num_hidden_units'][0]], dtype=int)
    colors = np.zeros([num_states, state['num_hidden_units'][0]])

    sample_vectors(state)

    model = model_from_state(state)

    a = state['saved_states'][0][1]['0.weight']
    b = state['saved_states'][0][1]['0.bias']
    e = (-b.squeeze()/a.squeeze()).detach().cpu().numpy()
    in_range_idx = []
    for i in range(len(e)):
        if min_x < e[i] < max_x:
            in_range_idx.append(i)
    in_range_idx = np.array(in_range_idx)
    if args.num_trajectories > len(in_range_idx):
        in_range_idx = np.arange(len(e))
    trajectory_neurons = np.random.choice(in_range_idx, min(args.num_trajectories, len(in_range_idx)))

    for i in tqdm.tqdm(range(num_states)):
        model.load_state_dict(state['saved_states'][i][1])

        a = state['saved_states'][i][1]['0.weight']
        b = state['saved_states'][i][1]['0.bias']
        c = state['saved_states'][i][1]['2.weight']

        with torch.no_grad():
            knots_a[i] = a.squeeze().cpu().numpy() * np.abs(c.squeeze().cpu().numpy())
            knots_b[i] = b.squeeze().cpu().numpy() * np.abs(c.squeeze().cpu().numpy())

            colors[i] = torch.sign(c).squeeze().cpu().numpy()
            colors[i] += np.ones_like(colors[i])
            colors[i] /= 2.0
    s = args.scale
    if args.output:
        fig = plt.figure(figsize=(10, 10), dpi=240)
    else:
        fig = plt.figure()
    ax = plt.axes()

    frame1 = plt.gca()
    frame1.axes.xaxis.set_ticklabels([])
    frame1.axes.yaxis.set_ticklabels([])

    cmap = sns.color_palette()
    lines = sample_vectors(state, scale=s)
    ax.axis([-s, s, -s, s])
    plot_bg(lines, ax, alpha=0.4, s=s)
    lc = mc.LineCollection(lines, linewidths=1, label="$x_i$", color=cmap[1])
    ax.add_collection(lc)
    ax.plot(knots_a[:, trajectory_neurons], knots_b[:, trajectory_neurons], linewidth=2, color=cmap[2])
    ax.scatter(knots_a[-1, trajectory_neurons], knots_b[-1, trajectory_neurons], s=128, color=cmap[0])

    if args.output:
        plt.savefig(args.output, dpi=240)
    else:
        plt.show()


if __name__ == "__main__":
    main()
