import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from torch.utils.checkpoint import checkpoint
from torch.autograd.function import Function

import warnings
import gc
import re
import functools
import tqdm
from contextlib import contextmanager
from dataclasses import asdict
from enum import Enum
from typing import List, Literal, Optional, Union, Any

from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from binarization.binarized_modules import BinaryMoS

def get_blocks(model):
    if model.__class__.__name__ == 'LlamaForCausalLM':
        layers = model.model.layers
    elif model.__class__.__name__ == 'OPTForCausalLM':
        layers = model.model.decoder.layers
    else:
        raise NotImplementedError(type(model))
    return layers
            
def replace_with_mos(root_module, args, config):
    module_name_dict = {name: module for name, module in root_module.named_modules()}
    for name, module in module_name_dict.items():
        if isinstance(module, nn.Linear):
            ind = name.rfind(".")
            if ind == -1:
                father = module_name_dict[""]
            else:
                father = module_name_dict[name[:ind]]
            moe_linear = BinaryMoS(config, module.weight, module.bias)
            setattr(father, name[ind + 1 :], moe_linear)
            print(f"replace layer {name} with {moe_linear}")
            moe_linear.global_name = args.model_id + name   

def to_regular_linear_mos(root_module):
    module_name_dict = {name: module for name, module in root_module.named_modules()}
    for name, module in module_name_dict.items():
        if isinstance(module, BinaryMoS):
            ind = name.rfind(".")
            if ind == -1:
                father = module_name_dict[""]
            else:
                father = module_name_dict[name[:ind]]
            linear = module.to_regular_linear()
            setattr(father, name[ind + 1 :], linear)
            print(f"replace layer {name} with {linear}")
