from imports import *
from utils import load_json
from torch.utils.data import DataLoader, Dataset

from tango import step
import time
import tango
from tango.common import FromParams
import termplotlib as tpl
from weights_composer import remove_components, get_ov
import sys
from ioi_inhibition_exp import (
    PromptDataset, 
    load_dataset, 
    load_model, 
    DataParams, 
    #ModelParams, 
    calc_inhib_score,
    get_inhibition_scores
)
from my_plotly import *
import plotly.graph_objects as go

class WeightResetter:
    def __init__(self, model, inhib_layer, inhib_head):
        #this is kind of hacky
        self.model = model
        inhib_heads = [(inhib_layer, inhib_head)]#,(7,3), (7,9), (8,6), (8,10)]
        mover_heads = []#[(9,9), (9,6)]
        self.head_weights= {}
        for layer, head in inhib_heads:
            self.head_weights[layer, head] = {
                'V':self.model.blocks[layer].attn.W_V[head].clone(), 
                'O':self.model.blocks[layer].attn.W_O[head].clone()
            }

    def reset_inhib_weights(self):
        for layer, head in self.head_weights.keys():
            self.model.blocks[layer].attn.W_V[head] = self.head_weights[(layer, head)]['V'].clone()
            self.model.blocks[layer].attn.W_O[head] = self.head_weights[(layer, head)]['O'].clone()
        return self.model


@step(cacheable=True, deterministic=True, version='003')
def rm_ov_component( 
    model: FromParams,
    dataset: DataParams,
    inhib_layer: int,
    inhib_head:int,
    comp_idx,
    mover_layer: int, 
    mover_head: int ) -> np.array:

    model=model.model
    dataset = dataset.dataset

    ov = get_ov(model, inhib_layer, inhib_head)
    u, s, v = ov.svd()
    if comp_idx == None:
        comp_idx = []
    comp = remove_components(u, s, v, comp_idx)
    model.blocks[inhib_layer].attn.W_V[inhib_head]= comp.A
    model.blocks[inhib_layer].attn.W_O[inhib_head] = comp.B
    inhib_scores = []
    #dataset = load_dataset(path=dataset_path, batch_size=20)
    def get_prompt(prompts, idx):
        newprompt = dict.fromkeys(prompts)
        for key in prompts:
            newprompt[key] = prompts[key][idx]
        return newprompt
    for batch in dataset:
        text = batch['text']
        #print(text, model.to_str_tokens(text))
        _, cache = model.run_with_cache(text)
        for batch_idx in range(len(text)):
            cur_prompt = get_prompt(batch, batch_idx)
            score = calc_inhib_score(model, cur_prompt, cache.apply_slice_to_batch_dim(batch_idx), mover_layer, mover_head)
            inhib_scores.append(score.item())

    return np.array(inhib_scores)

def plot_scores(scores):
    x = np.arange(len(scores))
    y = scores

    fig = tpl.figure()
    counts, bin_edges = np.histogram(y, range=(-1.0, 1.0), bins=11)
    fig = tpl.figure()
    fig.hist(counts, bin_edges, orientation="horizontal", force_ascii=False)
    fig.show()

def show_plot(scores, vanilla_score, return_fig=True):
    x = np.arange(len(scores))
    print(scores.shape)
    mean = scores.mean(axis=1)
    rmcomp_line = go.Scatter(x=x, y=mean, mode='lines+markers', name='Remove Component')
    vanilla_line= go.Scatter(x=[0,64], y=[vanilla_score, vanilla_score], mode='lines', name='Vanilla',line={"color": "blue", "dash": 'dot'})
    #[0.7010605,0.7010605] vanilla score for gpt2 9 9 

    print(scores.std(axis=1).shape)
    #print(mean+scores.std(axis=1))
    y_upper = (mean + scores.std(axis=1)/np.sqrt(scores.shape[1]))
    y_lower = (mean - scores.std(axis=1)/np.sqrt(scores.shape[1]))

    shading = go.Scatter(
            x=x.tolist()+x[::-1].tolist(), # x, then x reversed
            y=y_upper.tolist()+y_lower[::-1].tolist(), # upper, then lower reversed
            fill='toself',
            fillcolor='rgba(00,100,80,0.2)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        )


    fig = go.Figure(data=[vanilla_line, shading, rmcomp_line ])
    if return_fig:
        return fig
    fig.show()


import tango
from tango.common import FromParams
from weights_composer import remove_components
from tango.common.det_hash import CustomDetHash

def load_model(model_name):
    return HookedTransformer.from_pretrained(model_name, fold_value_biases=True, refactor_factored_attn_matrices=False)

class ModelParams(CustomDetHash):
    CACHEABLE=True
    DETERIMINISTIC=True
    def __init__(self, model_name: str, should_load=True):
        self.model = None
        if should_load:
            self.model = load_model(model_name)
        self.model_name = model_name
    
    def load(self):
        self.model = load_model(self.model_name)
    
    def det_hash_object(self):
        return str(self.model_name)

if __name__ == "__main__":
    inhib_layer, inhib_head = int(sys.argv[1]), int(sys.argv[2]) 
    mover_layer, mover_head = int(sys.argv[3]), int(sys.argv[4])
    print("HEADS", inhib_layer, inhib_head, mover_layer, mover_head)
    ws = tango.Workspace.from_url("./tango_workspace")

    mode = 'opt-125m'

    #model_name = 'gpt2-small'
    model_name = 'opt-125m'
    if mode == 'pythia':
        model_name = 'EleutherAI/pythia-160m'#'EleutherAI/pythia-1b'
    dataset_path = 'datasets/ioi_dataset_200.json'
    if mode == "pythia":
        dataset_path = 'datasets/pythia_ioi_dataset_200.json'
    model_params = ModelParams(model_name)
    weight_resetter = WeightResetter(model_params.model, inhib_layer, inhib_head)
    data_params = DataParams(dataset_path, batch_size=25)

    #inhib_layer, inhib_head = 7, 9
    #mover_layer, mover_head= 9, 9

    start_time = time.time()
    scores = rm_ov_component(
        model=model_params, 
        dataset=data_params, 
        inhib_layer=inhib_layer, 
        inhib_head=inhib_head,
        comp_idx=None, 
        mover_layer=mover_layer, 
        mover_head=mover_head).result(ws)
    end_time = time.time()
    print("TIME", end_time - start_time)
    plot_scores(scores)
    print(scores.mean(), 'VANILLA MEAN')
    vanilla_score = scores.mean()

    all_scores = []
    d_head = model_params.model.cfg.d_head
    for i in range(d_head):
        print("INDEX:", i, flush=True)
        model_params.model = weight_resetter.reset_inhib_weights()
        start_time = time.time()
        scores = rm_ov_component(
            model=model_params, 
            dataset=data_params, 
            inhib_layer=inhib_layer, 
            inhib_head=inhib_head,
            comp_idx=[i], 
            mover_layer=mover_layer, 
            mover_head=mover_head).result(ws)
        end_time = time.time()
        print("TIME", end_time - start_time)
        all_scores.append(scores)
        plot_scores(scores)
        print(scores.mean(), 'MEAN', flush=True)
    print('all scores', np.array(all_scores).shape)
    #np.save(f'exp_site/results/inhibition_exp_rm_comps/{inhib_layer}_{inhib_head}_{mover_layer}_{mover_head}_rm_one.npy', np.array(all_scores))
    
    fig = show_plot(np.array(all_scores), vanilla_score, return_fig=True)
    fig.update_layout(title=f'Remove Component from {inhib_layer}.{inhib_head} to {mover_layer}.{mover_head}', xaxis_title='Component Index', yaxis_title='Inhibition Score')
    if '/' in model_name:
        mname = model_name.split('/')[-1]
    else:
        mname = model_name
    fig_to_json(fig, f'exp_site/results/inhibition_exp_rm_comps/{mname}_{inhib_layer}_{inhib_head}_{mover_layer}_{mover_head}_rm_inh_scores.json')

    model_params.model = weight_resetter.reset_inhib_weights()
    start_time = time.time()
    scores = rm_ov_component(
        model=model_params, 
        dataset=data_params, 
        inhib_layer=inhib_layer, 
        inhib_head=inhib_head,
        comp_idx=[], 
        mover_layer=mover_layer, 
        mover_head=mover_head).result(ws)
    end_time = time.time()
    print("TIME", end_time - start_time)
    plot_scores(scores)
    print(scores.mean(), 'MEAN')

