#!/usr/bin/env python3
# -*- coding: utf-8 -*-



import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock
from torch import optim

import os
import numpy as np
from sklearn.metrics import roc_auc_score

from multi_task_sampler import DataSampler
from Multi_pAUC_KL import Multi_pAUC_KL

from utils import PAUC_MultiLabel,pAUC_mini, ImageDataset, pretrain_cifar, partial_auc
from libauc.datasets import imbalance_generator
from libauc.datasets import CIFAR10, CIFAR100



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def resnet18(num_classes):
    model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
    return model

def run_train(epoch, model, optimizer, criterion, train_loader, AUC1_tr,AUC2_tr, T):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    test_pred = []
    test_true = [] 
    for batch_idx, (ind, inputs, targets) in enumerate(train_loader):
        T += 1.
        criterion.beta1 = 1./np.sqrt(T)
        
        inputs, targets, ind = inputs.to(device), targets.to(device), ind.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = torch.sigmoid(outputs)
        loss = criterion(outputs, targets, ind) 
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            outputs = torch.sigmoid(outputs)

        test_pred.append(outputs.cpu().detach().numpy())
        test_true.append(targets.cpu().numpy())

        train_loss += loss.item()

    test_true = np.concatenate(test_true)
    test_pred = np.concatenate(test_pred)
    val_auc_mean1 =  partial_auc(test_true, test_pred, max_fpr=0.1) 
    val_auc_mean2 =  partial_auc(test_true, test_pred, max_fpr=0.3) 


    print('Training Loss: %.3f, AUC1: %.3f, AUC2: %.3f'%(train_loss/(batch_idx+1), val_auc_mean1, val_auc_mean2))
    AUC1_tr.append(val_auc_mean1)
    AUC2_tr.append(val_auc_mean2)
    return T
    
def run_test(epoch, model, test_loader, AUC1_te, AUC2_te, best_auc):

    model.eval()
    save_flag = False
    with torch.no_grad():
        test_pred = []
        test_true = [] 
        for batch_idx, (ind, inputs, targets) in enumerate(test_loader):
            inputs, targets, ind = inputs.to(device), targets.to(device), ind.to(device)
            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)
            
            test_pred.append(outputs.cpu().detach().numpy())
            test_true.append(targets.cpu().numpy())
            
        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_auc_mean1 =  partial_auc(test_true, test_pred, max_fpr=0.1) 
        val_auc_mean2 =  partial_auc(test_true, test_pred, max_fpr=0.3) 
        if val_auc_mean1>best_auc:
            best_auc = val_auc_mean1
            save_flag = True

        print('Testing AUC1: %.3f, AUC2: %.3f'%(val_auc_mean1, val_auc_mean2))
    AUC1_te.append(val_auc_mean1)
    AUC2_te.append(val_auc_mean2)
    return best_auc, save_flag

def main():
    loss_type = 'PAUC'
    batch_size = 100
    tasks = 100
    epochs = 100
    pre_train = True
    dataset = 'cifar100'
    
    AUC1_te = []
    AUC1_tr = []
    AUC2_te = []
    AUC2_tr = []


    

    (train_data, train_label), (test_data, test_label) = CIFAR100()
    tmp1, tmp2 = list(range(len(train_label))), list(range(len(test_label)))
    train_labels, test_labels = torch.zeros(len(train_label),tasks), torch.zeros(len(test_label),tasks)
    train_labels[tmp1, torch.tensor(train_label).squeeze().long()] += 1
    test_labels[tmp2, torch.tensor(test_label).squeeze().long()] += 1
    # (train_images, train_labels) = imbalance_generator(train_data, train_label, imratio=imratio, shuffle=True, random_seed=SEED)
    # (test_images, test_labels) = imbalance_generator(test_data, test_label, is_balanced=True, random_seed=SEED)
    
    train_loader = DataLoader(ImageDataset(train_data, train_labels, mode='train'),
                         sampler=DataSampler(train_labels,batchSize=batch_size,multi_tasks=10),
                         batch_size=batch_size, num_workers=4, pin_memory=True)
    test_loader = DataLoader(ImageDataset(test_data, test_labels, mode='test'), batch_size=128,
                        shuffle=False, num_workers=4,  pin_memory=True)

    model = resnet18(tasks).to(device)
    
    if pre_train:
        model_path = 'models/'+dataset+'_resnet18_CrossEntropyLoss_pretrain.pth'
        if os.path.isfile(model_path):
            model.load_state_dict(torch.load(model_path))
        else:
            pretrain_cifar(model, train_loader, test_loader, model_path)
    
    if loss_type == 'PAUC':
        criterion = PAUC_MultiLabel(num_classes=tasks, eta1=0.1, eta2=0.1, beta=0.2, beta0=0.9, beta1=0.9, tau1=1, tau2=1)
    elif loss_type == 'SOPA': 
        criterion = Multi_pAUC_KL(data_len=train_labels.shape[0], gamma=0.1, Lambda=1, total_tasks=tasks)
    else:
        criterion = pAUC_mini(threshold=1., gamma=0.7)
    #optimizer = optim.Adam(list(model.parameters())+list(criterion.parameters()), lr=1e-3, weight_decay=5e-4)
    optimizer = optim.SGD(list(model.parameters())+list(criterion.parameters()),momentum=0.9, lr=5e-3, weight_decay=5e-4)
    #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[50],gamma=0.1)
    best_auc = 0
    T = 0
    for epoch in range(epochs):
        T = run_train(epoch, model, optimizer, criterion, train_loader, AUC1_tr,AUC2_tr, T)
        best_auc, save_flag = run_test(epoch, model, test_loader, AUC1_te,AUC2_te, best_auc)
        scheduler.step()
        
        if save_flag:
            torch.save(model.state_dict(), 'models/cifar100_resnet18_'+loss_type+str(tt)+'.pth')
        
        #wandb.log({"train_auc1": AUC1_tr[-1], "train_auc2": AUC2_tr[-1], "test_auc1": AUC1_te[-1], "test_auc2": AUC2_te[-1], "step": epoch})

        
    #np.save('train_auc_tasks'+str(tasks)+'_bs'+str(batch_size)+'.npy',AUC_tr)
    np.save(dataset+'train_auc1_'+loss_type+'_'+'.npy',AUC1_tr)
    np.save(dataset+'test_auc1_'+loss_type+'_'+'.npy',AUC1_te)
    np.save(dataset+'train_auc2_'+loss_type+'_'+'.npy',AUC2_tr)
    np.save(dataset+'test_auc2_'+loss_type+'_'+'.npy',AUC2_te)
    #np.save('test_loss_tasks'+str(tasks)+'_bs'+str(batch_size)+'.npy',LOS_tr)

if __name__ == '__main__':
            main()
    
        
        
