import torch
from attacks import *
from models import *
import matplotlib.pyplot as plt
from train import Trainer
from tqdm import tqdm
import os
import pickle
from color_mnist import *

def get_results_path(save_root, l2_norm=False):
    model_name = save_root.split('/')[-1][:-4]
    ext = '_l2' if l2_norm else ''
    results_path = './results/{}.pkl'.format(model_name+ext)
    return results_path

def load_results_dict(results_path):
    if os.path.exists(results_path):
        with open(results_path, 'rb') as f:
            results_dict = pickle.load(f)
    else:
        results_dict = dict()
    return results_dict

def cache_results_dict(results_path, results_dict):
    with open(results_path, 'wb') as f:
        pickle.dump(results_dict, f)

def eval_robustness(trainer, epsilons):
    adv_params = {'steps':10, 'alpha':2/255, 'norm':'linf'}
    accs = []
    for eps in epsilons:
        print(eps)
        adv_params['eps'] = eps / 255 
        trainer.adv_params = adv_params
        trainer.init_attack()
        _, acc = trainer.test()
        accs.append(acc)
    return accs

def plot_robustness(adv=False, spur_corr='color_shift', dset='cifar'):
    if adv:
        epsilons = [0,2,4,6,8,10]
    else:
        epsilons = [0,0.5,1,2,3,4,5,6]#,6,8,10]
        # epsilons = [0,6,12,18,24,30,36,42]


    mtypes = ['MLP', 'ConvNet']
    mtypes = ['resnet']
    # flip_fracs = [0, 0.025, 0.1, 0.2, 0.3, 0.4, 0.5]
    flip_fracs = [0.025, 0.05, 0.125, 0.25, 0.5]
    fig, axs = plt.subplots(1, len(mtypes), figsize=(len(mtypes)*3, 2.75))
    if len(mtypes) == 1:
        axs = [axs]

    import matplotlib.cm as cmap
    colors = [cmap.viridis(x/len(flip_fracs)) for x in range(len(flip_fracs))]
    plt.style.use('ggplot')

    for ax, model_type in zip(axs, mtypes):
        for flip_frac, c in zip(flip_fracs, colors):
            trainer = Trainer(model_type, flip_frac, spur_corr, dset, adv_params=None)
            trainer.restore_model()

            results_path = get_results_path(trainer.save_path)
            results_dict = load_results_dict(results_path)
            eps_to_do = set(epsilons).difference(results_dict.keys())
            # eps_to_do = epsilons

            # accs = eval_robustness(trainer, eps_to_do)

            # for eps, acc in zip(eps_to_do, accs):
            #     results_dict[eps] = acc

            # cache_results_dict(results_path, results_dict)
            accs = [results_dict[eps] for eps in epsilons]            
            # print(1-flip_frac, flip_frac, ((1-flip_frac)/flip_frac), np.round((1-flip_frac)/flip_frac))
            ax.plot(epsilons, accs, '*-', c=c, label='$\\rho=${}:1'.format(int(np.round((1-flip_frac) / flip_frac))))

        # ax.set_title('{}s'.format(model_type))
        ax.set_xlabel('PGD $\ell_\infty$ Attack Budget ($\epsilon/ 255$)')
        ax.set_ylabel('Accuracy under Attack')
    ax.legend()
    # ext = 'adv trained' if adv else 'undefended'
    # fig.suptitle('Robustness to PGD $L_\infty$ attack, {} models'.format(ext))
    fig.tight_layout()
    fig.savefig('./plots/{}_{}_robustness.png'.format(dset, spur_corr), dpi=300, bbox_inches='tight', pad_inches=0.01)

def color_shift_figure(fn_1=lambda x: shift_color(x, channel_dim=0), fn_2=lambda x: shift_color(x, channel_dim=1), shift_name1='redder', shift_name2='greener'):
    from color_mnist import get_cifar_loaders, shift_color
    loader, _ = get_cifar_loaders()
    x,y = next(iter(loader))
    dog = x[y==5][0]

    # red_dog = shift_color(dog.unsqueeze(0), channel_dim=0).squeeze(0)
    # green_dog = shift_color(dog.unsqueeze(0), channel_dim=1).squeeze(0)
    red_dog = fn_1(dog.unsqueeze(0)).squeeze(0)
    green_dog = fn_2(dog.unsqueeze(0)).squeeze(0)
    dog, red_dog, green_dog = [x.numpy().swapaxes(0,1).swapaxes(1,2) for x in [dog, red_dog, green_dog]]

    f, ax = plt.subplots(1,1, figsize=(1.5, 1.5)); ax.set_axis_off()
    ax.imshow(dog); ax.set_title('Dog, Label$\in [5,9]$', fontsize=9)
    f.savefig('eg_color_shift/dog_og_{}.jpg'.format(shift_name1), dpi=300, bbox_inches='tight', pad_inches=0.02)
    
    f, axs = plt.subplots(1,2, figsize=(3,1.5))
    _ = [ax.set_axis_off() for ax in axs.ravel()]
    axs[0].imshow(red_dog);axs[0].set_title(f'$Pr[${shift_name1}$]=0.05$', fontsize=8)
    axs[1].imshow(green_dog);axs[1].set_title(f'$Pr[${shift_name2}$]=0.95$', fontsize=8)
    f.savefig('eg_color_shift/dog_colored_{}.jpg'.format(shift_name1), dpi=300, bbox_inches='tight', pad_inches=0.02)
    # axs[1,0].imshow(dog)
    # axs[1,1].imshow(colored_dog)
    # f.savefig('test.png')


# plot_robustness(spur_corr='lighting')
plot_robustness()
# color_shift_figure()
# color_shift_figure(fn_1=lambda x: alter_lighting(x, scale=1.25), fn_2=lambda x:alter_lighting(x, scale=0.75), shift_name1='brighter', shift_name2='darker')




