# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
sns.set_style('whitegrid')
sns.axes_style({'axes.edgecolor':.5, 'axes.font.family':'Helvectia'})

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

def fixed_budget_plot(workers, arguments, start=0, end=None, labels=None, save=False, save_name='', xlabel='', ylabel=''):

    plt.figure(figsize=(10, 6))
    for i, worker in enumerate(workers):
        if type(worker) is dict:
            worker = dotdict(worker)
        
        if arguments[i] == 'oracle-acc':
            means = worker.accuracy_mean_oracle
            stes = worker.accuracy_ste_oracle
        elif arguments[i] == 'ols-acc':
            means = worker.accuracy_mean_ols
            stes = worker.accuracy_ste_ols
        elif arguments[i] == '2sls-acc':
            means = worker.accuracy_mean_two_sls
            stes = worker.accuracy_ste_two_sls
        elif arguments[i] == 'p2sls-acc':
            means = worker.accuracy_mean_pseudo_two_sls
            stes = worker.accuracy_ste_pseudo_two_sls
        elif arguments[i] == 'oracle-sr':
            means = worker.simple_regret_mean_oracle
            stes = worker.simple_regret_ste_oracle
        elif arguments[i] == 'ols-sr':
            means = worker.simple_regret_mean_ols
            stes = worker.simple_regret_ste_ols
        elif arguments[i] == '2sls-sr':
            means = worker.simple_regret_mean_two_sls
            stes = worker.simple_regret_ste_two_sls
        elif arguments[i] == 'p2sls-sr':
            means = worker.simple_regret_mean_pseudo_two_sls
            stes = worker.simple_regret_ste_pseudo_two_sls
        elif arguments[i] == 'oracle-bias':
            means = worker.bias_mean_oracle
            stes = worker.bias_ste_oracle
        elif arguments[i] == 'ols-bias':
            means = worker.bias_mean_ols
            stes = worker.bias_ste_ols
        elif arguments[i] == '2sls-bias':
            means = worker.bias_two_sls
            stes = worker.bias_ste_two_sls
        elif arguments[i] == 'p2sls-bias':
            means = worker.bias_mean_pseudo_two_sls
            stes = worker.bias_ste_pseudo_two_sls
            
        if end is None:
            end = len(means)
        if labels is not None:
            plt.plot(np.arange(start+1, end+1), means[start:end], label=labels[i])
            plt.fill_between(x=np.arange(start+1, end+1), y1=(means-stes)[start:end], 
                             y2=(means+stes)[start:end], alpha=.4)
        else:
            plt.plot(np.arange(start+1, end+1), means[start:end])
            plt.fill_between(x=np.arange(start+1, end+1), 
                             y1=(means-stes)[start:end], y2=(means+stes)[start:end], alpha=.5)

    if labels is not None:
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=len(labels), fontsize=20)

    if xlabel=='':
        plt.xlabel('Time Index', fontsize=20)
    else:
        plt.xlabel(xlabel, fontsize=20)
    if ylabel=='':
        plt.ylabel('Identification Probability', fontsize=20)
    else:
        plt.ylabel(ylabel, fontsize=20)

    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.tight_layout()
    if save:
        plt.savefig(save_name)
    plt.show()
    
    
def line_plot(worker_groups, argument, labels=None, xtick_labels=None, save=False, save_name='', xlabel='', ylabel='', xticklabels=None):

    plt.figure(figsize=(10, 6))
    for i, workers in enumerate(worker_groups):
        
        for j, worker in enumerate(workers):
            if type(worker) is dict:
                workers[j] = dotdict(worker)
        if argument == 'accuracy':
            means = np.array([worker.accuracy_mean for worker in workers])
            stes = np.array([worker.accuracy_ste for worker in workers])
        elif argument == 'sample_complexity':
            means = np.array([worker.sample_complexity_mean for worker in workers])
            stes = np.array([worker.sample_complexity_ste for worker in workers])            
            
        start = 0
        if labels is not None:
            plt.plot(means, label=labels[i], lw=2)
            plt.fill_between(x=np.arange(start, len(means)), y1=(means-stes)[start:], 
                             y2=(means+stes)[start:], alpha=.4)
        else:
            print(i, means)
            plt.plot(np.arange(start, len(means)), means[start:], lw=2)
            plt.fill_between(x=np.arange(start, len(means)), 
                             y1=(means-stes)[start:], y2=(means+stes)[start:], alpha=.5)

    if labels is not None:
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=len(labels), fontsize=20)

        
    if xlabel=='':
        plt.xlabel('Time Index', fontsize=22)
    else:
        plt.xlabel(xlabel, fontsize=22)
    if ylabel=='':
        plt.ylabel('Sample Complexity', fontsize=22)
    else:
        plt.ylabel(ylabel, fontsize=22)
        


    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    if xticklabels is not None:
        plt.xticks(range(len(xticklabels)), xticklabels)
        
    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

    plt.tight_layout()
    if save:
        plt.savefig(save_name)
    plt.show()
            
            
def bar_plot(workers, argument, labels=None, save=False, save_name='', xlabel='', ylabel=''):

    plt.figure(figsize=(10, 6))
    
    for i, worker in enumerate(workers):
        if type(worker) is dict:
            workers[i] = dotdict(worker)

    if argument == 'accuracy':
        means = [worker.accuracy_mean for worker in workers]
        stes = [worker.accuracy_ste for worker in workers]
    elif argument == 'sample_complexity':
        means = [worker.sample_complexity_mean for worker in workers]
        stes = [worker.sample_complexity_ste for worker in workers]

    cmap = plt.get_cmap("tab10")
    colors = [cmap(i) for i in range(len(means))]
    plt.bar(range(len(means)), means, yerr=stes, color=colors, capsize=10)

    if xlabel=='':
        plt.xlabel('Algorithm', fontsize=20)
    else:
        plt.xlabel(xlabel, fontsize=20)
    if ylabel=='':
        plt.ylabel('Sample Complexity', fontsize=20)
    else:
        plt.ylabel(ylabel, fontsize=20)


    plt.xticks(range(len(means)), ['']*len(means), fontsize=16)
    plt.yticks(fontsize=16)
    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

    plt.tight_layout()
    if save:
        plt.savefig(save_name)
    plt.show()
    

