#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
from parameters_chexpert import para
import torch

from logistic_torch_base import Logistic_AUC_base

kvalue1 = 10000
kvalue2 =100000
train_size = 0.5
val_size = 0.25
dataname = "chexpert"
# dataname = "a9a"

X = torch.from_numpy(np.load('./dataset/chexpert/CheXpert_train_hidden_features_all.npy'))
y = torch.from_numpy(np.load('./dataset/chexpert/CheXpert_train_labels_all.npy'))
y[y <= 0] = 0

# Model_name = "HingeAUC"
Model_name = "Logistic"
num_dis = para.num_dis

classifier = Logistic_AUC_base(lr_0=para.lr_0, num_iter=para.num_iter, T0=para.T0, batch=para.batch, c=para.c, k_value = kvalue2, k2_value = kvalue1, seed =para.seed, dataname = para.dataname, Model_name=Model_name)

classifier.fit(X, y[:,num_dis-1], X, y[:,num_dis-1], X ,y[:,num_dis-1])

data_pass_dc = classifier.data_pass
time_dc = classifier.time_list
loss_dc = classifier.loss_list
auc_list = classifier.pauc_list

w_1 = classifier.theta.cpu()
w_2 = classifier.theta.cpu()

# save w1,w2
w1 = w_1.numpy()
w2 = w_2.numpy()

df1 = pd.DataFrame(data_pass_dc, columns=['data_pass'])
df2 = pd.DataFrame(loss_dc, columns=['loss'])
df3 = pd.DataFrame(w1.T,columns=['w1'])
df4 = pd.DataFrame(w2.T,columns=['w2'])
df5 = pd.DataFrame(auc_list, columns=['pAUC'])
df6 = pd.DataFrame(time_dc, columns=['time'])

d1 = df3.join(df4)
d2 = d1.join(df1)
d3 = d2.join(df2)
d4 = d3.join(df5)
d5 = d4.join(df6)
# s = str(para.dataname)+"_lr0="+str(para.lr_0)+"_lr_out="+str(para.lr_outer)+"_mu="+str(para.mu)+"_T0="+str(para.T0)+"_numStages="+str(para.num_iter)+"_margin="+str(para.margin)+"_seed="+str(para.seed)+"_random.csv";
# d5.to_csv(s)
s = str(para.dataname)+"_"+str(num_dis)+"_"+str(Model_name)+"_lr0="+str(para.lr_0)+"_lr_out="+str(para.lr_outer)+"_mu="+str(para.mu)+"_T0="+str(para.T0)+"_numStages="+str(para.num_iter)+"_c="+str(para.c)+"_seed="+str(para.seed)+"_random_baseline.csv";
d5.to_csv(s)