
"""Tests for lattice estimators."""
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_lattice as tfl
from sklearn.model_selection import KFold
import itertools

column_names = ['MPG','Cylinders','Displacement','Horsepower','Weight',
                    'Acceleration', 'Model Year', 'Origin'] 

train_evaluations = []
test_evaluations = []
folds = 3
for fold in range(folds):
    train_data_path = 'auto-mpg/'+str(fold)+'/train_data.csv'
    test_data_path  = 'auto-mpg/'+str(fold)+'/test_data.csv'
    train_dataset = pd.read_csv(train_data_path,index_col=0)
    test_dataset = pd.read_csv(test_data_path,index_col=0)

    # Example training and testing data.
    train_features = {
        'Cylinders': np.array(train_dataset['Cylinders']),
        'Displacement': np.array(train_dataset['Displacement']),
        'Horsepower': np.array(train_dataset['Horsepower']),
        'Weight': np.array(train_dataset['Weight']),
        'Acceleration':np.array(train_dataset['Acceleration']),
        'ModelYear': np.array(train_dataset['ModelYear']),
        'Origin':np.array(train_dataset['Origin']),
    }

    test_features = {
        'Cylinders': np.array(test_dataset['Cylinders']),
        'Displacement': np.array(test_dataset['Displacement']),
        'Horsepower': np.array(test_dataset['Horsepower']),
        'Weight': np.array(test_dataset['Weight']),
        'Acceleration':np.array(test_dataset['Acceleration']),
        'ModelYear': np.array(test_dataset['ModelYear']),
        'Origin':np.array(test_dataset['Origin']),
    }

    train_labels = np.array(train_dataset['MPG'])
    test_labels =  np.array(test_dataset['MPG'])


    # Feature definition.
    feature_columns = [
        tf.feature_column.numeric_column('Cylinders'),
        tf.feature_column.numeric_column('Displacement'),
        tf.feature_column.numeric_column('Horsepower'),
        tf.feature_column.numeric_column('Weight'),
        tf.feature_column.numeric_column('Acceleration'),
        tf.feature_column.numeric_column('ModelYear'),
        tf.feature_column.numeric_column('Origin'),
    ]


    # Grid search
    key_points = [10, 50, 100]
    rates = [0.01,0.001]
    batch_sizes = [32, 64, 128]
    epochs = [1000, 1500, 2000]
    
    kf = KFold(n_splits = 5, shuffle = True, random_state = 2)
    best_parameters = [0,0,0,0]
    best_evaluations = 1000
    
    for k,r,b,e in itertools.product(key_points,rates, batch_sizes,epochs):
        evaluations= []
        index = 0
        for train_index, test_index in kf.split(train_features):
            index+=1
            if index<=3:
                train_X = {}
                for key,values in train_features.items():
                    train_X[key] = train_features[key][train_index]
                train_Y = train_labels[train_index]
                test_X =  {}
                for key,values in train_features.items():
                    test_X[key] = train_features[key][test_index]
                test_Y =  train_labels[test_index]
        
                # Hyperparameters.
                num_keypoints = k
                hparams = tfl.CalibratedLatticeHParams(
                    feature_names=['Cylinders','Displacement','Horsepower','Weight','Acceleration', 'ModelYear', 'Origin'],
                    num_keypoints=num_keypoints,
                    learning_rate=r,
                )

                # Set feature monotonicity.
                hparams.set_feature_param('Displacement', 'monotonicity', -1)
                hparams.set_feature_param('Weight', 'monotonicity', -1)
                hparams.set_feature_param('Horsepower', 'monotonicity', -1)

                # Define keypoint init.
                keypoints_init_fns = {
                    'Cylinders': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=10.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'Displacement': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=500.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'Horsepower': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=300.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'Weight': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=6000.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0), 
                    'Acceleration': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=30.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'ModelYear': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=2500.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'Origin': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=1.0,
                                                                        input_max=3.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),                                    
                }

                lattice_estimator = tfl.calibrated_lattice_regressor(
                    feature_columns=feature_columns,
                    hparams=hparams,
                    keypoints_initializers_fn=keypoints_init_fns)

                
                # Train-grid
                train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
                    x=train_X,
                    y=train_Y,
                    batch_size=b,
                    num_epochs=e,
                    shuffle=False)

                lattice_estimator.train(input_fn=train_input_fn)
                # Test-grid
                test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(x=test_X, y=test_Y, batch_size=b, num_epochs=e, shuffle=False)
                evaluation = lattice_estimator.evaluate(input_fn=test_input_fn)
                #print(evaluation)
                evaluations.append(evaluation['average_loss'])
            
        result = np.mean(evaluations)
        #result = evaluation['average_loss']
        #print(result)
        if result<best_evaluations:
            best_evaluations = result
            best_parameters = [k,r,b,e]
        print(best_parameters) 
      
    #best_parameters = [10,0.1,32,100]
    print(best_parameters)
    [k,r,b,e] = best_parameters
    
    # Hyperparameters.
    num_keypoints = k
    hparams = tfl.CalibratedLatticeHParams(
        feature_names=['Cylinders','Displacement','Horsepower','Weight','Acceleration', 'ModelYear', 'Origin'],
        num_keypoints=num_keypoints,
        learning_rate=r,
    )

    # Set feature monotonicity.
    hparams.set_feature_param('Displacement', 'monotonicity', -1)
    hparams.set_feature_param('Weight', 'monotonicity', -1)
    hparams.set_feature_param('Horsepower', 'monotonicity', -1)
    

    # Define keypoint init.
    keypoints_init_fns = {
        'Cylinders': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=10.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'Displacement': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=500.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'Horsepower': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=300.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'Weight': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=6000.0,
                                                            output_min=0.0,
                                                            output_max=1.0), 
        'Acceleration': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=30.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'ModelYear': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=2500.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'Origin': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=1.0,
                                                            input_max=3.0,
                                                            output_min=0.0,
                                                            output_max=1.0),                                    
    }
    
    lattice_estimator = tfl.calibrated_lattice_regressor(
        feature_columns=feature_columns,
        hparams=hparams,
        keypoints_initializers_fn=keypoints_init_fns)

     # Train-grid
    train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        x=train_features,
        y=train_labels,
        batch_size=b,
        num_epochs=e,
        shuffle=False)

    lattice_estimator.train(input_fn=train_input_fn)
            
    # Test
    test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        x=test_features, y=test_labels, batch_size=b, num_epochs=e, shuffle=False)

    # Estimate
    train_estimate = lattice_estimator.evaluate(input_fn=train_input_fn)
    test_estimate = lattice_estimator.evaluate(input_fn=test_input_fn)
    train_evaluations.append(train_estimate['average_loss'])
    test_evaluations.append(test_estimate['average_loss'])  


mean_train = np.mean(train_evaluations)
print('mean_train: '+str(mean_train))
std_train = np.std(train_evaluations)
print('std_train: '+str(std_train))
mean_test = np.mean(test_evaluations)
print('mean_test: '+str(mean_test))
std_test = np.std(test_evaluations)
print('std_test: '+str(std_test))

