
"""Tests for lattice estimators."""
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_lattice as tfl
from sklearn.model_selection import KFold
import itertools

column_names = ['Age','WorkClass','EducationNum','MaritalStatus','Occupation','Relationship','Race','Sex','CapitalGain','CapitalLoss','HoursPerWeek','NativeCountry'] 

train_evaluations = []
test_evaluations = []
folds = 1

for fold in range(folds):
    train_data_path = 'adult/'+str(fold)+'/train_data.csv'
    test_data_path  = 'adult/'+str(fold)+'/test_data.csv'
    train_dataset = pd.read_csv(train_data_path,index_col=0)
    test_dataset = pd.read_csv(test_data_path,index_col=0)

    # Example training and testing data.
    train_features = {
        'Age': np.array(train_dataset['Age']),
        'WorkClass': np.array(train_dataset['WorkClass']),
        'EducationNum': np.array(train_dataset['EducationNum']),
        'MaritalStatus': np.array(train_dataset['MaritalStatus']),
        'Occupation':np.array(train_dataset['Occupation']),
        'Relationship': np.array(train_dataset['Relationship']),
        'Race':np.array(train_dataset['Race']),
        'Sex':np.array(train_dataset['Sex']),
        'CapitalGain':np.array(train_dataset['CapitalGain']),
        'CapitalLoss':np.array(train_dataset['CapitalLoss']),
        'HoursPerWeek':np.array(train_dataset['HoursPerWeek']),
        'NativeCountry':np.array(train_dataset['NativeCountry']),
    }

    test_features = {
        'Age': np.array(test_dataset['Age']),
        'WorkClass': np.array(test_dataset['WorkClass']),
        'EducationNum': np.array(test_dataset['EducationNum']),
        'MaritalStatus': np.array(test_dataset['MaritalStatus']),
        'Occupation':np.array(test_dataset['Occupation']),
        'Relationship': np.array(test_dataset['Relationship']),
        'Race':np.array(test_dataset['Race']),
        'Sex':np.array(test_dataset['Sex']),
        'CapitalGain':np.array(test_dataset['CapitalGain']),
        'CapitalLoss':np.array(test_dataset['CapitalLoss']),
        'HoursPerWeek':np.array(test_dataset['HoursPerWeek']),
        'NativeCountry':np.array(test_dataset['NativeCountry']),
    }

    train_labels = np.array(train_dataset['IncomeRange'])
    test_labels =  np.array(test_dataset['IncomeRange'])

    # Feature definition.
    feature_columns = [
        tf.feature_column.numeric_column('Age'),
        tf.feature_column.numeric_column('WorkClass'),
        tf.feature_column.numeric_column('EducationNum'),
        tf.feature_column.numeric_column('MaritalStatus'),
        tf.feature_column.numeric_column('Occupation'),
        tf.feature_column.numeric_column('Relationship'),
        tf.feature_column.numeric_column('Race'),
        tf.feature_column.numeric_column('Sex'),
        tf.feature_column.numeric_column('CapitalGain'),
        tf.feature_column.numeric_column('CapitalLoss'),
        tf.feature_column.numeric_column('HoursPerWeek'),
        tf.feature_column.numeric_column('NativeCountry'),
    ]
    
    # Grid search
    key_points = [10, 50, 100]
    rates = [0.1, 0.01, 0.001]
    batch_sizes = [16, 32, 64,128,256,512,1024]
    epochs = [50,100, 400,500,800, 1000,1500,2000]

    kf = KFold(n_splits = 5, shuffle = True, random_state = 2)
    best_parameters = [0,0,0,0]
    best_evaluations = 0
    
    for k,r,b,e in itertools.product(key_points,rates, batch_sizes,epochs):
        evaluations= []
        index = 0
        for train_index, test_index in kf.split(train_features):
            index+=1
            if index<=5:
                train_X = {}
                for key,values in train_features.items():
                    train_X[key] = train_features[key][train_index]
                train_Y = train_labels[train_index]
                test_X =  {}
                for key,values in train_features.items():
                    test_X[key] = train_features[key][test_index]
                test_Y =  train_labels[test_index]


                # Hyperparameters.
                num_keypoints = k
                hparams = tfl.CalibratedLatticeHParams(
                    feature_names=['Age','WorkClass','EducationNum','MaritalStatus','Occupation','Relationship','Race','Sex','CapitalGain','CapitalLoss','HoursPerWeek','NativeCountry'],
                    num_keypoints=num_keypoints,
                    learning_rate=r,
                )

                # Set feature monotonicity.
                hparams.set_feature_param('HoursPerWeek', 'monotonicity', +1)
                hparams.set_feature_param('CapitalGain', 'monotonicity', +1)



                # Define keypoint init.
                keypoints_init_fns = {
                    'Age': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=100.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'WorkClass': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=10.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'EducationNum': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=20.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'MaritalStatus': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=10.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0), 
                    'Occupation': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=20.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'Relationship': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=5.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'Race': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=4.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),    
                    'Sex': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=1.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),   
                    'CapitalGain': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=100000.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),    
                    'CapitalLoss': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=4000.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),  
                    'HoursPerWeek': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=1.0,
                                                                        input_max=100.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),    
                    'NativeCountry': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=40.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),   
                }


                lattice_estimator = tfl.calibrated_linear_classifier(
                    feature_columns=feature_columns,
                    hparams=hparams,
                    keypoints_initializers_fn=keypoints_init_fns)

                # Train!
                train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
                    x=train_X,
                    y=train_Y,
                    batch_size=b,
                    num_epochs=e,
                    shuffle=False)

                lattice_estimator.train(input_fn=train_input_fn)
                # Test-grid
                test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(x=test_X, y=test_Y, batch_size=b, num_epochs=e, shuffle=False)
                evaluation = lattice_estimator.evaluate(input_fn=test_input_fn)
                #print(evaluation)
                evaluations.append(evaluation['average_loss'])
            
        result = np.mean(evaluations)
        #result = evaluation['average_loss']
        #print(result)
        if result>best_evaluations:
            best_evaluations = result
            best_parameters = [k,r,b,e]
        print(best_parameters)  
        

    #best_parameters = [100, 0.1, 256, 1000]
    print(best_parameters)
    [k,r,b,e] = best_parameters   
    
    
    # Hyperparameters.
    num_keypoints = k
    hparams = tfl.CalibratedLatticeHParams(
        feature_names=['Age','WorkClass','EducationNum','MaritalStatus','Occupation','Relationship','Race','Sex','CapitalGain','CapitalLoss','HoursPerWeek','NativeCountry'],
        num_keypoints=num_keypoints,
        non_monotonic_num_lattices=1,
        non_monotonic_lattice_rank=1,
        non_monotonic_lattice_size=2,
        learning_rate=r,
    )

    # Set feature monotonicity.
    hparams.set_feature_param('HoursPerWeek', 'monotonicity', +1)
    hparams.set_feature_param('CapitalGain', 'monotonicity', +1)


    # Define keypoint init.
    keypoints_init_fns = {
        'Age': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=100.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'WorkClass': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=10.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'EducationNum': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=20.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'MaritalStatus': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=10.0,
                                                            output_min=0.0,
                                                            output_max=1.0), 
        'Occupation': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=20.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'Relationship': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=5.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'Race': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=4.0,
                                                            output_min=0.0,
                                                            output_max=1.0),    
        'Sex': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=1.0,
                                                            output_min=0.0,
                                                            output_max=1.0),   
        'CapitalGain': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=100000.0,
                                                            output_min=0.0,
                                                            output_max=1.0),    
        'CapitalLoss': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=4000.0,
                                                            output_min=0.0,
                                                            output_max=1.0),  
        'HoursPerWeek': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=1.0,
                                                            input_max=100.0,
                                                            output_min=0.0,
                                                            output_max=1.0),    
        'NativeCountry': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=40.0,
                                                            output_min=0.0,
                                                            output_max=1.0),   
    }
    lattice_estimator = tfl.calibrated_linear_classifier(
        feature_columns=feature_columns,
        hparams=hparams,
        keypoints_initializers_fn=keypoints_init_fns)

     # Train-grid
    train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        x=train_features,
        y=train_labels,
        batch_size=b,
        num_epochs=e,
        shuffle=False)

    lattice_estimator.train(input_fn=train_input_fn)
            
    # Test
    test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        x=test_features, y=test_labels, batch_size=b, num_epochs=e, shuffle=False)

    # Estimate
    train_estimate = lattice_estimator.evaluate(input_fn=train_input_fn)
    test_estimate = lattice_estimator.evaluate(input_fn=test_input_fn)
    print(test_estimate)
    train_evaluations.append(train_estimate['accuracy'])
    test_evaluations.append(test_estimate['accuracy'])  


mean_train = np.mean(train_evaluations)
print('mean_train: '+str(mean_train))
std_train = np.std(train_evaluations)
print('std_train: '+str(std_train))
mean_test = np.mean(test_evaluations)
print('mean_test: '+str(mean_test))
std_test = np.std(test_evaluations)
print('std_test: '+str(std_test))

