import nni
from get_filters import getFilters
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='Photo')

args = parser.parse_args()


end_test_acc_list=[]
for time in range(1):
    import numpy as np
    from sklearn.metrics import f1_score
    import tensorflow as tf
    import pickle
    import os
    from open_model import OPEN_Model
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'   
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU') 
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True) 
        
    data_path = './my_data/'+args.dataset+ '/'
    sum_epochs = 1000 # the sum of epochs
    early_stopping = 200 # early stopping
    """@nni.variable(nni.choice(0.00005,0.00001,0.0005,0.0001,0.005,0.001,0.05,0.01,0.1), name=learning_rate)"""  
    learning_rate = 0.001 # learning rate
    """@nni.variable(nni.uniform(0, 0.5), name=in_drop)"""
    in_drop = 0.2 # dropout
    L2 = 0.01 # L2 regularization
    features = np.load(data_path + 'feats.npy')
    features = features - np.mean(features, axis = -1, keepdims = True) 

    labels = np.load(data_path + 'labels.npy')
    num_classes=len(set(labels))
    features = tf.convert_to_tensor(features, dtype=tf.float32) 
    labels = tf.convert_to_tensor(labels, dtype=tf.float32)

    train_set = np.loadtxt(data_path + 'train_set', dtype = int)
    val_set = np.loadtxt(data_path + 'val_set', dtype = int)
    test_set = np.loadtxt(data_path + 'test_set', dtype = int)

    aggregators = None
    if  os.path.exists(data_path + 'filters.pkl'):
        
        with open(data_path + 'filters.pkl', 'rb') as f:
        
            aggregators = pickle.load(f)
    else:
        getFilters(data_path)
        with open(data_path + 'filters.pkl', 'rb') as f: 
            
            aggregators = pickle.load(f) 
    train_y_true = tf.gather(labels, train_set)
    val_y_true = tf.gather(labels, val_set)
    test_y_true = tf.gather(labels, test_set)
 
    model = OPEN_Model(layer_units = [516],  
num_classes = num_classes,
aggregators = aggregators,
dropout_func = tf.keras.layers.Dropout,
in_drop = in_drop, 
L2 = L2,
activation = tf.nn.relu)

    get_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True) 
    optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)
        
    train_loss = []
    train_accuracy = []
    val_accuracy = []
    test_accuracy = []
    best_val_acc = 0
    bad_counter = 0
    best_epoch = 0
    end_test_acc=0

    for epoch in range(sum_epochs):
        
        with tf.GradientTape() as tape:          
            y_pred_training = model(features, training = True)
            y_pred_training = tf.gather(y_pred_training, train_set)          
            loss = get_loss(y_true = train_y_true, y_pred = y_pred_training) 
            grads = tape.gradient(loss, model.trainable_variables) 
            optimizer.apply_gradients(zip(grads, model.trainable_variables))            
            logits = model(features, training = False)
            prediction = tf.argmax(logits, axis = 1, output_type=tf.int32)
            train_acc = f1_score(train_y_true, tf.gather(prediction, train_set).numpy(), average='micro')
            val_acc = f1_score(val_y_true, tf.gather(prediction, val_set).numpy(), average='micro')
            test_acc = f1_score(test_y_true, tf.gather(prediction, test_set).numpy(), average='micro')        
            loss = loss.numpy() 
            train_loss.append(loss)
            train_accuracy.append(train_acc)
            val_accuracy.append(val_acc)
            test_accuracy.append(test_acc)          
            print('epoch {:04d}: loss: {:.4f}, train acc: {:.4%}, val acc: {:.4%}, test acc: {:.4%}'.
                  format(epoch,loss,train_acc,val_acc,test_acc))

            if val_acc > best_val_acc: 
                end_test_acc=test_acc
                best_val_acc=val_acc
                best_epoch=epoch
                bad_counter = 0 
                
            else:
                bad_counter+=1
            if bad_counter == early_stopping:
                break
   
    log = 'best_Epoch: {:03d}, best_val_acc: {:.4f}, end_test_acc: {:.4f}'
    print('-------------------the_Epoch:',time,'-----------------------')
    print(log.format(best_epoch, best_val_acc, end_test_acc))
    end_test_acc_list.append(end_test_acc)
    model.summary()

   
print(end_test_acc_list)
a = np.asarray(end_test_acc_list)
last_acc=np.mean(a)
"""@nni.report_final_result(last_acc)"""
print(last_acc)

print('mean:',np.mean(a))
print('std:',np.std(a)) 
