import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
import os.path
import seaborn as sns
import os
import argparse
# from icecream import ic
def ic(*args, **kwargs):
    pass


################ plot settings #################################
import matplotlib as mpl
mpl.rcParams['font.serif'] = ['times new roman']
mpl.rcParams['text.latex.preamble'] = r'\usepackage{newtxmath}'
mpl.rcParams['lines.linewidth'] *= 2
mpl.rcParams['lines.markersize'] *= 1.5
heatmap_font_size = 30 
mpl.rc('font',**{'family':'serif','serif':['Times'], 'size': heatmap_font_size})
mpl.rc('legend', **{'fontsize': 20})
mpl.rc('text', usetex=True)

def eval_loss_p(Q, theta_star, eps, theta_1, sigma_w=0.1):
    theta = cp.Variable(theta_star.shape)
    sigma_theta = cp.norm(cp.hstack([Q.T @ (theta - theta_star), sigma_w]))
    c_1 = np.sqrt(2/np.pi)
    c_2 = 1 - c_1 ** 2
    obj = c_2 * sigma_theta ** 2  + (eps * cp.norm(theta, q) + c_1 * sigma_theta) ** 2
    theta.value = theta_1
    return obj.value

def my_norm(x, q):
    if q == 2:
        return np.linalg.norm(x)
    if q == 1:
        return np.sum(np.abs(x))
    if q == 'inf':
            return np.max(np.abs(x))
def eval_loss(Q, theta_star, eps, theta, q=2, sigma_w=0.1):
    # sigma_theta = np.linalg.norm(Q.T @ (theta - theta_star))
    sigma_theta = np.linalg.norm(Q.T @ (theta - theta_star))
    sigma_theta = np.sqrt(sigma_theta**2 + sigma_w**2)
    # sigma_theta = cp.norm(cp.hstack([Q.T @ (theta - theta_star), sigma_w]))
    if eps is None:
        return sigma_theta

    c_1 = np.sqrt(2/np.pi)
    c_2 = 1 - c_1 ** 2
    obj = c_2 * sigma_theta ** 2  + (eps * my_norm(theta, q) + c_1 * sigma_theta) ** 2
    return obj

def robust(Q, theta_star, eps, q=2, return_value=False, sigma_w=0.1):
    theta = cp.Variable(theta_star.shape)
    # sigma_theta = cp.norm(Q.T @ (theta - theta_star))
    sigma_theta = cp.norm(cp.hstack([Q.T @ (theta - theta_star), sigma_w]))
    c_1 = np.sqrt(2/np.pi)
    c_2 = 1 - c_1 ** 2
    obj = c_2 * sigma_theta ** 2  + (eps * cp.norm(theta, q) + c_1 * sigma_theta) ** 2
    obj = cp.Minimize(obj)
    prob = cp.Problem(obj, [])
    prob.solve(solver=cp.SCS)
    ic(prob.status)
    if return_value:
        return theta.value, prob.value
    return theta.value

def get_norm_frac(theta, c):
    x = theta**2
    return x[c:].sum() / x.sum()

def gen_Q(p=5, c=2, corr_in=0.5, eta=0.1, scale=1):
    """
    Q: 
    1   0.5   0   0   0
    0.5   1   0   0   0
    0.1 0.1 1   0   0
    0.1 0.1 0   1   0
    0.1 0.1 0   0   1
    """
    Q = np.zeros((p, p))

    for i in range(p):
        for j in range(p):
            if i in range(c):
                if j == i:
                    Q[i, j] = 1
                elif j in range(c):
                    Q[i, j] = corr_in
            else:
                if j == i:
                    Q[i, j] = 1
                elif j in range(c):
                    Q[i, j] = eta
    for i in range(p):
        Q[i, :] /= np.linalg.norm(Q[i, :])

    Q[-1,:] *= scale
    Q[-2,:] *= scale
    Q[-3,:] *= scale

    Sigma = Q @ Q.T
    return Sigma, Q

def plot_core_vs_total_acc(save=True):
    p, c = 5, 2
    theta_star = np.zeros(p)
    theta_star[:c] = 1

    corr_in = 0.5
    corr_cr = 0.3

    Sigma, Q_total = gen_Q(p=p, c=c, corr_in=corr_in, eta=corr_cr)
    Q_core = np.linalg.cholesky(Sigma[:c, :c])
    dict_name_q = {
            2: '2',
            1: 'inf',
            'inf': 1,
            }
    dict_range_q = {
            2: np.linspace(0, 2, 100),
            1: np.linspace(0, 4, 200),
            'inf': np.linspace(0, 4, 100),
            }

    if not save:
        fig_i = 1
        fig = plt.figure()

    set_eps_eval = set()
    for q, range_q in dict_range_q.items():
        loss = {
                'core': {
                    },
                'total': {
                    },
                }
        li_eps = range_q
        for eps in li_eps:
            theta_total = robust(Q_total, theta_star, eps, q)
            theta_core = robust(Q_core, theta_star[:c], eps, q)
            eps_base = np.mean(li_eps)
            dict_eps_eval = {
                    'Loss with epsilon=0': None,
                    'Loss with same epsilon as training': eps,
                    f'Loss with epsilon=mean': eps_base,
                    }
            for eps_name, eps_eval in dict_eps_eval.items():
                loss_total = eval_loss(Q_total, theta_star, eps_eval, theta_total, q)
                loss_core = eval_loss(Q_core, theta_star[:c], eps_eval, theta_core, q)
                loss['core'].setdefault(eps_name, []).append(loss_core)
                loss['total'].setdefault(eps_name, []).append(loss_total)

            set_eps_eval = dict_eps_eval.keys()

        for eps_name in set_eps_eval:
            if not save:
                fig.add_subplot(3, 3, fig_i)
                fig_i += 1
            for str_core in ['core', 'total']:
                plt.plot(li_eps, loss[str_core][eps_name], marker='*', label=str_core)
            plt.xlabel('Epsilon')
            plt.ylabel(eps_name)
            plt.title(f"{eps_name} feautres, norm={dict_name_q[q]}")
            plt.legend()
            if save:
                plt.savefig(f'figures/core vs total loss/{eps_name}, norm={dict_name_q[q]}.jpeg')
                plt.clf()
    if not save:
        plt.show()

def script_distributional():
    dict_range_q = {
            # 2: (np.linspace(0, 2, 100), [1, 1.3]),
            # 1: (np.linspace(0, 4, 200), [1, 1.3]),
            # 'inf': (np.linspace(0, 4, 100), [1.6, 1.8]),
            }
    li_eps = [1.3, 1.8]
    li_noise = np.linspace(0, 2, 100)
    plot_distributional(li_eps, li_noise, 0.25, q='inf', plot=True)
    li_eps = [1, 1.35]
    li_noise =  np.linspace(0, 2, 100)
    plot_distributional(li_eps, li_noise, 0.25, q=2, plot=True)
    # li_eps = [1, 1.35]
    # li_noise =  np.linspace(0, 4, 200)
    # plot_distributional(li_eps, li_noise, 0.25, q=1, plot=True)
    li_eps = [1, 1.15]
    li_noise =  np.linspace(0, 4, 200)
    plot_distributional(li_eps, li_noise, 0.25, q=1, scale=3, plot=True)


def plot_distributional(li_eps, li_noise, eta, p=5, c=2, q='inf', scale=1, plot=False):
    heatmap_font_size = 22 
    mpl.rc('font',**{'family':'serif','serif':['Times'], 'size': heatmap_font_size})

    name = f"distributional_eps__{st_float(li_eps)}__noise__{st_float(li_noise)}_{eta}_{p}_{c}_{q}_{scale}"
    path_npy = f"temp_npy/{name}.npy"
    calc = not os.path.exists(path_npy)
    path_image = f"figs/distributional{dual(q)}_sc={scale}.pdf"

    p, c = 5, 2
    theta_star = np.zeros(p)
    theta_star[:c] = 1

    li_str_eps = list(map(str, li_eps))
    eta = 0.3
    Sigma, Q = gen_Q(p=p, c=c, eta=eta, scale=scale)
    Q_c = np.linalg.cholesky(Sigma[:c, :c])


    if calc:
        loss = {
                'core': {
                    },
                'total': {
                    },
                }
        for eps, str_eps in zip(li_eps, li_str_eps):
            theta_total  = robust(Q, theta_star, eps, q)
            theta_c_ = robust(Q_c, theta_star[:c], eps, q)
            theta_core = np.zeros_like(theta_total)
            theta_core[:c] = theta_c_

            repeat_count = 10000
            for noise_param in li_noise:
                ans_total = []
                ans_core = []
                for _ in range(repeat_count):
                    noise = np.random.randn(*Q.shape) * noise_param
                    noise[:c,:] = 0
                    Q_p = Q + noise
                    theta_p = theta_star
                    ans_total.append(eval_loss(Q_p, theta_p, None, theta_total))
                    ans_core.append(eval_loss(Q_p, theta_p, None, theta_core))
                loss['total'].setdefault(str_eps, []).append(np.mean(ans_total))
                loss['core'].setdefault(str_eps, []).append(np.mean(ans_core))
        save_dict(path_npy, loss)
    
    if plot:
        loss = load_dict(path_npy)
        for eps, str_eps, color in zip(li_eps, li_str_eps, ['r', 'b', 'o']):
            plt.plot(li_noise, loss['total'][str_eps], color=color, linestyle='solid', label=r"$\epsilon$=" + f"{eps:.2f}" + ", total", alpha=0.5)
            plt.plot(li_noise, loss['core'][str_eps], color=color,  linestyle='dotted', label=r"$\epsilon$=" + f"{eps:.2f}" + ", core")
            # for str_core in loss.keys():
            #     plt.plot(li_noise, loss[str_core][str_eps], marker='*', label=str_core + f"{eps:.2f}")
        plt.xlabel(r'Noise parameter $\sigma_{Q}$')
        plt.ylabel('Loss after distribution shift')
        plt.legend(loc='upper center', bbox_to_anchor=(0.50, 1.27), ncol=2, fancybox=False, shadow=False, fontsize=19)
        plt.savefig(path_image, bbox_inches='tight')
        plt.clf()

def save_dict(path, d):
    np.save(path, d)
def load_dict(path):
    return np.load(path, allow_pickle=True).item()

def dual(q):
    if q == 2:
        return 2
    if q == 1:
        return 'inf'
    if q == 'inf':
        return 1

def st_float(li):
    l = str(len(li))
    li = [li[0], li[-1]]
    return "_".join(list(map(lambda val: f"{val:.4f}", li))) + "_" + l

def plot_heatmap(li_eps, li_scale, eta, p=5, c=2, q='inf', plot=False, xticklabels=None, yticklabels=None):
    heatmap_font_size = 21 
    mpl.rc('font',**{'family':'serif','serif':['Times'], 'size': heatmap_font_size})

    name = f"heatmap_eps__{st_float(li_eps)}__eta__{st_float(li_scale)}_{eta}_{p}_{c}_{q}"
    path_npy = f"temp_npy/{name}.npy"
    calc = not os.path.exists(path_npy)
    path_image = f"figs/heatmap{dual(q)}.pdf"
    theta_star = np.zeros(p)
    theta_star[:c] = 1
    if calc:
        matrix = np.ones((len(li_eps), len(li_scale))) * 0.6
        for j, scale in enumerate(li_scale):
            Sigma, Q = gen_Q(p=p, c=c, eta=eta, scale=scale)
            for i, eps in enumerate(li_eps):
                ic()
                theta = robust(Q, theta_star, eps, q)
                if np.linalg.norm(theta) < 1e-2:
                    # matrix[i, j] = 1
                    break
                matrix[i, j] = get_norm_frac(theta, c)
        save_dict(path_npy, {'matrix': matrix})
    if plot:
        matrix = load_dict(path_npy)['matrix']

        sns.heatmap(matrix, xticklabels=xticklabels, yticklabels=yticklabels)
        plt.xlabel("Scale of the spurious feautres")
        plt.ylabel(r"$\epsilon$ in adversarial training")
        plt.savefig(path_image, bbox_inches='tight')
        plt.clf()

def plot_frac_spurious_vary_p(li_eps, li_p, eta, scale=1,  q='inf', c=2, plot=False, ymin=-0.02, ymax=1):
    fig_size = np.array([7.5, 7])
    plt.figure(1, fig_size)
    name = f"eps__{st_float(li_eps)}__p_c__{st_float(li_p)}_{scale}_{eta}_{q}_{c}"
    path_npy = f"temp_npy/{name}.npy"
    calc = not os.path.exists(path_npy)
    path_image = f"figs/pcl1{dual(q)}_{c}.pdf"
    if calc:
        norm_frac = dict()
        for p in li_p:
            theta_star = np.zeros(p)
            theta_star[:c] = 1
            Sigma, Q = gen_Q(p=p, c=c, eta=eta, scale=scale)
            norm_frac[(p, c)] = []
            for eps in li_eps:
                ic()
                theta = robust(Q, theta_star, eps, q)
                if np.linalg.norm(theta) < 1e-2:
                    break
                norm_frac[(p, c)].append(get_norm_frac(theta, c))
            norm_frac[(p, c)] = np.array(norm_frac[(p, c)])
        ic(norm_frac.keys())
        save_dict(path_npy, norm_frac)
    if plot:
        norm_frac = load_dict(path_npy)
        ic(norm_frac.keys())
        for p in li_p:
            y = norm_frac[(p, c)]
            x = li_eps[:len(y)]
            # label = f"$eta$={eta:.2f}"
            # label = r"$\eta$=" + f"{eta:.2f}"
            label = f"m={p}"
            plt.plot(x, y, marker='^', label=label, alpha=0.7)
            ic()
        plt.ylim(ymin=ymin, ymax=ymax)
        plt.xlabel(r"$\epsilon$ used in " +  "adversarial training")
        plt.ylabel(r"NFS")
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=len(li_p), fancybox=False, shadow=False, fontsize=22)
        plt.savefig(path_image, bbox_inches='tight')
        plt.clf()

def plot_frac_spurious(li_eps, li_eta, scale=1, p=5, c=2, q='inf', plot=False, ymin=-0.02, ymax=.72):
    fig_size = np.array([7.5, 7])
    plt.figure(1, fig_size)
    name = f"eps__{st_float(li_eps)}__eta__{st_float(li_eta)}_{scale}_{p}_{c}_{q}"
    path_npy = f"temp_npy/{name}.npy"
    calc = not os.path.exists(path_npy)
    path_image = f"figs/{dual(q)}.pdf"
    theta_star = np.zeros(p)
    theta_star[:c] = 1
    if calc:
        norm_frac = dict()
        for eta in li_eta:
            Sigma, Q = gen_Q(p=p, c=c, eta=eta, scale=scale)
            norm_frac[eta] = []
            for eps in li_eps:
                ic()
                theta = robust(Q, theta_star, eps, q)
                if np.linalg.norm(theta) < 1e-2:
                    break
                norm_frac[eta].append(get_norm_frac(theta, c))
            norm_frac[eta] = np.array(norm_frac[eta])
        save_dict(path_npy, norm_frac)
    if plot:
        norm_frac = load_dict(path_npy)
        for eta in norm_frac:
            y = norm_frac[eta]
            x = li_eps[:len(y)]
            # label = f"$eta$={eta:.2f}"
            label = r"$\eta$=" + f"{eta:.2f}"
            plt.plot(x, y, marker='^', label=label, alpha=0.7)
            ic()
        plt.ylim(ymin=ymin, ymax=ymax)
        plt.xlabel(r"$\epsilon$ used in " +  "adversarial training")
        plt.ylabel(r"NFS")
        plt.legend(loc='upper center', bbox_to_anchor=(0.45, 1.15), ncol=len(li_eta), fancybox=False, shadow=False, fontsize=22)
        plt.savefig(path_image, bbox_inches='tight')
        plt.clf()

def script_l1():
    li_eta = np.linspace(0, 0.5, 3)
    li_eps = np.linspace(0, 3, 30) 
    plot_frac_spurious(li_eps, li_eta, plot=True, q='inf')

def script_l1():
    li_eta = np.linspace(0, 0.5, 3)
    li_eps = np.linspace(0, 3, 30) 
    plot_frac_spurious(li_eps, li_eta, plot=True, q='inf')

def script_linfty():
    li_eta = np.linspace(0, 0.5, 3)
    li_eps = np.linspace(0, 3, 30) 
    plot_frac_spurious(li_eps, li_eta, plot=True, q=1)

def script_l2():
    li_eta = np.linspace(0, 0.5, 3)
    li_eps = np.linspace(0, 3, 30) 
    plot_frac_spurious(li_eps, li_eta, plot=True, q=2)

def script_heatmap():
    li_eps = np.linspace(0.8, 1.15, 31) 
    eta = 0.3
    li_scale = np.linspace(1, 10, 31)
    # li_eps = np.linspace(1, 4, 31) 
    # eta = 0.3
    # li_scale = np.linspace(2.5, 10, 31)
    xticklabels = list(li_scale)
    yticklabels = list(li_eps)
    def fl(x):
        return str(np.round(x, 2))
    for i in range(31):
        if i % 5 != 0:
            xticklabels[i] = ""
        else:
            xticklabels[i] = fl(xticklabels[i])
        if i % 6 != 0:
            yticklabels[i] = ""
        else:
            yticklabels[i]  = fl(yticklabels[i])
    plot_heatmap(li_eps, li_scale, eta, plot=True, q=1, xticklabels=xticklabels, yticklabels=yticklabels)

def script_l1_vary_p():
    eta = 0.5
    li_eps = np.linspace(0, 3, 30) 
    li_p = [5, 8, 20]
    plot_frac_spurious_vary_p(li_eps, li_p, eta=0.5, plot=True, c=4, q='inf')
    li_p = [20, 40, 60]
    plot_frac_spurious_vary_p(li_eps, li_p, eta=0.5, plot=True, c=10, q='inf')

def get_parser():
    parser = argparse.ArgumentParser(description='Adversarial Training for MNIST', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--figure', type=int, default=1)
    return parser

def mkdir(path):
    if(not os.path.exists(path)):
        os.makedirs(path)


def main():
    mkdir("temp_npy")
    mkdir("figs")
    parser = get_parser()
    args = parser.parse_args()
    if args.figure == 2:
        script_l1()
        script_l2()
        script_linfty()
    elif args.figure == 3:
        script_heatmap()
    elif args.figure == 4:
        script_distributional()
    elif args.figure == 10:
        script_l1_vary_p()

if __name__ == '__main__':
    main()
