import torch
from torch import nn
import torch.nn.functional as F
from .ETP import ETP
from sentence_transformers import SentenceTransformer


class FASTopic(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.args = args

        vocab_size = args.vocab_size
        num_topics = args.num_topics
        sentence_model = args.model.sentence_model
        self.device = args.device

        self.sentence_model = SentenceTransformer(sentence_model, device=self.device)
        self.doc_embeddings = self.load_contextual_embed(args.train_texts, device=self.device)
        self.doc_embeddings = nn.Parameter(torch.from_numpy(self.doc_embeddings), requires_grad=False)

        embed_size = self.doc_embeddings.shape[1]

        self.word_embeddings = nn.init.trunc_normal_(torch.empty(vocab_size, embed_size))
        self.word_embeddings = nn.Parameter(F.normalize(self.word_embeddings))

        self.topic_embeddings = torch.empty((num_topics, self.word_embeddings.shape[1]))
        nn.init.trunc_normal_(self.topic_embeddings, std=0.1)
        self.topic_embeddings = nn.Parameter(F.normalize(self.topic_embeddings))

        self.word_weights = nn.Parameter((torch.ones(args.vocab_size) / args.vocab_size).unsqueeze(1))
        self.topic_weights = nn.Parameter((torch.ones(args.num_topics) / args.num_topics).unsqueeze(1))
     
        self.DT_ETP = ETP(sinkhorn_alpha=args.model.DT_alpha, init_b_dist=self.topic_weights)
        self.TW_ETP = ETP(sinkhorn_alpha=args.model.TW_alpha, init_b_dist=self.word_weights)

    def load_contextual_embed(self, texts, device, show_progress_bar=True):
        embeddings = self.sentence_model.encode(texts, show_progress_bar=show_progress_bar)
        return embeddings

    def pairwise_euclidean_distance(self, x, y):
        cost = torch.sum(x ** 2, axis=1, keepdim=True) + torch.sum(y ** 2, dim=1) - 2 * torch.matmul(x, y.t())
        return cost

    def softmax_euclidean(self, x, y, temp=1.0, dim=-1):
        dist = self.pairwise_euclidean_distance(x, y)
        out = F.softmax(-dist / temp, dim=dim)
        return out

    # only for testing
    def get_beta(self):
        _, transp_TW = self.TW_ETP(self.topic_embeddings, self.word_embeddings)

        # use transport plan as beta
        beta = transp_TW * transp_TW.shape[0]

        return beta

    # only for testing
    def get_theta(self, texts):
        doc_embeddings = self.load_contextual_embed(texts, self.device)
        doc_embeddings = torch.from_numpy(doc_embeddings).to(self.device)

        dist = self.pairwise_euclidean_distance(doc_embeddings, self.topic_embeddings)
        train_dist = self.pairwise_euclidean_distance(self.doc_embeddings, self.topic_embeddings)

        exp_dist = torch.exp(-dist / self.args.model.theta_temp)
        exp_train_dist = torch.exp(-train_dist / self.args.model.theta_temp)

        theta = exp_dist / (exp_train_dist.sum(0))
        theta = theta / theta.sum(1, keepdim=True)

        return theta

    def forward(self, train_bow):
        loss = 0.

        loss_DT, transp_DT = self.DT_ETP(self.doc_embeddings, self.topic_embeddings)
        loss_TW, transp_TW = self.TW_ETP(self.topic_embeddings, self.word_embeddings)

        loss_DT = self.args.model.weight_DT * loss_DT
        loss += loss_DT

        loss_TW = self.args.model.weight_TW * loss_TW
        loss += loss_TW

        theta = transp_DT * transp_DT.shape[0]
        beta = transp_TW * transp_TW.shape[0]

        recon = torch.matmul(theta, beta)

        loss_TM = -(train_bow * recon.log()).sum(axis=1).mean()
        loss += loss_TM

        rst_dict = {
            'loss': loss,
        }

        return rst_dict
