import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
import scipy
from scipy.optimize import minimize
from scipy.optimize import curve_fit
from itertools import product
from configs import *
from utils import *


# Define the Huber loss function
def custom_huber_loss(y_true, y_pred, delta=1e-3):
    diff = y_true - y_pred
    cond = np.abs(diff) <= delta
    loss = np.where(cond, 0.5 * diff**2, delta * (np.abs(diff) - 0.5 * delta))
    return np.sum(loss)

def huber_loss_objective(params, F, losses):
    a, e, alpha = params
    predictions = np.logaddexp(a - alpha * np.log(F), e)
    return custom_huber_loss(np.log(losses), predictions, delta=1e-3)


def fetch_flop(df, flop, loss_key='train/loss_smoothed', warmup_remove_factor=1e-12, n_key='params', 
               seq_len=2048, bs_key='bs',
               flop_per_token_key='flops_per_token', flop_tolerance=0.1):
    out = []
    for _, row in df.iterrows():
        if len(row[loss_key]) == 0:
            continue
        loss_vals = row[loss_key].dropna().groupby(level=0).mean().sort_index()
        step_vals = loss_vals.index
        mask = step_vals >= ((warmup_remove_factor * row.warmup_tokens) / row.bs / row.seq_len)
        loss_vals = loss_vals[mask]
        loss_vals.index = loss_vals.index.astype(float) * seq_len * row[bs_key] * row[flop_per_token_key]
        flop_vals = loss_vals.index
        
        if len(loss_vals) == 0:
            continue        
        flop_ind = loss_vals.index.searchsorted(flop)
        if flop_ind > 0:
            flop_ind += -1 + np.abs(np.log(flop_vals[flop_ind-1:flop_ind+1]/flop)).argmin()
        rel_err = np.exp(np.abs(np.log(flop_vals[flop_ind]/flop))) - 1
        if rel_err > flop_tolerance:
            continue

        if len(flop_vals) > 1:
            flop_slice = flop_vals[max(0,flop_ind-5):flop_ind+5]
            loss_slide = loss_vals.iloc[max(0,flop_ind-5):flop_ind+5]
            loss_interp = np.exp(np.interp(np.log(flop), np.log(flop_slice), np.log(loss_slide)))
            out.append(dict(n=row[n_key], t=flop / row[flop_per_token_key], loss=loss_interp))
        else:
            out.append(dict(n=row[n_key], t=loss_vals.index[flop_ind] / row[flop_per_token_key], loss=loss_vals.iloc[flop_ind]))

    return pd.DataFrame(out)


def power_law_fit(df, x, y, weighted=False):
    if isinstance(y, (list, tuple)):
        out = {}
        for yy in y:
            out.update(power_law_fit(df, x, yy, weighted=weighted))
        return out
    else:
        X_data = np.log(df.dropna()[x].values).reshape(-1, 1)
        y_data = np.log(df.dropna()[y].values)
        std_key = f'{y}_star_std'
        if weighted and std_key in df.columns:
            y_data_std = df.dropna()[std_key].values
            w = 1 / y_data_std ** 2
        else:
            w = None

        clf = LinearRegression().fit(X_data, y_data, sample_weight=w)
        return {f'{y}_exponent': clf.coef_.item(),
                f'{y}_coef': np.exp(clf.intercept_),
                f'{y}_r2': clf.score(X_data, y_data)}

def fit_compute_optimal_power_laws(optimal_pairs, bootstrap_data, bootstrap_num=None):
    out = {'basic': power_law_fit(optimal_pairs.reset_index(), 'flops', ['n', 't', 'multiplier']),
           'weighted': power_law_fit(optimal_pairs.reset_index(), 'flops', ['n', 't', 'multiplier'], weighted=True)}

    bootstrap_samples = bootstrap_data.dropna().set_index('flops')[
        ['n_stars', 't_stars', 'multiplier_stars', 'n_star_std', 't_star_std']].rename(
        columns=lambda x: x.replace('_stars', ''))
    if bootstrap_num is None:
        bootstrap_num = bootstrap_samples[['n', 't', 'multiplier']].applymap(len).min().min()

    for name, is_weighted in dict(bootstrap=False, bootstrap_weighted=True).items():
        out[name] = [power_law_fit(
            bootstrap_samples.applymap(lambda x: maybe_get_item(x, i)).reset_index(),
            'flops', ['n', 't', 'multiplier'], weighted=is_weighted)
            for i in range(bootstrap_num)]
    boostrap_medians = bootstrap_samples.applymap(np.median)
    out.update({
        'bs_median': power_law_fit(boostrap_medians.reset_index(), 'flops', ['n', 't', 'multiplier']),
        'bs_median_weighted': power_law_fit(boostrap_medians.reset_index(), 'flops', ['n', 't', 'multiplier'], weighted=True)})
    return out



def get_noise_for_loss(loss, bootstrap_iters, noise_low=0.005, noise_high=0.1, l_threshold_high=6, l_threshold_low=3):
    basic_noise = np.random.normal(0, 1, (bootstrap_iters, len(loss) // bootstrap_iters))
    noise_adjusted_losses = np.zeros((bootstrap_iters, len(loss) // bootstrap_iters))

    for i in range(len(loss) // bootstrap_iters):
        if loss[i] >= l_threshold_high:
            noise_factor = noise_high
        elif loss[i] <= l_threshold_low:
            noise_factor = noise_low
        else:
            log_noise = np.interp(loss[i], [l_threshold_low, l_threshold_high], [np.log(noise_low), np.log(noise_high)])
            noise_factor = np.exp(log_noise)
        noise_adjusted_losses[:, i] = loss[i] + noise_factor * basic_noise[:, i]
        
    return noise_adjusted_losses.flatten()


def vectorized_interp_with_seed_noise(df, n_interp_, bootstrap_iters, seed_noise=None,
                                      min_std_factor=0.33, tok_or_n='n'):
    if seed_noise is None:
        seed_noise = {}
    interp_num = len(n_interp_)
    stacked_df = pd.concat([df] * bootstrap_iters).reset_index(drop=True)
    # noise = np.random.normal(0, 1, (bootstrap_iters, len(df))) * seed_noise
    stacked_df['loss'] = get_noise_for_loss(stacked_df.loss, bootstrap_iters=bootstrap_iters, **seed_noise)

    batch_ids = np.repeat(np.arange(bootstrap_iters), len(df))
    stacked_df['batch_id'] = batch_ids
    stacked_df.sort_values(by=['batch_id', tok_or_n], inplace=True)

    def batch_interp(batch):
        interp = scipy.interpolate.Akima1DInterpolator(np.log(batch[tok_or_n]), np.log(batch['loss']))
        return np.exp(interp(np.log(n_interp_)))

    interpolated_values = stacked_df.groupby('batch_id').apply(batch_interp)

    # Find the index of the minimum interpolated loss value per batch
    min_indices = interpolated_values.apply(np.argmin)
    results = [n_interp_[idx] if idx != 0 and idx != interp_num - 1 else None for idx in min_indices]
    valid_results_loss = [interpolated_values[i][idx] for i, idx in enumerate(min_indices) if idx != 0 and idx != interp_num - 1]
    # Filter None values and calculate statistics
    valid_results = [result for result in results if result is not None]
    if len(valid_results) < bootstrap_iters // 2:
        return None, 0, None, None, None
    else:
        n_star_std_ = np.std(np.log(valid_results))
        min_std = min_std_factor * np.log(n_interp_[1] / n_interp_[0])  # this assumes a roughly uniform grid
        n_star_std_ = max(n_star_std_, min_std) * (bootstrap_iters / len(valid_results))
        loss_star_std_ = np.std(np.log(valid_results_loss))
        min_std_loss = min_std_factor * min([np.log(df.loss.iloc[i+1] / df.loss.iloc[i]) for i in range(len(df) - 1)])
        loss_star_std_ = max(loss_star_std_, min_std_loss) * (bootstrap_iters / len(valid_results_loss))
        return n_star_std_, None, valid_results, valid_results_loss, loss_star_std_


def interpolation(df_, interp_num, bootstrap_iters, seed_noise, min_std_factor, interp_num_multiplier, std_method, col):
    interp_ = np.geomspace(df_[col].min(), df_[col].max(), interp_num)
    df_ = df_.sort_values(col)
    interpolator = scipy.interpolate.Akima1DInterpolator(np.log(df_[col]), np.log(df_.loss))
    loss_interp_ = np.exp(interpolator(np.log(interp_)))
    star_ind_ = loss_interp_.argmin()

    if std_method == 'add_seed_noise':
        star_std_, _, noised_stars_, noised_loss, loss_star_std = vectorized_interp_with_seed_noise(
            df_, interp_, bootstrap_iters, seed_noise, min_std_factor * interp_num_multiplier, tok_or_n=col)
    else:
        star_std_ = None
        noised_stars_ = []

    return star_ind_, star_std_, noised_stars_, interp_, loss_interp_, noised_loss, loss_star_std


def interp_flop(big_df, loss_key, flop_vals=[8e16, 3e17, 6e17, 3e18, 6e18, 1e19], groupby_action='min',
                warmup_remove_factor=1e-12,
                interp_num_multiplier=25,
                n_key='params', n_star_std_method='add_seed_noise', t_star_std_method='add_seed_noise',
                bootstrap_iters=1000,
                min_std_factor=0.33,
                seed_noise=None, flop_tolerance=0.1,
                flop_per_token_key='flops_per_token',
                bs_median_as_obs=True
                ):
    out = []
    optimal_pairs = []
    max_loss, min_loss = 0, 1e12

    for c in flop_vals:
        df_ = fetch_flop(big_df, c, loss_key=loss_key, 
                         warmup_remove_factor=warmup_remove_factor, n_key=n_key, 
                         flop_per_token_key=flop_per_token_key,
                         flop_tolerance=flop_tolerance)

        if len(df_) < 3:
            out.append(dict(n_interp=None, loss_interp=None, t_interp=None, 
                            loss_interp_tok=None, opt_ind=None, opt_tok_ind=None, flops=c))
            continue
        if groupby_action == 'min':
            df_ = df_.loc[df_.groupby(['n']).loss.idxmin()]  # take the best value of lr, etc., if there are multiple ones - could potentially do better by interpolating here too
        elif groupby_action == 'mean':
            df_ = df_.groupby('n').mean()
        else:
            raise ValueError(f'Unknown groupby_action {groupby_action}')
        df_ = df_.reset_index()

        interp_num = (len(df_) - 1) * interp_num_multiplier

        max_loss, min_loss = max(max_loss, df_.loss.max()), min(min_loss, df_.loss.min())
        
        n_star_ind_, n_star_std_, noised_n_stars_, n_interp_, loss_interp_, noised_loss, loss_star_std = interpolation(
            df_, interp_num, bootstrap_iters, seed_noise, min_std_factor, interp_num_multiplier, n_star_std_method, 'n')

        t_star_ind_, t_star_std_, noised_t_stars_, t_interp_, loss_interp_tok_, noised_loss, _ = interpolation(
            df_, interp_num, bootstrap_iters, seed_noise, min_std_factor, interp_num_multiplier, t_star_std_method, 't')
        
        if n_star_ind_ != 0 and n_star_ind_ != interp_num -1 and noised_n_stars_ is not None:
            optimal_pairs.append(
                dict(flops=c, n=n_interp_[n_star_ind_], t=t_interp_[t_star_ind_], multiplier=c / 6 / (n_interp_[n_star_ind_]**2),
                     loss=loss_interp_.min(), loss_t=loss_interp_tok_.min(),
                     n_vals=df_.n.values, t_vals=df_.t.values, loss_vals=df_.loss
                    )
            )
        else:
            optimal_pairs.append(
                dict(flops=c, n=None, t=None, loss=None, loss_t=None,
                        n_vals=df_.n.values, t_vals=df_.t.values, loss_vals=df_.loss
                    )
            )
        out.append(
            dict(n_interp=n_interp_, loss_interp=loss_interp_, 
                 t_interp=t_interp_, loss_interp_tok=loss_interp_tok_, 
                 opt_ind=n_star_ind_, opt_tok_ind=t_star_ind_, flops=c, 
                 orig_n=df_.n, orig_t=df_.t, orig_loss=df_.loss)
            )
        if n_star_std_method == 'add_seed_noise':
            # perhaps compute std and mean of opt_inds and min_losses
            # in a separate function (that fit the line)
            # for now it's like that for backward compatibility
            out[-1]['n_star_std'] = n_star_std_
            out[-1]['n_stars'] = noised_n_stars_
            out[-1]['multiplier_stars'] = (c / (6 * np.array(noised_n_stars_)**2)) if noised_n_stars_ is not None else None
            optimal_pairs[-1]['n_star_std'] = n_star_std_ 
            
            out[-1]['multiplier_star_std'] = 2 * n_star_std_ if n_star_std_ is not None else None
            optimal_pairs[-1]['multiplier_star_std'] = 2 * n_star_std_ if n_star_std_ is not None else None

            out[-1]['t_star_std'] = t_star_std_
            out[-1]['t_stars'] = noised_t_stars_
            optimal_pairs[-1]['t_star_std'] = t_star_std_

            out[-1]['loss_stars'] = noised_loss
            out[-1]['loss_star_std'] = loss_star_std 
            optimal_pairs[-1]['loss_star_std'] = loss_star_std

    out_df = pd.DataFrame(out)
    optimal_pairs_df = pd.DataFrame(optimal_pairs)

    if bs_median_as_obs:
        for ind, row in optimal_pairs_df.iterrows():
            if row['n'] is None or np.isnan(row['n']):
                continue
            flop = row['flops']
            data_row = out_df.set_index('flops').loc[flop]
            for key in ['n', 't', 'multiplier', 'loss']:
                optimal_pairs_df.at[ind, key] = np.median(data_row[key + '_stars']) if data_row[key + '_stars'] is not None else row[key]

    return out_df, optimal_pairs_df, max_loss, min_loss


def fit_loss_with_saturation(flops, loss):
    def model_func(F, a, e, alpha):
        return np.logaddexp(a - alpha * np.log(F), e)
    if isinstance(loss, (list, tuple)):
        out = []
        for ll in loss:
            out_item = fit_loss_with_saturation(flops, ll)
            out.append(out_item)
        return out
    else:
        alpha_vals = np.arange(0, 0.4, 0.1)
        e_vals = np.arange(-1, 1.5, 0.5)
        a_vals = np.arange(0, 30, 5)
        best_loss = np.inf
        best_params = None
        results_dict = {}
        for alpha, e, a in list(product(alpha_vals, e_vals, a_vals)):
            init_params = [a, e, alpha]
            try:
                popt, _ = curve_fit(model_func, flops, np.log(loss), p0=init_params, method='trf', ftol=1e-6, xtol=1e-6, max_nfev=100)
                result_loss = huber_loss_objective(popt, flops, loss)
                results_dict[tuple(init_params)] = {'params': popt, 'loss': result_loss}
                if result_loss < best_loss:
                    best_loss = result_loss
                    best_params = popt
            except RuntimeError:
                continue
        
        if best_params is not None:
            A = np.exp(best_params[0])
            E = np.exp(best_params[1])
            alpha = best_params[2]
            return {'A': A, 'E': E, 'alpha': alpha}
        else:
            return None

def predict_and_estimate_cost(df, predict_targets, confidence_level=0.05, anytime=True,
                              max_models_per_flop=50, max_excess_loss=1.0, max_multiplier=100, base_flop_vals=None,
                              **predict_args):
    res = []
    if base_flop_vals is None:
        base_flop_vals = FLOP_VALS
    for i in range(3, len(base_flop_vals) + 1):
        flop_vals = base_flop_vals[:i]
        data, optimal_pairs, max_loss, min_loss = interp_flop(
            df, flop_vals=flop_vals, **predict_args)

        fit_results = fit_compute_optimal_power_laws(optimal_pairs, data)

        # extract the prediction
        def extract_single_prediction(fit_dict):
            pred = dict(exponent=fit_dict['n_exponent'])
            for i, target in enumerate(predict_targets):
                pred[f'prediction_at_{target:.3g}'] = fit_dict['n_coef'] * (target ** fit_dict['n_exponent'])
            return pred

        point_prediction = extract_single_prediction(fit_results['bs_median_weighted'])
        bs_predictions = pd.DataFrame([extract_single_prediction(x) for x in fit_results['bootstrap_weighted']])
        confidence_interval = bs_predictions.quantile([confidence_level / 2, 1 - confidence_level / 2])
        confidence_interval.index = ['lo', 'hi']
        confidence_interval_dict = {f'{k}_{q}': v for (q, k), v in confidence_interval.stack().items()}
        # confidence_interval_dict = {(k, q): v for (q, k), v in confidence_interval.stack().items()}

        # compute the cost
        # make a list of model/flops pairs
        relevant_models = []
        for flop, flop_df in optimal_pairs.dropna().explode(['n_vals', 'loss_vals']).groupby('flops'):
            flop_df = flop_df.copy()
            flop_df['excess_loss'] = flop_df['loss_vals'] - flop_df['loss_vals'].min()
            flop_df['multiplier'] = flop / (6 * flop_df.n_vals ** 2)
            flop_df = flop_df.sort_values('excess_loss').query(
                'excess_loss < @max_excess_loss & multiplier <= @max_multiplier').iloc[
                      :max_models_per_flop]
            relevant_models.append(flop_df[['flops', 'n_vals', 'loss_vals']])
        relevant_models_df = pd.concat(relevant_models, axis=0, ignore_index=True)

        if anytime:
            flop_vals = relevant_models_df.groupby('n_vals').flops.max()
        else:
            flop_vals = relevant_models_df.flops
        cost = flop_vals.sum()

        res.append(dict(max_flop=optimal_pairs.dropna().flops.max(),
                        cost=cost, optimal_pairs=optimal_pairs, bs_data=data,
                        bs_predictions=bs_predictions) | point_prediction | confidence_interval_dict)

    return pd.DataFrame(res).set_index('cost').sort_index(axis=1)


def perform_varying_compute_analysis(df, predict_targets,
                                     config_compute, confidence_level=0.05,
                                     flop_vals=None, seed=42):
    np.random.seed(seed)

    if flop_vals is None:
        flop_vals = FLOP_VALS
    df = df.copy()
    dataset, hparams, warmup, decay, param_count, val = config_compute
    show_df = df.query("dataset==@dataset and hparams==@hparams and warmup==@warmup and decay==@decay")

    df_compute = predict_and_estimate_cost(
            show_df, predict_targets, base_flop_vals=flop_vals,  **ISOFLOP_ARGS[config_compute[-2:]]
        )
    return pd.DataFrame([dict(dataset=dataset, hparams=hparams, warmup=warmup,decay=decay,param_count=param_count,val=val,
                            base_flop_vals=flop_vals, predict_targets=predict_targets, confidence_level=confidence_level, results_df=df_compute)])


def perform_main_analysis(results_df, configs,
                          flop_vals=None,
                          seed=42, seed_noise_args=None,
                          ):
    np.random.seed(seed)

    if flop_vals is None:
        flop_vals = FLOP_VALS
    if seed_noise_args is None:
        seed_noise_args = SEED_ARGS
    df = results_df.copy()
    out = []
    for config in configs:
        dataset, hparams, warmup, decay, param_count, val = config
        show_df = df.query(f"dataset=='{dataset}' and hparams=='{hparams}' and warmup=='{warmup}' and decay=='{decay}'")

        if len(show_df) == 0:
            continue
        data, optimal_pairs, max_loss, min_loss = interp_flop(
            show_df, seed_noise = SEED_ARGS[config], 
            flop_vals=flop_vals, **ISOFLOP_ARGS[config[-2:]]
        )

        fit_results = fit_compute_optimal_power_laws(optimal_pairs, data)

        out.append(dict(dataset=dataset, hparams=hparams, warmup=warmup, decay=decay, param_count=param_count, val=val, 
                        optimal_pairs=optimal_pairs, fit_results=fit_results,
                        data=data, max_loss=max_loss, min_loss=min_loss,))
    return pd.DataFrame(out)


# For hparams sweep
def minimize_with_interp(df, x_key='lr', y_key='loss', interp_num=100, groupby_action='min', interpolator=scipy.interpolate.Akima1DInterpolator):
    df = df.copy().reset_index()
    if groupby_action == 'min':
        df = df.loc[df.groupby([x_key])[y_key].idxmin()]  # take the best value of lr, etc., if there are multiple ones - could potentially do better by interpolating here too
        df = df.set_index(x_key)
    elif groupby_action == 'mean':
        df = df.groupby(x_key).mean()
    else:
        raise ValueError(f'Unknown groupby_action {groupby_action}')
    df = df.sort_index()

    if len(df) < 2:
        return pd.DataFrame({x_key: [np.nan], y_key: [np.nan], 'on_edge': True})
    
    xlog, ylog = np.log(df.index.values), np.log(df[y_key].values)
    interp = interpolator(xlog, ylog)

    xlog_i = np.linspace(xlog.min(), xlog.max(), interp_num)
    ylog_i = interp(xlog_i)

    x_i, y_i = np.exp(xlog_i), np.exp(ylog_i)

    argmin_xlog_i = xlog_i[y_i.argmin()]
    argmin_x_i = x_i[y_i.argmin()]
    on_edge = argmin_x_i < np.exp(xlog[1]) or argmin_x_i > np.exp(xlog[-2])
    out = {x_key: [argmin_x_i], 'on_edge': int(on_edge)+1}

    for key in df.columns:
        if key == 'index':
            continue
        out[key] = [np.exp(interpolator(xlog, np.log(df[key].values))(argmin_xlog_i))]
    return pd.DataFrame(out).set_index(x_key)


# For hparams sweep
def create_pivot_df(df, loss_col='loss'):
    df = df.copy()
    pivot_df = df.pivot_table(index=['lr', 'bs', 'beta2'], columns='params', values=loss_col, aggfunc='first').reset_index()
    pivot_df.columns = [f'final_loss_smoothed_{col:.2e}' if isinstance(col, float) else col for col in pivot_df.columns]
    
    pivot_df = pivot_df.sort_values(by=['bs', 'lr']).reset_index(drop=True)
    return pivot_df


# For hparams sweep
def get_interpolated_hparams_dfs(df_sweep, min_params_for_fit=2.5e7, max_params_for_fit=1.1e8):
    df_sweep = df_sweep.copy()
    df_sweep_opt_eta = df_sweep.drop('excess_loss', axis=1).groupby(['params','bs']).apply(minimize_with_interp).drop(['bs', 'params'], axis=1).reset_index()
    df_sweep_opt_eta_and_bs = df_sweep_opt_eta.groupby(['params']).apply(lambda x: minimize_with_interp(x, x_key='bs')).drop('params', axis=1).reset_index()
    query_str = f"params > {min_params_for_fit} and params < {max_params_for_fit}"
    fit_dict_bs = power_law_fit(df_sweep_opt_eta_and_bs.query(query_str).reset_index().copy(), 'params', 'bs')
    fit_dict_lr = power_law_fit(df_sweep_opt_eta_and_bs.query(query_str).reset_index().copy(), 'params', 'lr')
    fit = {'bs': fit_dict_bs, 'lr': fit_dict_lr}
    return df_sweep_opt_eta_and_bs, fit