from libauc_mod.models import resnet18
from libauc_mod.datasets import CelebaDataset  # CIFAR10, CIFAR100, Melanoma,
from libauc_mod.utils import ImbalancedDataGenerator
from libauc_mod.losses.auc import pAUC_CVaR_smag_Loss, pAUC_CVaR_Loss
from libauc_mod.optimizers import SMAG, SOPA, SGA, ESGA, ESGD
from libauc_mod.utils import ImbalancedDataGenerator
from libauc_mod.sampler import DualSampler  # data resampling (for binary class)
from libauc_mod.metrics import pauc_roc_score
from libauc_mod.utils import CosineLRScheduler

import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

import os
import pickle
from datetime import datetime as dt
from datetime import timedelta as td
from datetime import date
import argparse
import torch.optim as optim
from adversarial_head import AttrLinearClassifier
from auc_fair_metric import AUCFairness, EOD_EOP, DP

parser = argparse.ArgumentParser(description='pAUC_fairness')

### Parameter Setting
parser.add_argument('--SEED', default=123, type=int)
# parser.add_argument('--gpu_id', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--num_workers', default=0, type=int)
parser.add_argument('--total_epochs', default=60, type=int)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--weight_decay', default=5e-4, type=float, help='regularization weight decay')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate for sopa')
parser.add_argument('--lr_0', default=0.01, type=float, help='learning rate for w_hat in smag')
parser.add_argument('--lr_1', default=0.1, type=float, help='learning rate for w in smag')
parser.add_argument('--eta', default=1e1, type=float, help='learning rate for control negative samples weights')
parser.add_argument('--decay_epochs', default=[20, 40])
parser.add_argument('--decay_factor', default=10)
parser.add_argument('--max_fpr', default=0.3, type=float, help='upper bound for FPR')
parser.add_argument('--sampling_rate', default=0.5, type=float)
parser.add_argument('--optimizer', default='sopa', type=str, help='sopa, smag or edga')
parser.add_argument('--optimizer_mode', default='sgd', type=str, help='sgd')
parser.add_argument('--gamma', default=100, type=float)
parser.add_argument('--dataset', default='celeba_binary', type=str, help='celeba_binary')
parser.add_argument('--sensitive_label', default='Male', type=str,
                    help='For fairness adversarial sensitive label of celeba datasset')
parser.add_argument('--binary_label', default='Attractive', type=str, help='For binary label celeba datasset')
parser.add_argument('--adv_alpha', default=0.0, type=float, help='For binary label celeba datasset')
parser.add_argument('--cos_lr', default=0, type=int, help='1 for using cosing learning rate scheduler, 0 for not using')
parser.add_argument('--pretrain', default=1, type=int, help='whether to use pre-trained model')
parser.add_argument('--init_wbuf', default=0, type=int,
                    help='0 for random initialization for w_buf, 1 for setting inital w_buf equal to w')
parser.add_argument('--epoch_steps', default=10, type=int, help='second loop algorithm inner steps')


def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class ImageDataset(Dataset):
    def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
        self.images = images.astype(np.uint8)
        self.targets = targets
        self.mode = mode
        self.transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomCrop((crop_size, crop_size), padding=None),
            transforms.RandomHorizontalFlip(),
            transforms.Resize((image_size, image_size)),
        ])
        self.transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((image_size, image_size)),
        ])

        # for loss function
        self.pos_indices = np.flatnonzero(targets == 1)
        self.pos_index_map = {}
        for i, idx in enumerate(self.pos_indices):
            self.pos_index_map[idx] = i

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        image = Image.fromarray(image.astype('uint8'))
        if self.mode == 'train':
            idx = self.pos_index_map[idx] if idx in self.pos_indices else -1
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        return image, target, idx


def main():
    ### paramaters
    args = parser.parse_args()

    SEED = args.SEED
    batch_size = args.batch_size
    print(batch_size)
    total_epochs = args.total_epochs
    weight_decay = args.weight_decay  # regularization weight decay

    # learning rate for sopa
    lr = args.lr

    # learning rates for smag
    lr_0 = args.lr_0
    lr_1 = args.lr_1
    eta_0 = args.lr_0
    eta_1 = args.lr_1

    eta = args.eta  # learning rate for control negative samples weights
    decay_epochs = args.decay_epochs
    decay_factor = args.decay_factor
    max_fpr = args.max_fpr  # upper bound for FPR
    sampling_rate = args.sampling_rate
    optimizer_mode = args.optimizer_mode
    gamma = args.gamma
    dataset = args.dataset
    device = 'cuda'
    optimizer_option = args.optimizer
    cos_lr = args.cos_lr
    step_size = 100
    eval_step = 500
    iter_count = 0
    rand_init_wbuf = args.init_wbuf != 0
    print(args)

    ### datasets
    if dataset == 'celeba_binary':
        print('loading dataset ', dataset, '...', flush=True)
        celeba_mode = 'binary'
        root = './CelebA/'
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = CelebaDataset(
            root + 'celeba_attr_train.csv',
            root + 'img_align_celeba/',
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]), mode=celeba_mode, binary_label=args.binary_label, sensitive_label=args.sensitive_label)

        val_dataset = CelebaDataset(root + 'celeba_attr_val.csv', root + 'img_align_celeba/',
                                    transforms.Compose([
                                        transforms.ToTensor(),
                                        normalize,
                                    ]), mode=celeba_mode, binary_label=args.binary_label,
                                    sensitive_label=args.sensitive_label)

        test_dataset = CelebaDataset(root + 'celeba_attr_test.csv', root + 'img_align_celeba/',
                                     transforms.Compose([
                                         transforms.ToTensor(),
                                         normalize,
                                     ]), mode=celeba_mode, binary_label=args.binary_label,
                                     sensitive_label=args.sensitive_label)

        sampler = DualSampler(train_dataset, batch_size, sampling_rate=sampling_rate)
        trainloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=(sampler is None),
            num_workers=args.num_workers, pin_memory=True, sampler=sampler)

        print('................>>>>>> length of trainloader {}'.format(len(trainloader)))

        validloader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batch_size, shuffle=False,
            num_workers=args.num_workers, pin_memory=True)

        testloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=batch_size, shuffle=False,
            num_workers=args.num_workers, pin_memory=True)

    data_pos_len = sampler.pos_len
    data_len = sampler.pos_len + sampler.neg_len
    print('data_pos_len: ', data_pos_len, 'data_len: ', data_len)

    ### model
    set_all_seeds(SEED)
    model = resnet18(pretrained=True, num_classes=1, last_activation=None)
    ### Adversarial Head
    adv_head = AttrLinearClassifier()

    model = model.cuda()
    adv_head = adv_head.cuda()

    if optimizer_option == 'sopa':
        loss_fn = pAUC_CVaR_Loss(pos_len=data_pos_len, beta=max_fpr, data_len=data_len)
        optimizer = SOPA(model.parameters(), loss_fn=loss_fn, mode=optimizer_mode, lr=lr, eta=eta,
                         weight_decay=weight_decay)
        adv_criterion = nn.BCEWithLogitsLoss()
        adv_optimizer = SGA(adv_head.parameters(), lr=lr_0)
    elif optimizer_option == 'smag':
        loss_fn = pAUC_CVaR_smag_Loss(pos_len=data_pos_len, beta=max_fpr, data_len=data_len, gamma=gamma, eta_0=eta_0,
                                      eta_1=eta_1)
        optimizer = SMAG(model.parameters(), loss_fn=loss_fn, mode=optimizer_mode, lr_0=lr_0, lr_1=lr_1, eta=eta,
                         gamma=gamma, rand_init_wbuf=rand_init_wbuf)
        adv_criterion = nn.BCEWithLogitsLoss()
        adv_optimizer = SGA(adv_head.parameters(), lr=lr_0)
    elif optimizer_option == 'egda':
        loss_fn = pAUC_CVaR_smag_Loss(pos_len=data_pos_len, beta=max_fpr, data_len=data_len, gamma=gamma, eta_0=eta_0,
                                      eta_1=eta_1)
        optimizer = ESGD(model.parameters(), loss_fn=loss_fn, lr=lr, eta=eta,
                         weight_decay=weight_decay, gamma=gamma, rand_init_wbuf=True, epoch_steps=args.epoch_steps)
        adv_criterion = nn.BCEWithLogitsLoss()
        adv_optimizer = ESGA(adv_head.parameters(), lr=lr_0, epoch_steps=args.epoch_steps)
    else:
        raise ValueError("Invalid optimizer_option: {}".format(optimizer_option))

    if cos_lr:
        if optimizer_option != 'smag':
            param_group_field = 'lr'
        else:
            param_group_field = 'lr_1'
        # min_lr = lr_1
        min_lr = 0.00001
        lr_scheduler = CosineLRScheduler(optimizer=optimizer,
                                         param_group_field=param_group_field,
                                         t_initial=total_epochs,
                                         t_mul=1.0,
                                         lr_min=min_lr,
                                         decay_rate=1,
                                         warmup_lr_init=0.00001,
                                         warmup_t=20,
                                         cycle_limit=1,
                                         t_in_epochs=True,
                                         noise_range_t=None,
                                         noise_pct=0.67,
                                         noise_std=1.0,
                                         noise_seed=42,
                                         )

    ### training
    print('Start Training', flush=True)
    print('-' * 30)
    timer_start = dt.now()

    tr_pAUC, te_pAUC, va_pAUC = [], [], []
    valid_eod, valid_eop, valid_dp, test_eod, test_eop, test_dp = [], [], [], [], [], []
    valid_intraAUC, valid_interAUC, valid_xAUC = [], [], []
    test_intraAUC, test_interAUC, test_xAUC = [], [], []
    best_test_pAUC_list = []
    tr_loss_list = []
    best_vali = float('-inf')
    best_test = float('-inf')
    best_record = 0

    print('start training...', flush=True)
    for epoch in range(total_epochs):
        if (not cos_lr) and (epoch in decay_epochs):
            optimizer.update_lr(decay_factor=decay_factor)
            adv_optimizer.update_lr(decay_factor=decay_factor)

        train_loss = 0
        for idx, data in enumerate(trainloader):

            train_data, train_labels, sensitive_labels, index = data
            train_data, train_labels, sensitive_labels = train_data.cuda(), train_labels.cuda(), sensitive_labels.cuda()
            feats, y_pred = model(train_data)
            y_prob = torch.sigmoid(y_pred)
            loss = loss_fn(y_prob, train_labels, index)

            '''
            STARTING OF Modification for adv loss
            '''

            # max loss:  Adversarial loss
            adv_pred = adv_head(feats)
            # print(adv_pred.size(), sensitive_labels.size(), sensitive_labels.dtype, adv_pred.dtype)
            sensitive_labels = sensitive_labels.type(torch.cuda.FloatTensor)
            adv_loss = adv_criterion(adv_pred, sensitive_labels)  # for \w_a
            loss = loss + args.adv_alpha * adv_loss
            # if idx % 500 == 0:
            #     print(' idx {}/{} loss {:4f}, args.adv_alpha * adv_loss {:4f} '.format(idx, len(trainloader), loss.item(), args.adv_alpha * adv_loss))
            '''
            ENDING OF MODIFICATION'
            '''

            # minimize w
            optimizer.zero_grad()
            adv_optimizer.zero_grad()
            loss.backward()
            optimizer.step()  # gradient descent
            adv_optimizer.step()  # gradient ascent for \w_a

            train_loss = train_loss + loss.cpu().detach().numpy()

            if cos_lr and epoch == 0 and idx % step_size == 0:
                lr_scheduler.step(idx // step_size)
                if optimizer_option == 'smag':
                    cur_lr0 = optimizer.param_groups[0]["lr_0"]
                    cur_lr1 = optimizer.param_groups[0]["lr_1"]
                    print('Epoch: ', epoch, ' iteration: ', idx, ' lr0: ', cur_lr0, ' lr1: ', cur_lr1)
                else:
                    cur_lr = optimizer.param_groups[0]["lr"]
                    print('Epoch: ', epoch, ' iteration: ', idx, ' lr: ', cur_lr)

            iter_count += 1

            if iter_count % eval_step == 0:
                model.eval()
                adv_head.eval()
                with ((torch.no_grad())):
                    if dataset == 'celeba_binary':
                        single_train_auc = 0
                    else:
                        train_pred = []
                        train_true = []
                        for jdx, data in enumerate(trainloader):
                            train_data, train_labels, train_sensitive_labels, _ = data
                            train_data = train_data.cuda()
                            _, y_pred = model(train_data)
                            y_prob = torch.sigmoid(y_pred)
                            train_pred.append(y_prob.cpu().detach().numpy())
                            train_true.append(train_labels.numpy())
                        train_true = np.concatenate(train_true)
                        train_pred = np.concatenate(train_pred)
                        single_train_auc = pauc_roc_score(train_true, train_pred, max_fpr=max_fpr)

                    # Valid
                    valid_pred = []
                    valid_true = []
                    valid_sa_true = []
                    for jdx, data in enumerate(validloader):
                        valid_data, valid_labels, valid_sensitive_labels, index = data
                        valid_data = valid_data.cuda()
                        _, y_pred = model(valid_data)
                        valid_pred.append(y_pred.cpu().detach().numpy())
                        valid_true.append(valid_labels.numpy())
                        valid_sa_true.append(valid_sensitive_labels.numpy())
                    valid_true = np.concatenate(valid_true)
                    valid_pred = np.concatenate(valid_pred)
                    valid_sa_true = np.concatenate(valid_sa_true)
                    single_valid_auc = pauc_roc_score(valid_true, valid_pred, max_fpr=max_fpr)
                    i_valid_auc_fair_dicts = AUCFairness(valid_pred, valid_true, valid_sa_true)
                    i_valid_eod, i_valid_eop = EOD_EOP(valid_pred, valid_true, valid_sa_true)
                    i_valid_dp = DP(valid_pred, valid_sa_true)
                    i_valid_intraAUC, i_valid_interAUC, i_valid_xAUC = float(
                        i_valid_auc_fair_dicts['intraAUCF']), float(i_valid_auc_fair_dicts["interAUCF"]), float(
                        i_valid_auc_fair_dicts["xAUCF"])
                    valid_eod.append('{:.4f}'.format(i_valid_eod))
                    valid_eop.append('{:.4f}'.format(i_valid_eop))
                    valid_dp.append('{:.4f}'.format(i_valid_dp))
                    valid_intraAUC.append(i_valid_intraAUC)
                    valid_interAUC.append(i_valid_interAUC)
                    valid_xAUC.append(i_valid_xAUC)
                    print('valid_auc_fair_dict : ', i_valid_auc_fair_dicts, 'valid_eod : ',
                          '{:.4f}'.format(i_valid_eod.item()), 'valid_eop : ', '{:.4f}'.format(i_valid_eop.item()),
                          'valid_dp : ', '{:.4f}'.format(i_valid_dp.item()))

                    # Test
                    test_pred = []
                    test_true = []
                    test_sa_true = []
                    for jdx, data in enumerate(testloader):
                        test_data, test_labels, test_sensitive_labels, index = data
                        test_data = test_data.cuda()
                        _, y_pred = model(test_data)
                        test_pred.append(y_pred.cpu().detach().numpy())
                        test_true.append(test_labels.numpy())
                        test_sa_true.append(test_sensitive_labels.numpy())
                    test_true = np.concatenate(test_true)
                    test_pred = np.concatenate(test_pred)
                    test_sa_true = np.concatenate(test_sa_true)
                    single_test_auc = pauc_roc_score(test_true, test_pred, max_fpr=max_fpr)
                    i_test_auc_fair_dicts = AUCFairness(test_pred, test_true, test_sa_true)
                    i_test_eod, i_test_eop = EOD_EOP(test_pred, test_true, test_sa_true)
                    i_test_dp = DP(test_pred, test_sa_true)
                    i_test_intraAUC, i_test_interAUC, i_test_xAUC = float(i_test_auc_fair_dicts['intraAUCF']), float(
                        i_test_auc_fair_dicts["interAUCF"]), float(i_test_auc_fair_dicts["xAUCF"])
                    test_eod.append('{:.4f}'.format(i_test_eod))
                    test_eop.append('{:.4f}'.format(i_test_eop))
                    test_dp.append('{:.4f}'.format(i_test_dp))
                    test_intraAUC.append(i_test_intraAUC)
                    test_interAUC.append(i_test_interAUC)
                    test_xAUC.append(i_test_xAUC)
                    print('test_auc_fair_dict : ', i_test_auc_fair_dicts, 'test_eod : ',
                          '{:.4f}'.format(i_test_eod.item()), 'test_eop : ', '{:.4f}'.format(i_test_eop.item()),
                          'test_dp : ', '{:.4f}'.format(i_test_dp.item()))

                    if single_valid_auc > best_vali:
                        best_vali = single_valid_auc
                        best_test = single_test_auc
                        best_record = [iter_count, train_loss / (idx + 1), single_train_auc, single_valid_auc,
                                       single_test_auc]
                    print(
                        'Step=%s, Loss=%.4f, Train_pAUC(0.3)=%.4f, Valid_pAUC(0.3)=%.4f, Test_pAUC(0.3)=%.4f, Best_test_pAUC=%.4f' % (
                            iter_count, train_loss / (idx + 1), single_train_auc, single_valid_auc, single_test_auc,
                            best_test), flush=True)

                    best_test_pAUC_list.append(best_test)
                    tr_loss_list.append(train_loss / (idx + 1))
                    tr_pAUC.append(single_train_auc)
                    te_pAUC.append(single_test_auc)
                    va_pAUC.append(single_valid_auc)


                print('output_res_len: ', len(valid_eod), len(valid_eop), len(valid_dp))
                print(
                    'Best result : Iter=%s, Loss=%.4f, Train_pAUC(0.3)=%.4f, Valid_pAUC(0.3)=%.4f, Test_pAUC(0.3)=%.4f' % (
                        best_record[0], best_record[1], best_record[2], best_record[3], best_record[4]))
                print('Total training time : ', td.total_seconds(dt.now() - timer_start))

                model.train()
                adv_head.train()

        if cos_lr:
            lr_scheduler.step(epoch + 1)
            if optimizer_option == 'smag':
                cur_lr0 = optimizer.param_groups[0]["lr_0"]
                cur_lr1 = optimizer.param_groups[0]["lr_1"]
                print('Epoch: ', epoch, ' lr0: ', cur_lr0, ' lr1: ', cur_lr1)
            else:
                cur_lr = optimizer.param_groups[0]["lr"]
                print('Epoch: ', epoch, ' lr: ', cur_lr)

    output_dict = {'train_pAUC': tr_pAUC, 'test_pAUC': te_pAUC, 'val_pAUC': va_pAUC,
                   'best_test_pAUC': best_test_pAUC_list, 'tr_loss_list': tr_loss_list,
                   'val_EOD': valid_eod, 'valid_eop': valid_eop, 'val_DP': valid_dp,
                   'test_EOD': test_eod, 'test_EOP': test_eop, 'test_DP': test_dp,
                   'valid_intraAUC': valid_intraAUC, 'valid_interAUC': valid_interAUC, 'valid_xAUC': valid_xAUC,
                   'test_intraAUC': test_intraAUC, 'test_interAUC': test_interAUC, 'test_xAUC': test_xAUC}
    today = date.today()
    datestr = today.strftime("%b%d%Y")

    if optimizer_option == 'sopa':
        if weight_decay == 0:
            wd_str = '0'
        else:
            wd_str = str(int(1 / weight_decay))
        file_name = '{0}_{1}_{2}_lr{3}_wd{4}_maxfpr{5}_seed{6}_{7}_adv_{8}_{9}_isPRETRAIN_{10}_lr0{11}.p'.format(
            optimizer_option, optimizer_mode, dataset, str(int(1 / lr)), wd_str, str(int(max_fpr * 10)), str(SEED),
            datestr, str(args.adv_alpha), args.binary_label, str(args.pretrain), str(1/args.lr_0))
    else:
        file_name = '{0}_{1}_{2}_lr0{3}_lr1{4}_eta0{5}_eta1{6}_gamma{7}_maxfpr{8}_seed{9}_{10}_adv_{11}_{12}_isPRETRAIN_{13}.p'.format(
            optimizer_option, optimizer_mode, dataset, str(int(1 / lr_0)), str(int(1 / lr_1)), str(int(1 / eta_0)),
            str(int(1 / eta_1)),
            str(gamma), str(int(max_fpr * 10)), str(SEED), datestr, str(args.adv_alpha), args.binary_label,
            str(args.pretrain))
    print('creating new file: ', file_name)
    results_path = 'xxx' # plase add local path

    new_file_path = os.path.join(results_path, file_name)
    with open(new_file_path, 'wb') as handle:
        pickle.dump(output_dict, handle)
        handle.close()


if __name__ == "__main__":
    main()