import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import copy
import lightly

from lightly.models.modules.heads import MoCoProjectionHead
from lightly.models.utils import deactivate_requires_grad
from lightly.models.utils import update_momentum
from lightly.models.utils import batch_shuffle
from lightly.models.utils import batch_unshuffle
import torchmetrics

from models.ViT import ViT

feature_map = {0: 65536, 1: 32768, 2: 16384, 3: 8192, 4: 512, -1: 512}


class ClassifierViT(pl.LightningModule):
    def __init__(self, num_classes=100,
                 max_epochs=500, depth=10, head=5, mlp_hidden=384, hidden=384):
        super().__init__()

        self.backbone = ViT(num_layers=depth, head=head, is_cls_token=True,
                            mlp_hidden=mlp_hidden, hidden=hidden, num_classes=num_classes)

        # return out5, [out1, out2, out3, out4, out5]
        self.max_epochs = max_epochs
        self.is_cuda = torch.cuda.is_available()
        self.num_classes = num_classes
        self.accuracy = torchmetrics.Accuracy('multiclass')

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        y_hat = self.backbone(x)
        return y_hat

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.accuracy(y_hat, y)
        self.log("train_loss_fc", loss)
        self.log("Loss/total_loss", loss)
        self.log('train_acc_step', self.accuracy)
        return loss

    def training_epoch_end(self, outputs):
        self.log('train_acc_epoch', self.accuracy)

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        y_hat = torch.nn.functional.softmax(y_hat, dim=1)

        # calculate number of correct predictions
        _, predicted = torch.max(y_hat, 1)
        num = predicted.shape[0]
        correct = (predicted == y).float().sum()
        return num, correct

    def validation_epoch_end(self, outputs):
        # calculate and log top1 accuracy
        if outputs:
            total_num = 0
            total_correct = 0
            for num, correct in outputs:
                total_num += num
                total_correct += correct
            acc = total_correct / total_num
            self.log("val_acc", acc, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-5)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs, eta_min=1e-5)
        return [optim], [scheduler]
        # return optim
