from models.classifier import Classifier
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

from models.classifier_vit import ClassifierViT
from models.ssl_models import VICRegModel
from utils.model_checkpointing import AlteredModelCheckpoint

gpus = 1 if torch.cuda.is_available() else 0
import os


def train_classification(backbone, dataloader_train, dataloader_test, output_path,
                         max_epochs=1200, layer_idx=-1):
    use_vit =False

    if use_vit:
        depth = 10
        head = 8
        mlp_hidden = 384
        hidden = 384
        classifier = ClassifierViT(num_classes=100,
                                   max_epochs=max_epochs,
                                   depth=depth, head=head, mlp_hidden=mlp_hidden,
                                   hidden=hidden)

    else:
        classifier = Classifier(backbone, layer_idx=layer_idx,
                                num_classes=100, max_epochs=max_epochs)


    logger = CSVLogger(output_path, name=f"classification")
    # logger = TensorBoardLogger(output_path, name=f"classification_{run_id}")
    step = 10
    checkpoint_callback = AlteredModelCheckpoint(
        save_on_train_epoch_end=False,
        dirpath=os.path.join(output_path, 'checkpoints'),
        save_top_k=-1, every_n_epochs=step, monitor='Loss/total_loss',
        every_n_val_epochs=step,
        ssl_train=False
    )

    trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus,
                         default_root_dir=output_path,
                         logger=logger,
                         progress_bar_refresh_rate=100,
                         callbacks=[checkpoint_callback])

    trainer.fit(
        classifier,
        dataloader_train,
        dataloader_test
    )


def run_classification(output_path):
    from offline_evaluation.load_datasets import load_datasets
    dataloader_train_ssl, dataloader_train_eval, dataloader_test_eval = \
        load_datasets(augment=False, use_imagenet_transforms=False)

    pl.seed_everything(0)
    benchmark_model = VICRegModel(dataloader_kNN=dataloader_train_eval, dataloader_test=dataloader_test_eval,
                                  num_classes=100, resnet_type="resnet-18").cuda()

    backbone = benchmark_model.backbone


    train_classification(backbone=backbone, dataloader_train=dataloader_train_eval,
                         dataloader_test=dataloader_test_eval, output_path=output_path, max_epochs=1002)