__all__ = ['PowerGPT']

from typing import Union, Tuple
import torch
import torch_geometric
from torch_geometric.typing import (
    Adj,
    OptTensor,
    pyg_lib
)
from copy import deepcopy
# Cell
from .powergpt_components.basics import *
from .powergpt_components.attention import *
from torch_geometric.nn.conv import RGCNConv

from .powergpt_components.pos_encoding import positional_encoding


class RelationGCN(RGCNConv):
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        num_relations: int,
        num_patch: int,
        num_bases: Optional[int] = None,
        num_blocks: Optional[int] = None,
        aggr: str = 'mean',
        root_weight: bool = False,
        is_sorted: bool = False,
        bias: bool = False,
        **kwargs,
    ):
        super().__init__(in_channels, out_channels, num_relations, num_bases, num_blocks, aggr, root_weight, is_sorted, bias, **kwargs)
        self.num_patch = num_patch
        
    def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor:
        if torch_geometric.typing.WITH_PYG_LIB and edge_type_ptr is not None:
            N, D = x_j.shape
            x_j = x_j.reshape(N*self.num_patch, -1)
            edge_type_ptr = edge_type_ptr * self.num_patch
            return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight).reshape(N, D)
        return x_j

# Cell
class PowerGPT(nn.Module):
    """
    Output dimension:
         [bs x target_dim x nvars] for prediction
         [bs x target_dim] for regression
         [bs x target_dim] for classification
         [bs x num_patch x n_vars x patch_len] for pretrain
    """
    def __init__(self,
                 c_in:int,
                 target_dim:int,
                 patch_len:int,
                 context_points:int,
                 stride:int,
                 num_patch:int,
                 mask_ratio:float=0.4,
                 n_layers:int=3,
                 d_model=128,
                 n_heads=16,
                 shared_embedding=True,
                 d_ff:int=256,
                 norm:str='BatchNorm',
                 attn_dropout:float=0.,
                 dropout:float=0.,
                 act:str="swiglu",
                 res_attention:bool=True,
                 pre_norm:bool=True,
                 store_attn:bool=False,
                 pe:str='zeros',
                 learn_pe:bool=True,
                 head_dropout = 0,
                 head_type = "prediction",
                 name = None,
                 individual = False,
                 y_range:Optional[tuple]=None,
                 verbose:bool=False,
                 # R-GAT
                 fusion_num_relations=8,
                 num_layers=2,
                 **kwargs):

        super().__init__()

        assert head_type in ['pretrain', 'prediction', 'regression', 'classification', 'imputation'], 'head type should be either pretrain, prediction, or regression'
        # Backbone
        self.backbone = PowerGPTEncoder(c_in=c_in,
                                        num_patch=num_patch,
                                        patch_len=patch_len,
                                        n_layers=n_layers,
                                        d_model=d_model,
                                        n_heads=n_heads,
                                        shared_embedding=shared_embedding,
                                        d_ff=d_ff,
                                        attn_dropout=attn_dropout,
                                        dropout=dropout,
                                        act=act,
                                        norm=norm,
                                        res_attention=res_attention,
                                        pre_norm=pre_norm,
                                        store_attn=store_attn,
                                        pe=pe,
                                        learn_pe=learn_pe,
                                        verbose=verbose,
                                        head_type=head_type,
                                        **kwargs)

        # Head
        self.n_vars = c_in
        self.head_type = head_type
        self.name = name
        self.patch_len = patch_len
        self.stride = stride
        self.mask_ratio = mask_ratio
        self.target_dim = target_dim
        self.context_points = context_points
        self.relational_gcn_layers = nn.ModuleList([RelationGCN(d_model, d_model, num_relations=fusion_num_relations, num_patch=num_patch)])
        for _ in range(num_layers-1):
            self.relational_gcn_layers.append(RelationGCN(d_model, d_model, num_relations=fusion_num_relations, num_patch=num_patch))

        if head_type == "pretrain" or head_type == 'imputation':
            self.head = PretrainHead(d_model, patch_len, head_dropout)  # custom head passed as a partial func with all its kwargs
        elif head_type == "prediction":
            self.head = PredictionHead(individual, self.n_vars, d_model, num_patch, target_dim, head_dropout)
        elif head_type == "regression":
            self.head = RegressionHead(self.n_vars, d_model, target_dim, head_dropout, y_range)
        elif head_type == "classification":
            self.head = ClassificationHead(self.n_vars, d_model, target_dim, head_dropout)

    def forward(self, data):
        """
        z: tensor [bs x num_patch x n_vars x patch_len]
        """

        if self.head_type == 'pretrain':
# 为什么要修改data.x？
            x = data.x.clone()
            mask = data.mask.clone()
            node_attr = deepcopy(data.node_attr)
            cov = data.x_cov.clone()
            with torch.no_grad():
                data.x = x[data.batch_size_:]
                data.mask = mask[data.batch_size_:]
                data.node_attr = node_attr[data.batch_size_:]
                data.x_cov = cov[data.batch_size_:]
                z_nei = self.backbone(data)
            data.x = x[:data.batch_size_]
            data.mask = mask[:data.batch_size_]
            data.node_attr = node_attr[:data.batch_size_]
            data.x_cov = cov[:data.batch_size_]
            z_tar = self.backbone(data)

            z = torch.cat((z_tar,z_nei), dim = 0)
            data.x = x
            data.mask = mask
            data.node_attr = node_attr
            data.x_cov = cov
        else:
            z = self.backbone(data)
            data.x = z

        # else:
        #     x = data.x.clone()
        #     node_attr = deepcopy(data.node_attr)
        #     with torch.no_grad():
        #         data.x = x[data.batch_size_:]
        #         data.node_attr = node_attr[data.batch_size_:]
        #         z_nei = self.backbone(data)
        #     data.x = x[:data.batch_size_]
        #     data.node_attr = node_attr[:data.batch_size_]
        #     z_tar = self.backbone(data)
        #
        #     z = torch.cat((z_tar, z_nei), dim=0)
        #     data.x = z
        #     data.node_attr = node_attr

        batch_size, n_vars, d_model, num_patch = z.shape
        z = z.reshape(batch_size, -1)

        for layer in self.relational_gcn_layers:
            z = layer(z, data.edge_index, data.edge_type)
        z = z.reshape(batch_size, n_vars, num_patch, d_model)
        z = torch.mean(z, dim=2)
        # z = self.head(z)

        return z

class PretrainHead(nn.Module):
    def __init__(self, d_model, patch_len, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(d_model, patch_len)

    def forward(self, x):
        """
        x: tensor [bs x nvars x d_model x num_patch]
        output: tensor [bs x nvars x num_patch x patch_len]
        """
        x = self.linear(self.dropout(x))  # [bs x nvars x num_patch x patch_len]
        x = x.permute(0, 2, 1, 3)  # [bs x num_patch x nvars x patch_len]
        return x


class PredictionHead(nn.Module):
    def __init__(self, individual, n_vars, d_model, num_patch, forecast_len, head_dropout=0, flatten=False):
        super().__init__()

        self.individual = individual
        self.n_vars = n_vars
        self.flatten = flatten
        head_dim = d_model*num_patch

        if self.individual:
            self.linears = nn.ModuleList()
            self.dropouts = nn.ModuleList()
            self.flattens = nn.ModuleList()
            for i in range(self.n_vars):
                self.flattens.append(nn.Flatten(start_dim=-2))
                self.linears.append(nn.Linear(head_dim, forecast_len))
                self.dropouts.append(nn.Dropout(head_dropout))
        else:
            self.flatten = nn.Flatten(start_dim=-2)
            self.linear = nn.Linear(head_dim, forecast_len)
            self.dropout = nn.Dropout(head_dropout)


    def forward(self, x):
        """
        x: [bs x nvars x d_model x num_patch]
        output: [bs x forecast_len x nvars]
        """
        if self.individual:
            x_out = []
            for i in range(self.n_vars):
                z = self.flattens[i](x[:,i,:,:])          # z: [bs x d_model * num_patch]
                z = self.linears[i](z)                    # z: [bs x forecast_len]
                z = self.dropouts[i](z)
                x_out.append(z)
            x = torch.stack(x_out, dim=1)         # x: [bs x nvars x forecast_len]
        else:
            x = self.flatten(x)     # x: [bs x nvars x (d_model * num_patch)]
            x = self.dropout(x)
            x = self.linear(x)      # x: [bs x nvars x forecast_len]
        return x.transpose(2,1)     # [bs x forecast_len x nvars]


class RegressionHead(nn.Module):
    def __init__(self, n_vars, d_model, output_dim, head_dropout, y_range=None):
        super().__init__()
        self.y_range = y_range
        self.flatten = nn.Flatten(start_dim=1)
        self.dropout = nn.Dropout(head_dropout)
        self.linear = nn.Linear(n_vars*d_model, output_dim)

    def forward(self, x):
        """
        x: [bs x nvars x d_model x num_patch]
        output: [bs x output_dim]
        """
        x = x[:,:,:,-1]             # only consider the last item in the sequence, x: bs x nvars x d_model
        x = self.flatten(x)         # x: bs x nvars * d_model
        x = self.dropout(x)
        y = self.linear(x)         # y: bs x output_dim
        if self.y_range: y = SigmoidRange(*self.y_range)(y)
        return y



class ClassificationHead(nn.Module):
    def __init__(self, n_vars, d_model, n_classes, head_dropout):
        super().__init__()
        self.flatten = nn.Flatten(start_dim=1)
        self.dropout = nn.Dropout(head_dropout)
        self.linear = nn.Linear(n_vars*d_model, n_classes)

    def forward(self, x):
        """
        x: [bs x nvars x num_patch x d_model]
        output: [bs x n_classes]
        """
        x = x[:,:,-1,:]             # only consider the last item in the sequence, x: bs x nvars x d_model
        x = self.flatten(x)         # x: bs x nvars * d_model
        x = self.dropout(x)
        y = self.linear(x)         # y: bs x n_classes
        return y



class PowerGPTEncoder(nn.Module):
    def __init__(self, c_in, num_patch, patch_len,
                 n_layers=3, d_model=128, n_heads=16, shared_embedding=True,
                 d_ff=256, norm='RMSNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False,
                 res_attention=True, pre_norm=False,
                 pe='zeros', learn_pe=True, verbose=False, head_type='pretrain', **kwargs):

        super().__init__()
        self.n_vars = c_in
        self.num_patch = num_patch
        self.patch_len = patch_len
        self.d_model = d_model
        self.shared_embedding = shared_embedding
        self.head_type = head_type
        # self.max_seq_len = 4096

        # Input encoding: projection of feature vectors onto a d-dim vector space
        if not shared_embedding:
            self.W_P = nn.ModuleList()
            for _ in range(self.n_vars): self.W_P.append(nn.Linear(patch_len, d_model))
        else:
            self.W_P = nn.Linear(patch_len, d_model)

        self.W_pos = positional_encoding(pe, learn_pe, num_patch + 1, d_model)

        # Positional encoding
        self.mask_token = nn.Parameter(torch.randn(d_model))

        self.weat_embs_tab = nn.Embedding(15, d_model)
        self.temp_embs_tab = nn.Embedding(45, d_model)
        self.week_embs_tab = nn.Embedding(7, d_model)
        self.holi_embs_tab = nn.Embedding(2, d_model)
        self.instance_embs = nn.Embedding(6, d_model)
        # Residual dropout
        self.dropout = nn.Dropout(dropout)

        # Encoder
        self.encoder = TSTEncoder(d_model, n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
                                  pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers,
                                  store_attn=store_attn)

    def forward(self, data) -> Tensor:
        """
        x: tensor [bs x num_patch x nvars x patch_len]
        """
        x = data.x
        bs, num_patch, n_vars, patch_len = x.shape
        # Input encoding
        if not self.shared_embedding:
            x_out = []
            for i in range(n_vars):
                z = self.W_P[i](x[:, :, i, :])
                x_out.append(z)
            x = torch.stack(x_out, dim=2)
        else:
            x = self.W_P(x)  # x: [bs x num_patch x nvars x d_model]
        x = x.transpose(1, 2)  # x: [bs x nvars x num_patch x d_model]

        if self.head_type == 'pretrain':
            mask = data.mask
            patch_mask = mask[:, :, :, 0].permute(0, 2, 1)
            x[patch_mask] = self.mask_token
            u = torch.reshape(x, (bs * n_vars, num_patch, self.d_model))  # u: [bs * nvars x num_patch x d_model]

            # shape 4个
            cov_weat = self.weat_embs_tab(data.x_cov[:,:,0,:].reshape(-1)).reshape(bs,num_patch,-1,self.d_model)
            cov_temp = self.temp_embs_tab(data.x_cov[:,:,1,:].reshape(-1)).reshape(bs,num_patch,-1,self.d_model)
            cov_week = self.week_embs_tab(data.x_cov[:,:,2,:].reshape(-1)).reshape(bs,num_patch,-1,self.d_model)
            cov_holi = self.holi_embs_tab(data.x_cov[:,:,3,:].reshape(-1)).reshape(bs,num_patch,-1,self.d_model)
            cov_d = torch.mean((cov_weat + cov_temp + cov_week + cov_holi)/4, dim = 2)
            u = u + cov_d
            instance_p = self.instance_embs(data.node_attr).unsqueeze(1) 
            u = torch.cat([instance_p, u], dim=1)
            u = self.dropout(u + self.W_pos)  # u: [bs * nvars x num_patch x d_model]
            # Encoder
            z = self.encoder(u)  # z: [bs * nvars x num_patch x d_model]
            z = torch.reshape(z, (-1, n_vars, num_patch + 1, self.d_model))  # z: [bs x nvars x num_patch x d_model]
            z = z.permute(0, 1, 3, 2)  # z: [bs x nvars x d_model x num_patch]
            z = z[:, :, :, 1:]
            return z
        else:
            u = torch.reshape(x, (bs * n_vars, num_patch, self.d_model))  # u: [bs * nvars x num_patch x d_model]
            instance_p = self.instance_embs(data.node_attr).unsqueeze(1)
            u = torch.cat([instance_p, u], dim=1)
            u = self.dropout(u + self.W_pos)  # u: [bs * n_vars x num_patch x d_model]
            z = self.encoder(u)  # z: [bs * n_vars x num_patch x d_model]
            z = torch.reshape(z, (-1, n_vars, num_patch + 1, self.d_model))  # z: [bs x n_vars x num_patch x d_model]
            z = z.permute(0, 1, 3, 2)  # z: [bs x n_vars x d_model x num_patch]
            z = z[:, :, :, 1:]
            return z


# Cell
class TSTEncoder(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=None,
                 norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
                 res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
        super().__init__()

        self.layers = nn.ModuleList([TSTEncoderLayer(d_model, n_heads=n_heads, d_ff=d_ff, norm=norm,
                                                     attn_dropout=attn_dropout, dropout=dropout,
                                                     activation=activation, res_attention=res_attention,
                                                     pre_norm=pre_norm, store_attn=store_attn) for i in
                                     range(n_layers)])
        self.res_attention = res_attention

    def forward(self, src: Tensor):
        """
        src: tensor [bs x q_len x d_model]
        """
        output = src
        scores = None
        if self.res_attention:
            for mod in self.layers: output, scores = mod(output, prev=scores)
            return output
        else:
            for mod in self.layers: output = mod(output)
            return output


class TSTEncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=256, store_attn=False,
                 norm='BatchNorm', attn_dropout=0, dropout=0., bias=True,
                 activation="gelu", res_attention=False, pre_norm=True):
        super().__init__()
        assert not d_model % n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
        d_k = d_model // n_heads
        d_v = d_model // n_heads

        # Multi-Head attention
        self.res_attention = res_attention
        self.self_attn = MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout,
                                            res_attention=res_attention)

        # Add & Norm
        self.dropout_attn = nn.Dropout(dropout)
        if "batch" in norm.lower():
            self.norm_attn = nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2))
        elif 'rms' in norm.lower():
            self.norm_attn = RMSNorm(d_model)
        else:
            self.norm_attn = nn.LayerNorm(d_model)

        # Position-wise Feed-Forward
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
                                get_activation_fn(activation),
                                nn.Dropout(dropout),
                                nn.Linear(d_ff, d_model, bias=bias))

        # Add & Norm
        self.dropout_ffn = nn.Dropout(dropout)
        if "batch" in norm.lower():
            self.norm_ffn = nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2))
        elif 'rms' in norm.lower():
            self.norm_ffn = RMSNorm(d_model)
        else:
            self.norm_ffn = nn.LayerNorm(d_model)
        # self.pe = RoPE(d_model)
        self.pre_norm = pre_norm
        self.store_attn = store_attn

    def forward(self, src: Tensor, prev: Optional[Tensor] = None):
        """
        src: tensor [bs x q_len x d_model]
        """
        # Multi-Head attention sublayer
        if self.pre_norm:
            src = self.norm_attn(src)

        q, k, v = src, src, src
        # q, k = self.pe(q), self.pe(k)
        ## Multi-Head attention
        if self.res_attention:
            src2, attn, scores = self.self_attn(q, k, v, prev)
        else:
            src2, attn = self.self_attn(q, k, v)
        if self.store_attn:
            self.attn = attn
        ## Add & Norm
        src = src + self.dropout_attn(src2)  # Add: residual connection with residual dropout
        if not self.pre_norm:
            src = self.norm_attn(src)

        # Feed-forward sublayer
        if self.pre_norm:
            src = self.norm_ffn(src)
        ## Position-wise Feed-Forward
        src2 = self.ff(src)
        ## Add & Norm
        src = src + self.dropout_ffn(src2)  # Add: residual connection with residual dropout
        if not self.pre_norm:
            src = self.norm_ffn(src)

        if self.res_attention:
            return src, scores
        else:
            return src



def patch_masking(x, patch_len, stride, mask_ratio):
    """
    xb: [bs x seq_len x n_vars] -> [bs x num_patch x n_vars x patch_len]
    """
    bs = x.shape[0]
    x_patch, num_patch = create_patch(x, patch_len, stride)  # xb_patch: [bs x num_patch x n_vars x patch_len]
    x_mask_t, _, mask_t, _ = random_masking(x_patch[bs//2: ], mask_ratio)  # xb_mask: [bs x num_patch x n_vars x patch_len]
    x_mask_b, mask_b= pred_masking(x_patch[: bs//2], mask_ratio)
    mask = torch.cat([mask_t, mask_b], dim=0).bool()  # mask: [bs x num_patch x n_vars]
    x = torch.cat([x_mask_t, x_mask_b], dim=0)  # learner.xb: masked 4D tensor
    y = x_patch  # learner.yb: non-masked 4d tensor
    return x, y, mask

def _loss(self, preds, target):
    """
    preds:   [bs x num_patch x n_vars x patch_len]
    targets: [bs x num_patch x n_vars x patch_len]
    """
    loss = (preds - target) ** 2
    loss = loss.mean(dim=-1)
    loss = (loss * self.mask).sum() / self.mask.sum()
    return loss


def create_patch(xb, patch_len, stride):
    """
    xb: [bs x seq_len x n_vars]
    """
    seq_len = xb.shape[1]
    num_patch = (max(seq_len, patch_len) - patch_len) // stride + 1
    tgt_len = patch_len + stride * (num_patch - 1)
    s_begin = seq_len - tgt_len

    xb = xb[:, s_begin:, :]  # xb: [bs x tgt_len x nvars]
    xb = xb.unfold(dimension=1, size=patch_len, step=stride)  # xb: [bs x num_patch x n_vars x patch_len]
    return xb, num_patch


def random_masking(xb, mask_ratio):
    # xb: [bs x num_patch x n_vars x patch_len]
    bs, L, nvars, D = xb.shape
    x = xb.clone()

    len_keep = int(L * (1 - mask_ratio))

    noise = torch.rand(bs, L, nvars, device=xb.device)  # noise in [0, 1], bs x L x nvars

    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)  # ids_restore: [bs x L x nvars]

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep, :]  # ids_keep: [bs x len_keep x nvars]
    x_kept = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, 1,
                                                                        D))  # x_kept: [bs x len_keep x nvars  x patch_len]

    # removed x
    x_removed = torch.zeros(bs, L - len_keep, nvars, D,
                            device=xb.device)  # x_removed: [bs x (L-len_keep) x nvars x patch_len]
    x_ = torch.cat([x_kept, x_removed], dim=1)  # x_: [bs x L x nvars x patch_len]

    # combine the kept part and the removed one
    x_masked = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, 1,
                                                                              D))  # x_masked: [bs x num_patch x nvars x patch_len]

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([bs, L, nvars], device=x.device)  # mask: [bs x num_patch x nvars]
    mask[:, :len_keep, :] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)  # [bs x num_patch x nvars]
    return x_masked, x_kept, mask, ids_restore


def pred_masking(xb, mask_ratio):
    bs, L, nvars, D = xb.shape
    masked_len = int(L * mask_ratio)
    x_masked = xb
    x_masked[:,-masked_len:,:,:] = 0
    mask = torch.zeros([bs, L, nvars])
    mask[:,-masked_len:,:] = 1
    
    return x_masked, mask
    

class MaskMSELoss(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, pred, target, mask):
        loss = (pred[mask] - target[mask]) ** 2
        loss = loss.mean()
        return loss


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        Apply the RMSNorm normalization to the input tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The normalized tensor.

        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass through the RMSNorm layer.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.

        """
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class RoPE(nn.Module):
    # 旋转位置编码; 参考LLAMA
    # attention 之前计算，只对k, q进行计算
    def __init__(self, d_model, max_len=5000) -> None:
        super(RoPE, self).__init__()
        # 生成旋转矩阵
        freqs_cis = self.precompute_freqs_cis(d_model, max_len)
        self.register_buffer('freqs_cis', freqs_cis)

    def precompute_freqs_cis(self, d_model, max_len, theta=10000.0):
        # token 两两分组后，每组对应的旋转角度\theta_i
        freqs = 1.0 / (theta ** (torch.arange(0, d_model, 2)[: (d_model // 2)].float() / d_model))
        # 生成token索引序列 t = [0, 1, ..., max_len-1]
        t = torch.arange(max_len)
        # freqs.shape: (max_len, d_model//2)
        freqs = torch.outer(t, freqs).float()  # m * \theta

        # make freqs to complex number
        # if freqs = [x, y], then:
        # freqs_cis = [cos(x)+sin(x)i, cos(y)+sin(y)i]
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def broadcast(self, freqs_cis, x):
        dim_n = len(x.shape)
        assert freqs_cis.shape == (x.shape[1], x.shape[-1])  # (seq_len, d_model)
        shape = [d if i == 1 or i == dim_n - 1 else 1 for i, d in enumerate(x.shape)]
        return freqs_cis.view(*shape)

    def forward(self, x):
        # shape of x: (batch_n, seq_len, d_model)
        # the shape may be : (batch_n, seq_len, n_head, d_model)
        x_ = x.reshape(*x.shape[:-1], -1, 2).float()
        # convert to complex number
        x_ = torch.view_as_complex(x_)
        # apply RoPE and convert to real number
        freqs_cis = self.broadcast(self.freqs_cis[: x.shape[1]], x_)
        x_out = torch.view_as_real(x_ * freqs_cis).flatten(-2)
        return x_out.type_as(x)