# -*- coding: utf-8 -*-



path = " "

import csv
import numpy as np
import random
import torch
from PIL import Image
from torch.utils.data import DataLoader,TensorDataset
from torchvision.transforms import Compose, RandomCrop, Pad, RandomHorizontalFlip, Resize, RandomAffine,RandomResizedCrop,CenterCrop
from torchvision.transforms import ToTensor, Normalize,transforms
from PIL.Image import BICUBIC
import scipy.io as sio

batch_size = 128
num_workers=4

import torch
import torch.nn as nn
import torch.nn.functional as F

arr = sio.loadmat('saved_test_feature.mat')
print("load .npy done")
test_features = arr['features']
test_label = arr['label']
test_g = arr['g']
arr = sio.loadmat('saved_train_feature.mat')
print("load .npy done")
train_features = arr['features']
train_label = arr['label']
train_g = arr['g']
'''
arr = sio.loadmat('120_preds-on_validation.mat')
print("load .npy done")
validation_features = arr['features']
validation_label = arr['label']
validation_g = arr['g']
'''
arr = sio.loadmat('saved_validation_feature.mat')
#print(arr[0])
print("load .npy done")
validation_test_features = arr['features']
validation_test_label = arr['label']
validation_test_g = arr['g']

from sklearn.model_selection import train_test_split

np.shape(train_features)

train_features_groups = np.hstack((train_features,train_g))

train_features_groups, validation_features_groups, train_label, validation_label = train_test_split(train_features_groups,train_label, test_size=0.5, random_state=1)

train_features = train_features_groups[:,:512]
train_g = np.reshape(train_features_groups[:,-1],(-1,1))
validation_features = validation_features_groups[:,:512]
validation_g = np.reshape(validation_features_groups[:,-1],(-1,1))

train_tensor_x = torch.Tensor(train_features)
train_tensor_y = torch.Tensor(train_label)
train_tensor_g = torch.Tensor(train_g)
val_tensor_x = torch.Tensor(validation_features)
val_tensor_y = torch.Tensor(validation_label)
val_tensor_g = torch.Tensor(validation_g)
val_test_tensor_x = torch.Tensor(validation_test_features)
val_test_tensor_y = torch.Tensor(validation_test_label)
val_test_tensor_g = torch.Tensor(validation_test_g)
test_tensor_x = torch.Tensor(test_features)
test_tensor_y = torch.Tensor(test_label)
test_tensor_g = torch.Tensor(test_g)

train_my_dataset = TensorDataset(train_tensor_x,train_tensor_y,train_tensor_g)
train_my_dataloader = DataLoader(train_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)
test_my_dataset = TensorDataset(test_tensor_x,test_tensor_y,test_tensor_g)
test_my_dataloader = DataLoader(test_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)
val_my_dataset = TensorDataset(val_tensor_x,val_tensor_y,val_tensor_g)
val_my_dataloader = DataLoader(val_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)
val_test_my_dataset = TensorDataset(val_test_tensor_x,val_test_tensor_y,val_test_tensor_g)
val_test_my_dataloader = DataLoader(val_test_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)

!pip install pytorch-ignite

import torch.nn.functional as F
#from utils.metrics import topk_corrects
import torch
from torch.autograd import grad
def gather_flat_grad(loss_grad):
    #cnt = 0
    #for g in loss_grad:
    #    g_vector = g.contiguous().view(-1) if cnt == 0 else torch.cat([g_vector, g.contiguous().view(-1)])
    #    cnt = 1
    return torch.cat([p.contiguous().view(-1) for p in loss_grad if not p is None]) #g_vector

def neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, elementary_lr, num_neumann_terms, model):
    preconditioner = d_val_loss_d_theta.detach()
    counter = preconditioner

    # Do the fixed point iteration to approximate the vector-inverseHessian product
    i = 0
    while i < num_neumann_terms:  # for i in range(num_neumann_terms):
        old_counter = counter

        # This increments counter to counter * (I - hessian) = counter - counter * hessian
        #gradient=grad(d_train_loss_d_w, model.parameters(), grad_outputs=counter.view(-1), retain_graph=True)
        #print(gradient)
        hessian_term = gather_flat_grad(
            grad(d_train_loss_d_w, model.parameters(), grad_outputs=counter.view(-1), retain_graph=True))
        counter = old_counter - elementary_lr * hessian_term

        preconditioner = preconditioner + counter
        i += 1
    return elementary_lr * preconditioner

def train_epoch(cur_epoch, model, in_loader, in_criterion , in_optimizer, in_logit_adjust=None, in_params=None,
    is_out=False, out_loader=None, out_optimizer=None, out_criterion=None, out_logit_adjust=None, out_params=None,out_posthoc=False,
    ITER_LR=None, ARCH_EPOCH=0,num_classes=2,ARCH_INTERVAL=1,ARCH_TRAIN_SAMPLE=1,ARCH_VAL_SAMPLE=1):
    """Performs one epoch of bilevel optimization."""

    

    # Enable training mode
    model.train()
    if is_out:
        print('lr: ',in_optimizer.param_groups[0]['lr'],'  arch lr: ',out_optimizer.param_groups[0]['lr'])
        out_iter = iter(out_loader)
        in_iter_alt=iter(in_loader)
    else:
        print('lr: ',in_optimizer.param_groups[0]['lr'])

    total_correct=0.
    total_sample=0.
    total_loss=0.
    arch_interval=20
    num_weights, num_hypers = sum(p.numel() for p in model.parameters()), 3*num_classes
    use_reg=True
    d_train_loss_d_w = torch.zeros(num_weights)
    #d_train_loss_d_w = torch.zeros(num_weights).cuda()
    for cur_iter, (in_data, in_label,in_group) in enumerate(in_loader):
        in_label = in_label.long()[:,0]
        in_group = in_group.long()[:,0]
        
        # Transfer the data to the current GPU device
        #in_data, in_targets = in_data.cuda(non_blocking=True), in_targets_2[:,0].cuda(non_blocking=True)
        # Update architecture
        if is_out and not out_posthoc:# and cur_epoch>=ARCH_EPOCH:
            model.train()
            out_optimizer.zero_grad()
            if cur_iter%ARCH_INTERVAL==0:
                for cur_iter_alt in range(ARCH_TRAIN_SAMPLE):
                    try:
                        in_data_alt, in_label_alt,in_group_alt = next(in_iter_alt)
                    except StopIteration:
                        in_iter_alt = iter(in_loader)
                        in_data_alt, in_label_alt,in_group_alt = next(in_iter_alt) 
                    #in_data_alt, in_targets_alt = in_data_alt.cuda(non_blocking=True), in_targets_alt.long().cuda(non_blocking=True)
                    in_label_alt = in_label_alt.long()[:,0]
                    in_group_alt = in_group_alt.long()[:,0]
                    in_optimizer.zero_grad()
                    in_preds=model(in_data_alt)
                    #print(in_label_alt)
                    #print(in_group_alt)
                    in_loss=in_criterion(in_preds,in_label_alt,in_group_alt,in_params) 
                    d_train_loss_d_w+=gather_flat_grad(grad(in_loss,model.parameters(),create_graph=True))
                    #print(cur_iter_alt)
                d_train_loss_d_w/=ARCH_TRAIN_SAMPLE
                #d_val_loss_d_theta, direct_grad = torch.zeros(num_weights).cuda(), torch.zeros(num_hypers).cuda()
                d_val_loss_d_theta, direct_grad = torch.zeros(num_weights), torch.zeros(num_hypers)
                #print(direct_grad)

                for _ in range(ARCH_VAL_SAMPLE):
                    try:
                        out_data, out_label, out_group = next(out_iter)
                    except StopIteration:
                        out_iter = iter(out_loader)
                        out_data, out_label, out_group = next(out_iter) 
                #for _,(out_data,out_targets) in enumerate(out_loader):
                    #out_data, out_targets1 = out_data.cuda(non_blocking=True), out_targets.long()[:,0].cuda(non_blocking=True)
                    out_group = out_group.long()[:,0]
                    out_label = out_label.long()[:,0]
                    model.zero_grad()
                    in_optimizer.zero_grad()
                    out_preds = model(out_data)
                    #print(out_targets.long().cuda(non_blocking=True).size())
                    out_loss = out_criterion(out_preds,out_label, out_group,la,out_params)
                    d_val_loss_d_theta += gather_flat_grad(grad(out_loss, model.parameters(), retain_graph=use_reg))
                    # if use_reg:
                    #     direct_grad+=gather_flat_grad(grad(out_loss, get_trainable_hyper_params(out_params), allow_unused=True))
                    #     direct_grad[direct_grad != direct_grad] = 0
                d_val_loss_d_theta/=ARCH_VAL_SAMPLE
                direct_grad/=ARCH_VAL_SAMPLE
                preconditioner = d_val_loss_d_theta
                
                preconditioner = neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, 1.0,
                                                                5, model)
                indirect_grad = gather_flat_grad(
                    grad(d_train_loss_d_w, get_trainable_hyper_params(out_params), grad_outputs=preconditioner.view(-1),allow_unused=True))
                hyper_grad=indirect_grad#+direct_grad
                out_optimizer.zero_grad()
                assign_hyper_gradient(out_params,-hyper_grad,num_classes)
                out_optimizer.step()
                #d_train_loss_d_w = torch.zeros(num_weights).cuda()
                d_train_loss_d_w = torch.zeros(num_weights)
        
        if is_out and out_posthoc:
            try:
                out_data, out_label, out_group = next(out_iter)
            except StopIteration:
                out_iter = iter(out_loader)
                out_data, out_label, out_group = next(out_iter) 
            #out_data, out_targets = out_data.cuda(non_blocking=True), out_targets.cuda(non_blocking=True)
            out_group = out_group.long()[:,0]
            out_label = out_label.long()[:,0]
            out_preds=model(out_data)
            out_preds=out_logit_adjust(out_preds,params=out_params)
            out_loss=out_criterion(out_preds,out_label, out_group,out_params)
            out_optimizer.zero_grad()
            out_loss.backward()
            out_optimizer.step()


        # Perform the forward pass
        in_preds = model(in_data)
        if not in_logit_adjust is None:
            in_preds=in_logit_adjust(in_preds,in_params)
        # Compute the loss
        loss = in_criterion(in_preds, in_label, in_group, in_params)
        # Perform the backward pass
        in_optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
        in_optimizer.step()

        # Compute the errors
        mb_size = in_data.size(0)
        ks = [1] 
        top1_correct = topk_corrects(in_preds, in_label, ks)[0]
        
        # Copy the stats from GPU to CPU (sync point)
        loss = loss.item()
        top1_correct = top1_correct.item()
        total_correct+=top1_correct
        total_sample+=mb_size
        total_loss+=loss*mb_size
    # Log epoch stats
    print(f'Epoch {cur_epoch} :  Loss = {total_loss/total_sample}   ACC = {total_correct/total_sample*100.}')



import numpy as np

@torch.no_grad()
def eval_epoch(data_loader, model, criterion, cur_epoch, text, params=None, logit_adjust=None, num_classes=2,class_wise=False):
    model.eval()
    correct=0.
    total=0.
    loss=0.
    class_correct=np.zeros(num_classes,dtype=float)
    class_total=np.zeros(num_classes,dtype=float)

    for cur_iter, (data, label, group) in enumerate(data_loader):
        #print(np.shape(targets2[:,0]))
        #print(np.shape(data))
        label = label.long()[:,0]
        group = group.long()[:,0]
        #targets2 = targets2.long()
        #data, targets = data.cuda(), targets2[:,0].cuda(non_blocking=True)
        logits = model(data)
        if not logit_adjust is None:
            logits=logit_adjust(logits,label, group,params)

        preds = logits.data.max(1)[1]
        #print(logits,preds, targets==preds)
        mb_size = data.size(0)
        # if not dy is None:
        #     print(my_cross_entropy(logits,labels,dy,ly))
        loss+=criterion(logits,label,group,la,params).item()*mb_size
        # if 'train' in text:
        #     loss += loss_fun(logits, labels,dy,ly ).item()*mb_size
        # else:
        #     loss += loss_fun(logits, labels).item()*mb_size
        total+=mb_size
        correct+=preds.eq(label.data.view_as(preds)).sum().item()
        if class_wise:
            for i in range(num_classes):
                indexes=np.where(label.cpu().numpy()==i)[0]
                class_total[i]+=indexes.size
                class_correct[i]+=preds[indexes].eq(label[indexes].data.view_as(preds[indexes])).sum().item()
            #print(class_total,class_correct)
    text=f'TEST {text}: Epoch {cur_epoch} :  Loss = {loss/total}   ACC = {correct/total*100.}'
    if class_wise:
        text=f'TEST {text}: Epoch {cur_epoch} :  Loss = {loss/total}   ACC = {correct/total*100.} Class wise = {class_correct/class_total*100.}'
    print(text)
    return text,loss/total,correct/total*100.

def loss_adjust_cross_entropy(logits,label, group,params):
    #assert(len(params)==2)
    dy=params[0]
    ly=params[1]
    #wy=params[2]
    #print(targets)
    pi_list = [0.73,0.038,0.012,0.22]
    #pi = torch.FloatTensor(pi_list).cuda()
    pi = torch.FloatTensor(pi_list)
    class_val = label*2+group
    #print(class_val)
    #print(class_val)
    one_hot = F.one_hot(class_val,num_classes=4)
    one_hot = one_hot.type(torch.float32)
    pi_yg = torch.max(pi*one_hot,axis = 1)[0]#right
    #one_hot_group = F.one_hot(group,num_classes=2).type(torch.float32)
    #pi_g = torch.max(torch.sum(torch.reshape(pi,(2,2)),axis=0)*one_hot_group,axis = 1)[0]
    one_hot_group = F.one_hot(group,num_classes=2).T
    one_hot_group = one_hot_group.type(torch.float32)
    #print(np.shape(dy))
    #print(np.shape(one_hot_group))
    d_yg = torch.mm(dy,one_hot_group).T
    l_yg = torch.mm(ly,one_hot_group).T
    #print(d_yg)
    x=logits*F.sigmoid(d_yg)-l_yg
    #x=logits*d_yg-l_yg
    loss=F.cross_entropy(x,label,reduction='none')
    #weighted_loss = (1/pi_yg)*loss
    weighted_loss = (1/pi_yg)*loss
    loss_ba_xent = torch.mean(weighted_loss)    
    #x=torch.transpose(torch.transpose(logits,0,1)*dy[targets[:,0]],0,1)+ly[]
    #print(loss)
    
    #loss=wy[targets]*F.cross_entropy(x,targets)
    return loss_ba_xent

def cross_entropy(logits,label, group,la,params):
    return F.cross_entropy(logits,label)

def logit_adjust_ly(logits,label, group,params):
    #assert(len(params)==2)
    dy=params[0]
    ly=params[1]
    one_hot_group = F.one_hot(group,num_classes=2).T
    one_hot_group = one_hot_group.type(torch.float32)
    d_yg = torch.mm(dy,one_hot_group).T
    l_yg = torch.mm(ly,one_hot_group).T
    x=logits*F.sigmoid(d_yg)-l_yg
    #x=logits*d_yg-l_yg
    return x
    
def outer_loss(logits,label, group,la,params):
    pi_list = [0.73,0.038,0.012,0.22]
    tau = 1
    #pi = torch.FloatTensor(pi_list).cuda()
    pi = torch.FloatTensor(pi_list)
    class_val = label*2+group
    one_hot = F.one_hot(class_val,num_classes=4)
    one_hot = one_hot.type(torch.float32)
    pi_yg = torch.max(pi*one_hot,axis = 1)[0]#right

    loss = F.cross_entropy(logits,label,reduction='none')
    #print(loss)
    weighted_loss = (1/pi_yg)*loss
    weighted_loss_group = (one_hot.T*weighted_loss.T).T
    weighted_loss_group_sum = torch.sum(weighted_loss_group,axis=0)
    loss_DEO = abs(weighted_loss_group_sum[0]-weighted_loss_group_sum[1])+abs(weighted_loss_group_sum[2]-weighted_loss_group_sum[3])
    return la*loss_DEO+(1-la)*F.cross_entropy(logits,label)
def get_trainable_hyper_params(params):
    return[param for param in params if param.requires_grad]
def assign_hyper_gradient(params,gradient,num_classes):
    i=0
    for para in params:
        if para.requires_grad:
            para.grad=torch.reshape(gradient[i:i+num_classes*2],(2,2)).clone()
            i+=num_classes*2
        #print(gradient)
    #print(params)

import os
import numpy as np
import random
from numpy.lib.scimath import log
import torch
import ignite
from torch._C import dtype

import matplotlib.pyplot as plt
import torchvision.utils as vutils
#from torch.utils.metrics import print_num_params
#from core.trainer import loss_adjust_cross_entropy,cross_entropy, logit_adjust_ly, train_epoch,eval_epoch
import torch.optim as optim
import torch.nn as nn
'''
assert torch.cuda.is_available()
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
device = "cuda"
'''

def topk_corrects(preds, labels, ks):
    """Computes the top-k error for each k."""
    err_str = "Batch dim of predictions and labels must match"
    assert preds.size(0) == labels.size(0), err_str
    # Find the top max_k predictions for each sample
    _top_max_k_vals, top_max_k_inds = torch.topk(
        preds, max(ks), dim=1, largest=True, sorted=True
    )
    # (batch_size, max_k) -> (max_k, batch_size)
    top_max_k_inds = top_max_k_inds.t()
    # (batch_size, ) -> (max_k, batch_size)
    rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
    # (i, j) = 1 if top i-th prediction for the j-th sample is correct
    top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
    # Compute the number of topk correct predictions for each k
    topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
    return topks_correct

def print_num_params(model, display_all_modules=False):
    total_num_params = 0
    for n, p in model.named_parameters():
        num_params = 1
        for s in p.shape:
            num_params *= s
        if display_all_modules: print("{}: {}".format(n, num_params))
        total_num_params += num_params 
    print("-" * 50)
    print("Total number of parameters: {:.2e}".format(total_num_params))

import torch
import torch.nn as nn
import torch.nn.functional as F

D_in = np.shape(test_features)[1]
num_classes = 2

seed = 9
  torch.manual_seed(seed)
  model = torch.nn.Linear(D_in, num_classes)
  print(model.weight)
  print(model.bias)

network_model='resnet18' # Model, either ResNet32 or Efficient
#dataset='Cifar10'  # Dataset, either Cifar10 or Cifar100
lr = 0.0005  # inner optim lr
hp_lr=0.00005  # outer optim lr
#total_epoch=200 # Total training epoch
batch_size=128 # Training batchsize
#train_rho=0.0000000001 # Imbalance ratio : Min/Max
#ARCH_EPOCH=0 # The epoch for starting outer opimization
#model=ResNet20()
#model = resnet18()

total_epoch=450
ARCH_EPOCH=150
#tag
lr=0.05
arch_lr=0.001
save_path='./results/loss_adjust_ly_dy_pretrain'
#train_rho=0.0005

'''
dy=torch.ones([num_classes,2],dtype=torch.float32,device=device)
ly=torch.zeros([num_classes,2],dtype=torch.float32,device=device)
wy=torch.ones([num_classes,2],dtype=torch.float32,device=device)
'''
dy=torch.ones([num_classes,2],dtype=torch.float32)
ly=torch.zeros([num_classes,2],dtype=torch.float32)
wy=torch.ones([num_classes,2],dtype=torch.float32)

train_optimizer = optim.SGD(params=model.parameters(),lr=lr,momentum=0.9,weight_decay=1e-4)
val_optimizer = optim.SGD(params=[{'params':dy},{'params':ly},{'params':wy}],
                        lr=arch_lr,momentum=0.9,weight_decay=1e-4)
train_lr_scheduler=optim.lr_scheduler.MultiStepLR(train_optimizer,milestones=[280,330],gamma=0.1)
val_lr_scheduler=optim.lr_scheduler.MultiStepLR(val_optimizer,milestones=[280,330],gamma=0.2)

dy.requires_grad=True
ly.requires_grad=True

la = 0.03
if not os.path.exists('./results/plain/'):
        os.makedirs('./results/plain/')
logfile=open('./results/plain/logs.txt',mode='w')
#print(dy.requires_grad,ly)
torch.save(model,'./results/plain/init_model.pth')

if save_path is None:
        import time
        save_path=f'./results/{int(time.time())}'       
if not os.path.exists(save_path):
        os.makedirs(save_path)

checkpoint = 0
if checkpoint==0:
        torch.save(model,f'{save_path}/init_model.pth')
        logfile=open(f'{save_path}/logs.txt',mode='w')
        dy_log=open(f'{save_path}/dy.txt',mode='w')
        ly_log=open(f'{save_path}/ly.txt',mode='w')
        acc_log=open(f'{save_path}/acc.txt',mode='w')
        config_log=open(f'{save_path}/config.txt',mode='w')

ARCH_VAL_SAMPLE = 10
ARCH_TRAIN_SAMPLE = 10
ARCH_INTERVAL = 10
ARCH_EPOCH_INTERVAL = 1
ARCH_END = 370

text,loss,test_acc=eval_epoch(test_my_dataloader,model,cross_entropy,0,' train_dataset',params=[dy,ly],logit_adjust=None,num_classes=num_classes)

checkpoint = 0
checkpoint_interval = 5
train_acc_all = []
val_acc_all = []
test_acc_all = []
train_loss_all = []
val_loss_all = []
test_loss_all = []
for i in range(checkpoint+1,total_epoch+1):
        torch.cuda.empty_cache()
        text,train_loss,train_acc=eval_epoch(val_test_my_dataloader,model,cross_entropy,i,' train_dataset',params=[dy,ly],logit_adjust=None,num_classes=num_classes)
        logfile.write(text+'\n')
        text,val_loss,val_acc=eval_epoch(val_my_dataloader,model,cross_entropy,i,' val_dataset',params=[dy,ly,wy],logit_adjust=None,num_classes=num_classes)
        logfile.write(text+'\n')
        text,test_loss,test_acc=eval_epoch(test_my_dataloader,model,cross_entropy,i,' test_dataset',params=[dy,ly,wy],logit_adjust=None,num_classes=num_classes)
        logfile.write(text+'\n')
        print(dy,ly,'\n')
        train_epoch(i, model, 
                in_loader=train_my_dataloader, in_criterion=loss_adjust_cross_entropy, 
                in_optimizer=train_optimizer,in_params=[dy,ly],
                is_out=(i>=ARCH_EPOCH) and (i<=ARCH_END) and ((i+1)%ARCH_EPOCH_INTERVAL)==0, 
                out_loader=val_my_dataloader, out_optimizer=val_optimizer,
                out_criterion=outer_loss, out_logit_adjust=None, out_params=[dy,ly],out_posthoc=False,
                num_classes=num_classes,
                ARCH_EPOCH=ARCH_EPOCH,ARCH_INTERVAL=ARCH_INTERVAL,
                ARCH_TRAIN_SAMPLE=ARCH_TRAIN_SAMPLE,ARCH_VAL_SAMPLE=ARCH_VAL_SAMPLE)
        train_acc_all.append(train_acc)
        val_acc_all.append(val_acc)
        test_acc_all.append(test_acc)
        train_loss_all.append(train_loss)
        val_loss_all.append(val_loss)
        test_loss_all.append(test_loss)        
        logfile.write(str(dy)+str(ly)+'\n\n')
        dy_log.write(f'{dy.detach().cpu().numpy()}\n')
        ly_log.write(f'{ly.detach().cpu().numpy()}\n')
        acc_log.write(f'{train_acc} {val_acc} {test_acc}\n')
        logfile.flush()
        dy_log.flush()
        ly_log.flush()
        acc_log.flush()
        train_lr_scheduler.step()
        val_lr_scheduler.step()
        if i%checkpoint_interval==0:
                torch.save(model,f'{save_path}/epoch_{i}.pth')
logfile.close()
dy_log.close()
ly_log.close()
acc_log.close()
torch.save(model,f'{save_path}/loss_adjustment.pth')

def eval_per_class(data_loader, model, text,flag=0):
    model.eval()
    correct=0.
    total=0.
    loss=0.
    class_group_correct = list(0. for i in range(4))
    class_group_total = list(0. for i in range(4))
    accuracy_4 = []
    classes = ('00', '01', '10', '11')
    for cur_iter, (data, label,group) in enumerate(data_loader):
        label = label.long()
        group = group[:,0]
        label = label[:,0]
        #data, label,group = data.cuda(), label[:,0].cuda(non_blocking=True),group[:,0].cuda(non_blocking=True)
        logits = model(data)
        preds = logits.data.max(1)[1]
        c = (label == preds).squeeze()
        #print(logits,preds, targets==preds)
        mb_size = data.size(0)
        # if not dy is None:
        #     print(my_cross_entropy(logits,labels,dy,ly))
        # if 'train' in text:
        #     loss += loss_fun(logits, labels,dy,ly ).item()*mb_size
        # else:
        #     loss += loss_fun(logits, labels).item()*mb_size
        total+=mb_size
        if mb_size>=1:
          for i in range(int(mb_size)):
            #label = preds[i].item()
            label_i = label[i].item()
            group_i = group[i].item()
            if label_i == 0 and group_i == 0:
              class_4 = 0
            if label_i == 0 and group_i == 1:
              class_4 = 1              
            if label_i == 1 and group_i == 0:
              class_4 = 2
            if label_i == 1 and group_i == 1:
              class_4 = 3             
            class_group_correct[class_4] += c.cpu().numpy()[i]
            class_group_total[class_4] += 1
        correct+=preds.eq(label.data.view_as(preds)).sum().item()
    if flag ==0:
      for i in range(4):
        if class_group_total[i] != 0:
          print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * class_group_correct[i] / class_group_total[i])
        else:
          print('No image')
    else:
      for i in range(4):
        if class_group_total[i] != 0:
          #print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * class_group_correct[i] / class_group_total[i])
        else:
          print('No image')      
        '''
    for i in range(4):
      if class_group_total[i] != 0:
        #print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
        print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
      else:
        print('No image')
        '''
    
    print(f'{text}:ACC = {correct/total*100.}')
    print(f'{text}:balance ACC = {np.mean(accuracy_4)}')

    return correct/total*100.,np.mean(accuracy_4),accuracy_4

    #return f'{text}: Epoch {cur_epoch} :  Loss = {loss/total}   ACC = {correct/total*100.}',loss/total,correct/total*100.

train_acc,train_balanced_acc,_=eval_per_class(train_my_dataloader,model,' train_dataset',flag=0)
      print('val')
      val_acc,val_balanced_acc,_=eval_per_class(val_my_dataloader,model,' val_dataset',flag=1)
      print('val(test)')
      val_test_acc,val_test_balanced_acc,_=eval_per_class(val_test_my_dataloader,model,' val_test_dataset',flag=0)
      print('test')
      test_acc,test_balanced_acc,_=eval_per_class(test_my_dataloader,model,' test_dataset',flag=0)

