import torch
import torch.nn as nn
from clip.model import CLIP

class ClipViTEmbedder(nn.Module):
    def __init__(self, model: CLIP, class_emb_only: bool = False):
        super().__init__()
        self.model = model
        '''
        self.model.visual.transformer.resblocks = nn.Sequential(
            *list(self.model.visual.transformer.resblocks)[:-1]
        )
        '''
        self.class_emb_only = class_emb_only

        self.eval()

    def forward(self, x):
        m = self.model.visual
        
        x = m.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat(
            [
                m.class_embedding.to(x.dtype)
                + torch.zeros(
                    x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
                ),
                x,
            ],
            dim=1,
        )  # shape = [*, grid ** 2 + 1, width]
        x = x + m.positional_embedding.to(x.dtype)
        x = m.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = m.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        
        x = m.ln_post(x[:, 0, :])
        
        x = x @ m.proj
        return x