import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import argparse

import os

from compute_patterns import acc, int_to_list

def check_data(data_dir):
    return os.path.exists(f'{data_dir}/counts.npy') 

def load_data(data_dir):
    y = np.load(f'{data_dir}/y_data.npy')
    counts = np.load(f'{data_dir}/counts.npy')
    patterns = np.load(f'{data_dir}/patterns.npy', allow_pickle=True)

    return y, counts[-1], patterns

def get_max_wrong(counts, theta, num_samples):
    opt = 0
    for n in counts:
        if n > 0:
            break
        opt += 1

    max_wrong = opt + int(theta * num_samples)

    return max_wrong

# A slower method to calculate the diversity by using the definition directly
def calc_diversity_hamming(y, patterns, max_wrong, num_samples, rashomon_count):
    total = 0
    for p1 in patterns:
        p1 = int_to_list(p1, num_samples)
        _, wrong1 = acc(p1, y)

        if wrong1 > max_wrong:
            continue

        for p2 in patterns:
            p2 = int_to_list(p2, num_samples)
            _, wrong2 = acc(p2, y)

            if wrong2 > max_wrong:
                continue

            _, dist = acc(p1, p2)
            total += dist
    
    total /= rashomon_count * rashomon_count
    total /= num_samples

    return total

# A faster method to calculate the diversity using sample agreement
def calc_diversity(y, patterns, max_wrong, num_samples, rashomon_count):
    agreements = np.zeros(shape=(num_samples,), dtype=np.int64)
    for pattern in patterns:
        pattern = int_to_list(pattern, num_samples)
        _, num_wrong = acc(pattern, y)

        if num_wrong > max_wrong:
            continue

        pattern = np.array(pattern)
        agreements += pattern
    
    agreements = agreements / rashomon_count
    agreements = agreements * (1 - agreements)

    diversity = 2 * agreements.mean()

    return diversity

def plot(data, path, datasets, dataset_names, x, y, xlabel, ylabel):
    fig, ax = plt.subplots()

    for dataset, name in zip(datasets, dataset_names):
        df = data[data['Dataset'] == dataset]
        sns.lineplot(data=df, x=x, y=y, ax=ax, label=name, palette='tab10')

    ax.set_xlabel(xlabel, size = 22)
    ax.set_ylabel(ylabel, size = 22)  
    ax.tick_params(axis='both', labelsize=18)
    ax.legend(fontsize=18, loc='upper left')

    fig.savefig(f'{path}.png', bbox_inches = 'tight')
    fig.savefig(f'{path}.pdf', bbox_inches = 'tight')


def main(data_dirs, data_names, output_dir, noises, num_draws, theta):
    arr = []
    for data_dir in data_dirs:
        y, counts, patterns = load_data(data_dir)
        num_samples = y.shape[0]
        max_wrong = get_max_wrong(counts, theta, num_samples)

        assert max_wrong <= counts.shape[0]-1

        rashomon_count = counts[:max_wrong+1].sum()

        diversity = calc_diversity(y, patterns, max_wrong, num_samples, rashomon_count)

        for _ in range(num_draws):
            arr.append([rashomon_count, diversity, 0, data_dir])

        for noise in noises:
            for trial in range(num_draws):
                path = f'{data_dir}/noise_{noise}/trial_{trial}'

                if not check_data(path):
                    continue

                y, counts, patterns = load_data(path)
                num_samples = y.shape[0]
                max_wrong = get_max_wrong(counts, theta, num_samples)

                assert max_wrong <= counts.shape[0]-1

                rashomon_count = counts[:max_wrong+1].sum()

                diversity = calc_diversity(y, patterns, max_wrong, num_samples, rashomon_count)

                arr.append([rashomon_count, diversity, noise, data_dir])

    df = pd.DataFrame(arr, columns=['Rashomon Count', 'Diversity', 'Noise', 'Dataset'])

    os.makedirs(output_dir, exist_ok=True)

    plot(
        df, 
        path=f'{output_dir}/rashomon_size',
        datasets=data_dirs,
        dataset_names=data_names,
        x='Noise', 
        y='Rashomon Count', 
        xlabel=r"Label noise, $\rho$",
        ylabel='Number of Patterns',
    )
    plot(
        df, 
        datasets=data_dirs,
        path=f'{output_dir}/diversity',
        dataset_names=data_names,
        x='Noise', 
        y='Diversity', 
        xlabel=r"Label noise, $\rho$",
        ylabel='Pattern Diversity',
    )

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Graphs the averaged results of the size of the rashomon set and diversity')
    parser.add_argument(
        '-d', 
        '--data_dirs', 
        nargs='*',
        default=['./rashomon_sets/iris', './rashomon_sets/wine_pca4', './rashomon_sets/seeds_pca4'],
    )
    parser.add_argument(
        '-n', 
        '--data_names', 
        nargs='*',
        default=['Iris', 'Wine 4', 'Seeds 4'],
    )
    parser.add_argument('-o', '--output_dir', default='./figures')
    parser.add_argument('-no', '--noises', nargs='*', type=float, default=[0.02, 0.04, 0.06, 0.08, 0.1, 0.15])
    parser.add_argument('-nd', '--num_draws', default=5, type=int)
    parser.add_argument('-t', '--theta', default=0.02, type=float)

    args = parser.parse_args()

    main(
        data_dirs = args.data_dirs, 
        data_names = args.data_names, 
        output_dir = args.output_dir, 
        noises = args.noises, 
        num_draws = args.num_draws, 
        theta = args.theta,
    )