from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from eval import get_run_metrics, read_run_dir, get_model_from_run, get_data_sampler, get_task_sampler, gen_standard, eval_batch
from plot_utils import basic_plot, collect_results, relevant_model_names

%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

run_dir = "../models"

df = read_run_dir(run_dir)
df  # list all the runs in our run_dir

# task = "linear_regression"
#task = "sparse_linear_regression"
#task = "decision_tree"
# task = "relu_2nn_regression"
task = "relu_2nn_regression_chainofthought"

# relu-2nn-default linear regression
# run_id = "2d467e82-1825-45e2-b497-7bd1c366a843"

# relu-2nn-chainofthought linear regression
run_id = "bb513654-291b-4b16-bcb2-44f4ae04c088"

run_path = os.path.join(run_dir, task, run_id)
recompute_metrics = False

if recompute_metrics:
    get_run_metrics(run_path)  # these are normally precomputed at the end of training

def valid_row(r):
    return r.task == task and r.run_id == run_id

metrics = collect_results(run_dir, df, valid_row=valid_row)
_, conf = get_model_from_run(run_path, only_conf=True)
n_dims = conf.model.n_dims

# get model and check it on noisy data
from eval import build_evals

model, conf = get_model_from_run(run_path, step=499999)
model = model.cuda().eval()


def get_model_error(model, xs, ys, num_ic_examples=None, device='cuda', layer_activations=None):
    if layer_activations is not None:
        layer_activations = [act.to(device) for act in layer_activations]
        pred = model.predict(xs.to(device), ys.to(device), layer_activations=layer_activations).detach()
    else:
        pred = model(xs.to(device), ys.to(device)).detach()

    # but this is prediction for every example while varying number of in-context examples
    sq_error = (ys.cpu() - pred.cpu()).square()
    if num_ic_examples:
        # pick only the prediction with num_ic_examples in-context examples
        final_pred_sq_error = sq_error[:, num_ic_examples]
        mean_sq_error = final_pred_sq_error.mean().item()
    else:
        final_pred_sq_error = sq_error
        mean_sq_error = final_pred_sq_error.mean(dim=0)

    return mean_sq_error


def get_linear_regression_error(xs, ys, num_ic_examples=-1):
    # compare against linear regression with the same x,y pairs
    # treat each batch as a linear regression problem
    # but ignore the last sample
    linear_regression_pred = []
    for i in range(xs.shape[0]):
        w_hat = torch.linalg.pinv(xs[i][:num_ic_examples]) @ ys[i][:num_ic_examples]
        y_hat = xs[i] @ w_hat
        linear_regression_pred.append(y_hat)

    linear_regression_pred = torch.stack(linear_regression_pred)
    sq_error = (ys.cpu() - linear_regression_pred.cpu()).square()
    if num_ic_examples:
        # pick only the prediction with num_ic_examples in-context examples
        final_pred_sq_error = sq_error[:, num_ic_examples]
        mean_sq_error = final_pred_sq_error.mean().item()
    else:
        final_pred_sq_error = sq_error
        mean_sq_error = final_pred_sq_error.mean(dim=0)

    return mean_sq_error


# numerically stable version of solving a system of equations
# using the default linalg.lstsq => least squares solver
def get_stable_linear_regression_error(xs, ys, num_ic_examples=-1):
    # compare against linear regression with the same x,y pairs
    # treat each batch as a linear regression problem
    # but ignore the last sample
    linear_regression_pred = []
    worst_condition_number = torch.Tensor(1)
    for i in range(xs.shape[0]):
        w_hat = torch.linalg.lstsq(xs[i][:num_ic_examples], ys[i][:num_ic_examples]).solution
        y_hat = xs[i] @ w_hat
        linear_regression_pred.append(y_hat)
        worst_condition_number = max(worst_condition_number, torch.linalg.cond(xs[i][:-1]))

    linear_regression_pred = torch.stack(linear_regression_pred)
    sq_error = (ys.cpu() - linear_regression_pred.cpu()).square()
    if num_ic_examples:
        # pick only the prediction with num_ic_examples in-context examples
        final_pred_sq_error = sq_error[:, num_ic_examples]
        mean_sq_error = final_pred_sq_error.mean().item()
    else:
        final_pred_sq_error = sq_error
        mean_sq_error = final_pred_sq_error.mean(dim=0)

    return mean_sq_error, worst_condition_number.item()


evaluation_kwargs = build_evals(conf)['standard']
data_name = evaluation_kwargs['data_name']
task_name = evaluation_kwargs['task_name']
batch_size = evaluation_kwargs['batch_size']
prompting_strategy = evaluation_kwargs['prompting_strategy']
n_points = evaluation_kwargs['n_points']

data_sampler = get_data_sampler(data_name, n_dims)
task_sampler = get_task_sampler(task_name, n_dims, batch_size, 
    num_tasks=conf.training.num_tasks, **conf.training.task_kwargs)
generating_func = gen_standard
num_eval_examples=1280

# std_devs_list = [0, 0.5, 1, 2, 4, 8, 16, 32, 64, 128]
std_devs_list = [0]

model_error_list = {}
linear_regression_error_list = {}
error_ratio_list = {'condition_number': []}
for std_dev in std_devs_list:
    model_error_list['{}_noise'.format(std_dev)] = []
    linear_regression_error_list['{}_noise'.format(std_dev)] = []
    error_ratio_list['{}_noise'.format(std_dev)] = []

for i in range(num_eval_examples // batch_size):
    xs, xs_p = generating_func(data_sampler, n_points, batch_size)
    task = task_sampler()
    device = "cuda"

    if task_name == 'relu_2nn_regression_chainofthought':
        ys, layer_activations = task.evaluate(xs)
    else:
        ys = task.evaluate(xs)
        layer_activations = None

    model_error = get_model_error(model, xs, ys, layer_activations=layer_activations)

    linear_regression_error, worst_condition_number = get_stable_linear_regression_error(xs, ys)
    model_error_list['0_noise'].append(model_error)
    linear_regression_error_list['0_noise'].append(linear_regression_error)
    error_ratio_list['0_noise'].append(linear_regression_error/model_error)
    error_ratio_list['condition_number'].append(worst_condition_number)

    print("Batch: {} | 40 in-context examples | Model Error: {} | Linear Regression Error: {}".format(i, model_error[-1].item(), linear_regression_error))

mean_model_error_list = {}
for key in model_error_list.keys():
    mean_model_error_list[key] = sum(model_error_list[key])/len(model_error_list[key])

print("Mean error of model with num in-context-examples = {}".format(mean_model_error_list))

results_df = pd.DataFrame({'num_ic_examples': list(range(len(mean_model_error_list['0_noise']))), 'test_error': mean_model_error_list['0_noise']})
results_df.to_csv("results/relu_2nn_regression_results.csv", index=False)


    # # now let's try with noise
    # for std in std_devs_list:
    #     if std == 0:
    #         # we've already done this
    #         continue
    #     noise = torch.randn_like(ys) * std
    #     noise[:, -1] = 0. # don't add noise for the query sample
    #     noisy_ys = ys + noise
    #     model_error = get_model_error(model, xs, noisy_ys)
    #     linear_regression_error, worst_condition_number = get_stable_linear_regression_error(xs, noisy_ys)
    #     print("Batch: {} | Noise with 0 mean {} std | Model Error: {} | Linear Regression Error: {}".format(i, std, model_error, linear_regression_error))

    #     model_error_list['{}_noise'.format(std)].append(model_error)
    #     linear_regression_error_list['{}_noise'.format(std)].append(linear_regression_error)
    #     error_ratio_list['{}_noise'.format(std)].append(linear_regression_error/model_error)
