import copy
import shutil
import re
from torch.utils.tensorboard import SummaryWriter
from src.models import SmallCNN, ResNet50, resnet32
from src.utils import (
    AverageMeter,
    calculate_accuracy,
    change_column_value_of_existing_row,
    load_checkpoint,
    get_mask,
    save_checkpoint,
    apply_mask_and_save_images,
    select_device,
    Logger,
    write_config_to_csv,
    get_heatmap_generator,
    calculate_data_heat_map_mean,
    update_dataset_and_dataloader
)


from torch import Tensor
from tqdm import tqdm
from abc import ABC, abstractmethod

import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import torch.nn as nn

import torch
import os
import math
import time

model_size_configuration = {
    'cifar10': {
        'num_classes': 10,
        'input_size': 32
    },
    'svhn': {
        'num_classes': 10,
        'input_size': 32
    },
    'catsvsdogs': {
        'num_classes': 2,
        'input_size': 64
    },
    'in9l': {
        'num_classes': 9,
        'input_size': 224
    },
    'mnist': {
        'num_classes': 2,
        'input_size': 28
    },
    "celeba": {
        'num_classes': 2,
        'input_size': 224
    }
}


class TrainBaseMethod(ABC):
    def __init__(self, args) -> None:
        self.args = args
        self.current_epoch = 0
        self.preprocess = None
        self.std = np.reshape([1.0, 1.0, 1.0], [3, 1, 1])
        self.mean = np.reshape([0.0, 0.0, 0.0], [3, 1, 1])
        os.makedirs(os.path.join(
            self.args.base_dir, "runs"
        ), exist_ok=True)
        self.run_configs_file_path = os.path.join(
            self.args.base_dir, "runs", "run_configs.csv"
        )
        self.run_id = write_config_to_csv(args, self.run_configs_file_path)
        os.makedirs(os.path.join(
            self.args.base_dir, "runs"
        ), exist_ok=True)
        self.run_dir = os.path.join(
            args.base_dir, "runs", str(self.run_id))

        os.makedirs(self.run_dir, exist_ok=True)
        log_dir = os.path.join(self.run_dir, "logs")
        os.makedirs(log_dir, exist_ok=True)
        self.writer = SummaryWriter(log_dir)
        model_save_dir = os.path.join(self.run_dir, "checkpoints")
        os.makedirs(model_save_dir, exist_ok=True)
        self.model_save_dir = model_save_dir
        augmented_data_save_dir = os.path.join(self.run_dir, "augmented_data")
        os.makedirs(augmented_data_save_dir, exist_ok=True)
        self.augmented_data_save_dir = augmented_data_save_dir
        self.device = select_device(self.args.use_cuda)
        if args.class_weights is not None:
            self.loss_function = nn.CrossEntropyLoss(
                weight=torch.tensor(args.class_weights).to(self.device))
        else:
            self.loss_function = nn.CrossEntropyLoss()
        if self.args.masking_arch == None:
            self.args.masking_arch = self.args.arch
        if os.path.isfile(os.path.join(log_dir, "log.txt")):
            self.logger = Logger(log_dir, None)
        else:
            self.logger = Logger(log_dir, str(self.args))
        self.writer.add_text("Configuration", str(self.args))

    @abstractmethod
    def prepare_data_loaders(self) -> None:
        pass

    def prepare_lr_scheduler(self) -> None:
        if self.args.lr_scheduler_name == "multi_step":
            lr_scheduler = optim.lr_scheduler.MultiStepLR(
                optimizer=self.optimizer,
                milestones=self.args.schedule,
                gamma=self.args.gamma,
                last_epoch=-1,
            )
        else:
            raise NotImplementedError(
                f"{self.args.lr_scheduler_name} scheduler not implemented!"
            )
        self.lr_scheduler = lr_scheduler

    def prepare_optimizer(self) -> None:
        if self.args.optimizer == "sgd":
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=self.args.lr,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay,
                nesterov=self.args.use_nesterov
            )
        elif self.args.optimizer == "adam":
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.args.lr,
                weight_decay=self.args.weight_decay,
            )
        else:
            raise NotImplementedError(
                f"{self.args.lr_scheduler_name} optimizer not implemented!"
            )

    def run_an_epoch(self, data_loader, epoch, train=False, val_or_test="val"):
        if train:
            self.model.train()
        else:
            self.model.eval()
        losses = AverageMeter()
        accuracies = AverageMeter()
        if train:
            progress_bar_description = 'Epoch ' + str(epoch)
        else:
            progress_bar_description = val_or_test
        with torch.set_grad_enabled(train):
            progress_bar = tqdm(data_loader)
            self.logger.info(
                f"{'train' if train else val_or_test} epoch: {epoch}"
            )
            for data in progress_bar:
                progress_bar.set_description(progress_bar_description)
                inputs, targets = data[0], data[2]
                inputs, targets = inputs.to(
                    self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = self.loss_function(outputs, targets)
                losses.update(loss.item(), inputs.size(0))
                output_probabilities = F.softmax(outputs, dim=1)
                probabilities, predictions = output_probabilities.data.max(1)
                accuracies.update(calculate_accuracy(targets, predictions), 1)
                if train:
                    self.optimize(loss=loss)

                progress_bar.set_postfix(
                    {
                        "loss": losses.avg,
                        "accuracy": accuracies.avg,
                    }
                )
            self.logger.info(
                f"loss: {losses.avg}"
            )
            self.logger.info(
                f"accuracy: {accuracies.avg}"
            )
        return accuracies.avg

    def optimize(self, loss: Tensor) -> None:
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def prepare_model(self, arch, multi_gpu=True) -> None:
        if arch == "small_cnn":
            self.model = SmallCNN(
                num_classes=model_size_configuration[self.args.dataset]['num_classes'], drop_rate=0.0)
        elif arch == "resnet50":
            self.model = ResNet50(
                pretrained=self.args.use_pretrained_weights, num_classes=model_size_configuration[self.args.dataset]['num_classes'])
        elif arch == 'resnet32':
            self.model = resnet32(pretrained=self.args.use_pretrained_weights, num_classes=model_size_configuration[self.args.dataset]['num_classes'])
        else:
            raise NotImplementedError()
        if multi_gpu:
            self.model = nn.DataParallel(
                self.model, device_ids=self.args.gpu_ids)
        self.model = self.model.to(self.device)

    def test(self, checkpoint_path=None):
        self.prepare_data_loaders(train=False)
        self.logger.info("-" * 10 + "testing the model" +
                         "-" * 10, print_msg=True)
        self.prepare_model(arch=self.args.arch)
        self.model = self.model.to(self.device)
        if self.args.epochs > 0 or self.args.saved_checkpoint_dir is not None:
            if checkpoint_path is None:
                if not os.path.isfile(os.path.join(self.model_save_dir, self.args.checkpoint_name)):
                    shutil.copy(
                        os.path.join(
                            self.args.saved_checkpoint_dir, self.args.checkpoint_name
                        ),
                        os.path.join(
                            self.model_save_dir, self.args.checkpoint_name
                        )
                    )
                checkpoint_path = os.path.join(
                    self.model_save_dir, self.args.checkpoint_name
                )
            (
                self.model,
                _,
                _,
                self.current_epoch,
            ) = load_checkpoint(
                model=self.model,
                optimizer=None,
                lr_scheduler=None,
                checkpoint_path=checkpoint_path,
            )
            self.logger.info(
                "-" * 10 + "model checkpoint loaded" + "-" * 10, print_msg=True)
        accuracy = self.run_an_epoch(
            data_loader=self.test_loader, epoch=0, train=False, val_or_test="test"
        )
        change_column_value_of_existing_row(
            "accuracy",
            accuracy,
            self.run_configs_file_path,
            self.run_id,
        )

    def train_model(self, phase="", use_lr_scheduler=True):
        resume_epoch = 0
        if self.args.saved_checkpoint_dir and ((phase != "final" and os.path.isfile(os.path.join(self.args.saved_checkpoint_dir, f"last_epoch_model_for_erm.pt"))) or (phase == "final" and self.args.resume_final_training)):

            (
                self.model,
                self.optimizer,
                self.lr_scheduler,
                last_epoch,
            ) = load_checkpoint(
                model=self.model,
                optimizer=self.optimizer,
                lr_scheduler=self.lr_scheduler,
                checkpoint_path=os.path.join(
                    self.args.saved_checkpoint_dir,
                    f"last_epoch_model_for_{phase}.pt",
                ),
            )
            checkpoint_source_path = os.path.join(
                self.args.saved_checkpoint_dir)
            shutil.copyfile(os.path.join(
                checkpoint_source_path, f'best_model_for_{phase}.pt'), os.path.join(self.model_save_dir, f'best_model_for_{phase}.pt'))
            shutil.copyfile(os.path.join(
                checkpoint_source_path, f"last_epoch_model_for_{phase}.pt"), os.path.join(self.model_save_dir, f"last_epoch_model_for_{phase}.pt"))

            if last_epoch >= self.args.epochs:
                self.logger.info(
                    "-" * 10 +
                    f"completely trained model checkpoint loaded for {phase}" +
                    "-" * 10,
                    print_msg=True
                )
                return
            elif last_epoch < self.args.epochs:
                self.logger.info(
                    "-" * 10 +
                    f"partially trained model checkpoint loaded for {phase} and resume training from epoch {last_epoch+1} on {len(self.train_dataset.data_path) + len(self.train_dataset.masked_data_path)} data" +
                    "-" * 10,
                    print_msg=True
                )
                resume_epoch = last_epoch + 1
        elif phase == "final" and self.args.initialize_final_train_model_with_erm_weights:
            self.logger.info(
                "-" *
                10 +
                f"training the model from erm initialized weights on {len(self.train_dataset.data_path) + len(self.train_dataset.masked_data_path)} data for phase {phase}" + "-" * 10,
                print_msg=True
            )
        else:
            self.logger.info(
                "-" *
                10 +
                f"training the model from scratch on {len(self.train_dataset.data_path) + len(self.train_dataset.masked_data_path)} data for phase {phase}" + "-" * 10,
                print_msg=True
            )
        best_accuracy = -math.inf
        for current_epoch in range(resume_epoch, self.args.epochs):
            self.current_epoch = current_epoch
            _ = self.run_an_epoch(
                data_loader=self.train_loader, epoch=current_epoch, train=True)
            val_accuracy = self.run_an_epoch(
                data_loader=self.val_loader, epoch=0, train=False, val_or_test="val"
            )
            if use_lr_scheduler:
                self.lr_scheduler.step()
            self.logger.info(
                f"lr: {self.lr_scheduler.get_last_lr()[0]}",
                print_msg=True
            )
            if val_accuracy > best_accuracy:
                save_checkpoint(
                    model=self.model,
                    optimizer=self.optimizer,
                    lr_scheduler=self.lr_scheduler,
                    checkpoint_path=os.path.join(
                        self.model_save_dir,
                        f"best_model_for_{phase}.pt",
                    ),
                    current_epoch=self.current_epoch,
                )
                best_accuracy = val_accuracy
            save_checkpoint(
                model=self.model,
                optimizer=self.optimizer,
                lr_scheduler=self.lr_scheduler,
                checkpoint_path=os.path.join(
                    self.model_save_dir,
                    f"last_epoch_model_for_{phase}.pt",
                ),
                current_epoch=self.current_epoch,
            )

    def test_selective_classification(self, checkpoint_path=None, phase="final"):
        self.prepare_data_loaders(train=False)
        self.prepare_model(arch=self.args.arch)
        self.model = self.model.to(self.device)
        final_model = copy.deepcopy(self.model)
        final_model = final_model.to(self.device)
        self.logger.info(
            "-" * 10 + "getting the selective classification's error" + "-" * 10, print_msg=True)
        if checkpoint_path is None:
            if not os.path.isfile(os.path.join(self.model_save_dir, "best_model_for_erm.pt")):
                shutil.copy(
                    os.path.join(
                        self.args.saved_checkpoint_dir, "best_model_for_erm.pt"
                    ),
                    os.path.join(
                        self.model_save_dir, "best_model_for_erm.pt"
                    )
                )
            checkpoint_path = os.path.join(
                self.model_save_dir, "best_model_for_erm.pt")

        if phase == "final":
            if not os.path.isfile(os.path.join(self.model_save_dir, "best_model_for_final.pt")):
                shutil.copy(
                    os.path.join(
                        self.args.saved_checkpoint_dir, "best_model_for_final.pt"
                    ),
                    os.path.join(
                        self.model_save_dir, "best_model_for_final.pt"
                    )
                )
            final_checkpoint_path = os.path.join(
                self.model_save_dir, "best_model_for_final.pt")
            (
                final_model,
                _,
                _,
                self.current_epoch,
            ) = load_checkpoint(
                model=final_model,
                optimizer=None,
                lr_scheduler=None,
                checkpoint_path=final_checkpoint_path
            )
            final_model.eval()
        (
            self.model,
            _,
            _,
            self.current_epoch,
        ) = load_checkpoint(
            model=self.model,
            optimizer=None,
            lr_scheduler=None,
            checkpoint_path=checkpoint_path
        )
        self.model.eval()

        val_eqs = []
        val_prob = []
        with torch.no_grad():
            for inputs, _, targets in tqdm(self.val_loader):
                inputs, targets = inputs.to(
                    self.device), targets.to(self.device)
                outputs = self.model(inputs)
                output_probabilities = F.softmax(outputs, dim=1)
                if phase == "final":
                    final_outputs = final_model(inputs)
                    final_output_probabilities = F.softmax(
                        final_outputs, dim=1)
                    probabilities = F.softmax(
                        output_probabilities * final_output_probabilities, dim=1)
                else:
                    probabilities = output_probabilities
                class_probs, class_preds = probabilities.data.max(1)
                val_prob.append(class_probs)
                equals = class_preds.cpu().eq(targets.data.cpu())
                val_eqs.append(equals)

            val_eqs = torch.cat(val_eqs, 0).cpu()
            val_prob = torch.cat(val_prob, 0).cpu()
            val_indices = torch.sort(val_prob, descending=True)[1]
            val_eqs = torch.gather(val_eqs, dim=0, index=val_indices)
            val_prob = torch.gather(val_prob, dim=0, index=val_indices)

            test_eqs = []
            test_prob = []
            for inputs, _, targets in tqdm(self.test_loader):
                inputs, targets = inputs.to(
                    self.device), targets.to(self.device)
                outputs = self.model(inputs)
                output_probabilities = F.softmax(outputs, dim=1)
                if phase == "final":
                    final_outputs = final_model(inputs)
                    final_output_probabilities = F.softmax(
                        final_outputs, dim=1)
                    probabilities = F.softmax(
                        output_probabilities * final_output_probabilities, dim=1)
                else:
                    probabilities = output_probabilities
                class_probs, class_preds = probabilities.data.max(1)
                test_prob.append(class_probs)
                equals = class_preds.cpu().eq(targets.data.cpu())
                test_eqs.append(equals)
            test_eqs = torch.cat(test_eqs, 0).cpu()
            test_prob = torch.cat(test_prob, 0).cpu()

            test_indices = torch.sort(test_prob, descending=True)[1]
            test_eqs = torch.gather(test_eqs, dim=0, index=test_indices)
            test_prob = torch.gather(test_prob, dim=0, index=test_indices)
            for e_cov in self.args.coverage:
                thresholded_index = round((e_cov / 100) * len(val_indices))
                threshold = val_prob[min(thresholded_index, len(val_prob)-1)]
                predicted_samples = test_prob>=threshold
                error = 1 - (torch.sum(predicted_samples * test_eqs) / torch.sum(predicted_samples))
                self.logger.info(
                    "EXP COV {}, COV {}, ERROR: {}".format(
                        e_cov, round((torch.sum(predicted_samples).item() / len(predicted_samples)) * 100.0, 2), 100 * round(error.item(), 4)
                    ),
                    print_msg=True
                )

    def check_and_load_saved_masks(self):
        data_is_complete = True
        for threshold_type_for_masking in self.args.threshold_types_for_masking:
            if os.path.isdir(os.path.join(
                    self.args.saved_mask_dir, threshold_type_for_masking)):
                data_source_path = os.path.join(
                    self.args.saved_mask_dir, threshold_type_for_masking)

                masked_data_count = 0
                for label in os.listdir(data_source_path):
                    masked_data_count += len(os.listdir(
                        os.path.join(data_source_path, label)))
                if masked_data_count == len(self.train_dataset):
                    if self.args.epochs > 0:
                        shutil.copyfile(os.path.join(
                            self.args.saved_checkpoint_dir, f'best_model_for_erm.pt'), os.path.join(self.model_save_dir, f'best_model_for_erm.pt'))
                    self.logger.info(
                        "-"*10 + f"using masked data from saved data dir" + "-"*10, print_msg=True)
                    self.augmented_data_save_dir = os.path.join(
                        self.args.saved_mask_dir)
                else:
                    data_is_complete = False
            else:
                data_is_complete = False

        return data_is_complete

    def train(self):
        self.prepare_data_loaders(train=True)
        self.prepare_model(arch=self.args.arch)
        self.model = self.model.to(self.device)
        self.prepare_optimizer()
        self.prepare_lr_scheduler()
        self.train_model(phase="erm")

    def finetune_with_cutout(self):
        self.prepare_data_loaders()
        phase_to_continue = 0
        data_is_complete = False
        if self.args.saved_mask_dir is not None:
            data_is_complete = self.check_and_load_saved_masks()
        if not data_is_complete:
            self.prepare_model(arch=self.args.arch)
            self.prepare_optimizer()
            self.prepare_lr_scheduler()
            self.train_model(phase="erm")
            if self.args.selective_classification:
                self.test_selective_classification(os.path.join(
                    self.model_save_dir, f"best_model_for_erm.pt"))
            else:
                self.test(os.path.join(self.model_save_dir,
                          f"best_model_for_erm.pt"))
        self.prepare_model(arch=self.args.arch)
        self.prepare_optimizer()
        self.prepare_lr_scheduler()
        if self.args.initialize_final_train_model_with_erm_weights:
            if self.args.continue_final_train_with_erm_lr:
                (
                    self.model,
                    self.optimizer,
                    self.lr_scheduler,
                    _,
                ) = load_checkpoint(
                    model=self.model,
                    optimizer=self.optimizer,
                    lr_scheduler=self.lr_scheduler,
                    checkpoint_path=os.path.join(
                        self.model_save_dir,
                        f"best_model_for_erm.pt",
                    ),
                )
            else:
                (
                    self.model,
                    _,
                    _,
                    _,
                ) = load_checkpoint(
                    model=self.model,
                    optimizer=None,
                    lr_scheduler=None,
                    checkpoint_path=os.path.join(
                        self.model_save_dir,
                        f"best_model_for_erm.pt",
                    ),
                )
        self.args.epochs = self.args.final_train_epochs
        self.args.cutout = True
        self.prepare_data_loaders(train=True)
        self.train_model(phase="final", use_lr_scheduler=(
            not self.args.continue_final_train_with_erm_lr))

    def train_augmask(self):
        self.prepare_data_loaders(train=True)
        data_is_complete = False
        if self.args.saved_mask_dir is not None:
            data_is_complete = self.check_and_load_saved_masks()
        if not data_is_complete:
            self.prepare_model(arch=self.args.arch)
            self.prepare_optimizer()
            self.prepare_lr_scheduler()
            self.train_model(phase='erm')
            if self.args.selective_classification:
                self.test_selective_classification(os.path.join(
                    self.model_save_dir, "best_model_for_erm.pt"), phase='erm')
            else:
                self.test(os.path.join(self.model_save_dir,
                          f"best_model_for_erm.pt"))
            if os.path.isfile(os.path.join(
                self.model_save_dir,
                f"best_model_for_erm.pt",
            )):
                self.prepare_model(arch=self.args.masking_arch, multi_gpu=False)
                (
                    self.model,
                    _,
                    _,
                    _,
                ) = load_checkpoint(
                    model=self.model,
                    optimizer=None,
                    lr_scheduler=None,
                    checkpoint_path=os.path.join(
                        self.model_save_dir,
                        f"best_model_for_erm.pt",
                    ),
                )
                self.logger.info(
                    "-"*10 + "best erm model's weights loaded for masking" + "-"*10, print_msg=True
                )
            if self.args.epochs == 0:
                self.prepare_model(arch=self.args.masking_arch, multi_gpu=False)
            self.mask_data()
        final_train_data_dirs = []
        for threshold_type_for_masking in self.args.threshold_types_for_masking:
            final_train_data_dirs.append(os.path.join(
                self.augmented_data_save_dir, threshold_type_for_masking))
        self.prepare_model(arch=self.args.arch)
        self.prepare_optimizer()
        self.prepare_lr_scheduler()
        self.train_dataset, self.train_loader = update_dataset_and_dataloader(
            self.train_dataset, data_dir=final_train_data_dirs, batch_size=self.args.train_batch, workers=self.args.workers)

        if self.args.initialize_final_train_model_with_erm_weights:
            if self.args.continue_final_train_with_erm_lr:
                self.logger.info(
                    "-"*10 + "loading best erm weights and scheduler" + "-"*10,
                    print_msg=True
                )
                (
                    self.model,
                    self.optimizer,
                    self.lr_scheduler,
                    _,
                ) = load_checkpoint(
                    model=self.model,
                    optimizer=self.optimizer,
                    lr_scheduler=self.lr_scheduler,
                    checkpoint_path=os.path.join(
                        self.model_save_dir,
                        f"best_model_for_erm.pt",
                    ),
                )
            elif os.path.isfile(os.path.join(
                self.model_save_dir,
                f"best_model_for_erm.pt",
            )):
                self.logger.info(
                    "-"*10 + "loading best erm weights" + "-"*10,
                    print_msg=True
                )
                (
                    self.model,
                    _,
                    _,
                    _,
                ) = load_checkpoint(
                    model=self.model,
                    optimizer=None,
                    lr_scheduler=None,
                    checkpoint_path=os.path.join(
                        self.model_save_dir,
                        f"best_model_for_erm.pt",
                    ),
                )
        self.args.epochs = self.args.final_train_epochs
        self.train_model(phase="final", use_lr_scheduler=(
            not self.args.continue_final_train_with_erm_lr))

    def mask_data(self):
        target_layer = self.model.get_grad_cam_target_layer()
        heat_map_generator = get_heatmap_generator(self.args.heat_map_generation_method, self.model, target_layer, self.args.use_cuda, grad_cam_weight=self.args.grad_cam_weight)
        self.logger.info(
            "-" * 10 + f"masking and saving {len(self.data_to_mask_dataset)} data" + "-" * 10, print_msg=True)
        heat_maps_mean = None
        if self.args.mask_mean_mode == "global":
            heat_maps_mean = calculate_data_heat_map_mean(
                self.train_loader, self.device, heat_map_generator)
        masked_data_save_dirs = []
        std_coefficients = []
        for threshold_type_for_masking in self.args.threshold_types_for_masking:
            std_coefficients.append(
                int(re.findall('\d+', threshold_type_for_masking)[0]))
            masked_data_save_dir = os.path.join(
                self.augmented_data_save_dir, threshold_type_for_masking)
            os.makedirs(masked_data_save_dir, exist_ok=True)
            masked_data_save_dirs.append(masked_data_save_dir)

        masking_start_time = time.time()
        masked_data_number = 0
        for data in tqdm(self.data_to_mask_loader):
            images, images_pathes, targets = data[0], data[1], data[2]
            images = images.to(self.device)
            image_masks = get_mask(
                images=images,
                heat_map_generator=heat_map_generator,
                masking=self.args.masking,
                remove_k=self.args.remove_k,
                global_threshold=heat_maps_mean,
                std_coefficients=std_coefficients,
            )
            masked_data_number += images.shape[0]
            for im_masks, masked_data_save_dir in zip(image_masks, masked_data_save_dirs):
                apply_mask_and_save_images(
                    image_masks=im_masks, masked_data_save_dir=masked_data_save_dir, images_pathes=images_pathes, targets=targets
                )
        tqdm.write("Masking Time: {}".format(
            time.time() - masking_start_time))
