import torch
import math
import re
import os
PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")

def get_last_checkpoint_trainerstate_robust(folder):
    content = os.listdir(folder)
    checkpoints = [
        path
        for path in content
        if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) and os.path.exists(os.path.join(folder, path, "trainer_state.json"))
    ]
    if len(checkpoints) == 0:
        return
    return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))

def sinusoidal_encoding(max_position, d_model, min_freq=1e-4):
    position = torch.arange(max_position).float() + 1
    mask = torch.arange(d_model) + 1
    sin_mask = mask % 2 == 0
    cos_mask = ~sin_mask
    exponent = (2 * mask) / d_model
    freqs = min_freq ** exponent
    angles = position.unsqueeze(1) * freqs.unsqueeze(0)
    pos_enc = torch.cos(angles) * cos_mask + torch.sin(angles) * sin_mask

    test_pos = max_position // 2
    test_index = d_model // 2

    if test_index % 2 == 0:

        if max_position > 1:
            assert torch.allclose(
                pos_enc[test_pos - 1, test_index - 1],
                torch.Tensor(
                    [math.sin(test_pos * min_freq ** ((2 * test_index) / d_model))]
                ),
            )

        assert torch.allclose(
            pos_enc[test_pos - 1, test_index],
            torch.Tensor(
                [math.cos(test_pos * min_freq ** ((2 * test_index + 2) / d_model))]
            ),
        )

    else:
        if max_position > 1:
            assert torch.allclose(
                pos_enc[test_pos - 1, test_index - 1],
                torch.Tensor(
                    [math.cos(test_pos * min_freq ** ((2 * test_index) / d_model))]
                ),
            )

        assert torch.allclose(
            pos_enc[test_pos - 1, test_index],
            torch.Tensor(
                [math.sin(test_pos * min_freq ** ((2 * test_index + 2) / d_model))]
            ),
        )
    return pos_enc


def random_encoding(max_positions, d_model, norm=1):

    gauss = torch.randn((max_positions, d_model))
    gauss = gauss / torch.norm(gauss, dim=1).unsqueeze(1)
    gauss *= norm
    return gauss


def random_encoding_fourier(max_positions, d_model):

    B_gauss = torch.randn((max_positions, d_model))


# def fourier_mapping(x, B):
#     x_proj = (2.*np.pi*x) @ B.T
#     return np.concatenate([np.sin(x_proj), np.cos(x_proj)], axis=-1)


def topk(
    logits,
    gt_classes,
    k_list,
):
    assert len(logits.shape) == 2
    assert len(gt_classes.shape) == 1
    batch, _ = logits.shape
    max_k = max(k_list)
    top_labels_max_k = torch.topk(logits, max_k, dim=1)[1]
    return [
        torch.sum(top_labels_max_k[:, :k] == gt_classes.unsqueeze(1)) / batch
        for k in k_list
    ]


def gen_attn_mask(sequence_length, len=None):
    batch_size = sequence_length.size(0)
    seq_range = torch.arange(len)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, len)
    seq_range_expand = seq_range_expand.to(sequence_length.device)
    seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
    return seq_range_expand < seq_length_expand


def binary_encoding(max_position, d_model, epsilon=0.3):
    assert epsilon <= 1 and epsilon >= 0, "epsilon value should lie in [0,1)"
    chunk_size = d_model // max_position
    start_of_chunks = chunk_size * torch.arange(max_position)
    end_of_chunks = start_of_chunks + chunk_size
    end_of_chunks[-1] = d_model
    # tweak start and end states to account for epsilon
    num_intersection = (epsilon / 2) * chunk_size
    start_of_chunks[1:] = start_of_chunks[1:] - num_intersection
    end_of_chunks[:-1] = end_of_chunks[:-1] + num_intersection

    # for loop here :( , not worth vectorizing, only called once
    binary_embeds = torch.zeros(max_position, d_model)
    for pos in range(max_position):
        binary_embeds[pos, start_of_chunks[pos] : end_of_chunks[pos]] = 1
    return binary_embeds

def count_params_hf(model):
    return sum([math.prod(v.shape) for _, v in model.items()])
    
def count_params_hf(model):
    params = {k: v for k, v in model.named_parameters()}
    return sum([math.prod(v.shape) for _, v in params.items()])
