import copy, time, traceback, logging, os, gc
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors

from ..constants import BINARY, MULTICLASS, REGRESSION, SOFTCLASS
from ..models.tabular_nn.tabular_nn_model import TabularNeuralNetModel
from ...metrics import mean_squared_error


logger = logging.getLogger(__name__)

EPS_bin2regress = 0.01 # truncate predicted probabilities to [EPS, 1-EPS] when converting binary problems -> regression

def format_distillation_labels(y_train, y_test, problem_type, num_classes=None):
    """ Transforms train/test label objects to the correct type for distillation. """
    if problem_type == MULTICLASS:
        y_train_int = y_train.to_numpy()
        y_train = np.zeros((y_train_int.size, num_classes))
        y_train[np.arange(y_train_int.size),y_train_int] = 1
        y_train = pd.DataFrame(y_train)
        y_test_int = y_test.to_numpy()
        y_test = np.zeros((y_test_int.size, num_classes))
        y_test[np.arange(y_test_int.size),y_test_int] = 1
        y_test = pd.DataFrame(y_test)
    elif problem_type == BINARY:
        min_pred = 0.0
        max_pred = 1.0
        y_train = EPS_bin2regress + ((1-2*EPS_bin2regress)/(max_pred-min_pred)) * (y_train - min_pred)
        y_test = EPS_bin2regress + ((1-2*EPS_bin2regress)/(max_pred-min_pred)) * (y_test - min_pred)

    return (y_train, y_test)

def augment_data(X_train, feature_types_metadata, augmentation_data=None, augment_method='spunge', augment_args={}):
    if augmentation_data is not None:
        X_aug = augmentation_data
    elif augment_method == 'spunge':
        X_aug = spunge_augment(X_train, feature_types_metadata, **augment_args)
    elif augment_method == 'munge':
        X_aug = munge_augment(X_train, feature_types_metadata, **augment_args)
    elif augment_method == 'gan':
        X_aug = gan_augment(X_train, feature_types_metadata, **augment_args)
    else:
        raise ValueError(f"unknown augment_method: {augment_method}")

    debug = True
    if debug:
        print("X_aug:", X_aug.head())
    return postprocess_augmented(X_aug, X_train)

def postprocess_augmented(X_aug, X):
    X_aug = pd.concat([X, X_aug])
    X_aug.drop_duplicates(keep='first', inplace=True)  # remove all duplicate points including those in original training data already.
    X_aug = X_aug.tail(len(X_aug)-len(X))
    print(f"Augmented training dataset with {len(X_aug)} extra datapoints (after removing duplicates)")
    return X_aug.reset_index(drop=True, inplace=False)

# To grid-search {frac_perturb,continuous_feature_noise}: call spunge_augment() many times and track validation score in Trainer.
def spunge_augment(X, feature_types_metadata, num_augmented_samples = 10000, frac_perturb = 0.1,
                   continuous_feature_noise = 0.1, **kwargs):
    """ Generates synthetic datapoints for learning to mimic teacher model in distillation
        via simplified version of MUNGE strategy (that does not require near-neighbor search).

        Args:
            num_augmented_samples: number of additional augmented data points to return
            frac_perturb: fraction of features/examples that are perturbed during augmentation. Set near 0 to ensure augmented sample distribution remains closer to real data.
            continuous_feature_noise: we noise numeric features by this factor times their std-dev. Set near 0 to ensure augmented sample distribution remains closer to real data.
    """
    if frac_perturb > 1.0:
        raise ValueError("frac_perturb must be <= 1")
    print(f"SPUNGE: Augmenting training data with {num_augmented_samples} synthetic samples for distillation...")
    num_feature_perturb = max(1, int(frac_perturb*len(X.columns)))
    X_aug = pd.concat([X.iloc[[0]].copy()]*num_augmented_samples)
    X_aug.reset_index(drop=True, inplace=True)
    continuous_types = ['float','int', 'datetime']
    continuous_featnames = [] # these features will have shuffled values with added noise
    for contype in continuous_types:
        if contype in feature_types_metadata:
            continuous_featnames += feature_types_metadata[contype]

    for i in range(num_augmented_samples): # hot-deck sample some features per datapoint
        og_ind = i % len(X)
        augdata_i = X.iloc[og_ind].copy()
        num_feature_perturb_i = np.random.choice(range(1,num_feature_perturb+1))  # randomly sample number of features to perturb
        cols_toperturb = np.random.choice(list(X.columns), size=num_feature_perturb_i, replace=False)
        for feature in cols_toperturb:
            feature_data = X[feature]
            augdata_i[feature] = feature_data.sample(n=1).values[0]
        X_aug.iloc[i] = augdata_i

    for feature in X.columns:
        if feature in continuous_featnames:
            feature_data = X[feature]
            aug_data = X_aug[feature]
            noise = np.random.normal(scale=np.nanstd(feature_data)*continuous_feature_noise, size=num_augmented_samples)
            mask = np.random.binomial(n=1, p=frac_perturb, size=num_augmented_samples)
            aug_data = aug_data + noise*mask
            X_aug[feature] = pd.Series(aug_data, index=X_aug.index)

    return X_aug


# Example: z = munge_augment(train_data[:100], trainer.feature_types_metadata, num_augmented_samples=25, s= 0.1, perturb_prob=0.9)
# To grid-search {p,s}: call munge_augment() many times and track validation score in Trainer.
def munge_augment(X, feature_types_metadata, num_augmented_samples = 10000, perturb_prob = 0.5,
                   s = 1.0, **kwargs):
    """ Use MUNGE to generate synthetic datapoints for learning to mimic teacher model in distillation.
        Args:
            num_augmented_samples: number of additional augmented data points to return
            perturb_prob: probability of perturbing each feature during augmentation. Set near 0 to ensure augmented sample distribution remains closer to real data.
            s: We noise numeric features by their std-dev divided by this factor (inverse of continuous_feature_noise). Set large to ensure augmented sample distribution remains closer to real data.
    """
    nn_dummy = TabularNeuralNetModel( path='nn_dummy', name='nn_dummy', problem_type=REGRESSION, objective_func=mean_squared_error,
                    hyperparameters={'num_dataloading_workers':0,'proc.embed_min_categories':np.inf}, features = list(X.columns))
    nn_dummy.feature_types_metadata = feature_types_metadata
    processed_data = nn_dummy.process_train_data(nn_dummy.preprocess(X), pd.Series([1]*len(X)))
    X_vector = processed_data.dataset._data[processed_data.vectordata_index].asnumpy()
    processed_data = None
    nn_dummy = None
    gc.collect()

    neighbor_finder = NearestNeighbors(n_neighbors=2)
    neighbor_finder.fit(X_vector)
    neigh_dist, neigh_ind = neighbor_finder.kneighbors(X_vector)
    neigh_ind = neigh_ind[:,1]  # contains indices of nearest neighbors
    neigh_dist = None
    # neigh_dist = neigh_dist[:,1]  # contains distances to nearest neighbors
    neighbor_finder = None
    gc.collect()

    if perturb_prob > 1.0:
        raise ValueError("frac_perturb must be <= 1")
    print(f"MUNGE: Augmenting training data with {num_augmented_samples} synthetic samples for distillation...")
    X = X.copy()
    X_aug = pd.concat([X.iloc[[0]].copy()]*num_augmented_samples)
    X_aug.reset_index(drop=True, inplace=True)
    continuous_types = ['float','int', 'datetime']
    continuous_featnames = [] # these features will have shuffled values with added noise
    for contype in continuous_types:
        if contype in feature_types_metadata:
            continuous_featnames += feature_types_metadata[contype]
    for col in continuous_featnames:
        X_aug[col] = X_aug[col].astype(float)
        X[col] = X[col].astype(float)
    """
    column_list = X.columns.tolist()
    numer_colinds = [j for j in range(len(column_list)) if column_list[j] in continuous_featnames]
    categ_colinds = [j for j in range(len(column_list)) if column_list[j] not in continuous_featnames]
    numer_std_devs = [np.std(X.iloc[:,j]) for j in numer_colinds]  # list whose jth element = std dev of the jth numerical feature
    """
    for i in range(num_augmented_samples):
        og_ind = i % len(X)
        augdata_i = X.iloc[og_ind].copy()
        neighbor_i = X.iloc[neigh_ind[og_ind]].copy()
        # dist_i = neigh_dist[og_ind]
        cols_toperturb = np.random.choice(list(X.columns), size=np.random.binomial(X.shape[1], p=perturb_prob, size=1)[0], replace=False)
        for col in cols_toperturb:
            new_val = neighbor_i[col]
            if col in continuous_featnames:
                new_val += np.random.normal(scale=np.abs(augdata_i[col]-new_val)/s)
            augdata_i[col] = new_val
        X_aug.iloc[i] = augdata_i

    return X_aug

def nearest_neighbor(numer_i, categ_i, numer_candidates, categ_candidates):
    """ Returns tuple (index, dist) of nearest neighbor point in the list of candidates (pd.DataFrame) to query point i (pd.Series).
        Uses Euclidean distance for numerical features, Hamming for categorical features.
    """
    from sklearn.metrics.pairwise import paired_euclidean_distances
    dists = paired_euclidean_distances(numer_i.to_numpy(), numer_candidates.to_numpy())
    return (index, distance)

def gan_augment(X, feature_types_metadata, num_augmented_samples=10000, epochs=300, **kwargs):
    """ Augments data using CTGAN from here: https://github.com/sdv-dev/CTGAN
        Need to do: pip install ctgan
        Only after autogluon has been installed, since ctgan depends on newer sklearn > 0.21
        Args:
            num_augmented_samples: number of additional augmented data points to return.
            epochs: Number of epochs to run GAN training (authors use 300 by default)
    """
    print(f"GAN: Augmenting training data with {num_augmented_samples} synthetic samples for distillation...")
    try:
        import ctgan
    except ImportError as e:
        raise ImportError(f"Error importing ctgan package: {e}. To use gan_augment(), please install ctgan via:  pip install ctgan")

    from ctgan import CTGANSynthesizer

    # impute missing values:
    continuous_types = ['float','int', 'datetime']
    continuous_featnames = [] # these features will have shuffled values with added noise
    for contype in continuous_types:
        if contype in feature_types_metadata:
            continuous_featnames += feature_types_metadata[contype]

    discrete_columns = [col for col in list(X.columns) if col not in continuous_featnames]
    X = X.copy()
    if len(continuous_featnames) > 0:
        for feat in continuous_featnames:
            X.loc[:,feat] = X[feat].fillna(X[feat].mean())
    if len(discrete_columns) > 0:
        for feat in discrete_columns:
            X.loc[:,feat] = X[feat].fillna(X[feat].mode()[0])

    ctgan = CTGANSynthesizer()
    ctgan.fit(X, discrete_columns, epochs=epochs)
    X_aug = ctgan.sample(num_augmented_samples)
    return X_aug



# OLD:
def augment_data_hotdeck(X, feature_types_metadata, num_augmented_samples = 50000, continuous_feature_noise = 0.1):
    """ Generates synthetic datapoints for learning to mimic teacher model in distillation.
        num_augmented_samples: number of total augmented data points to return (we add extra points to training set until this number is reached).
        continuous_feature_noise: we noise numeric features by this factor times their std-dev.
        These data are independent samples from the marginal distribution of each feature
    """
    if len(X) > num_augmented_samples:
        print("No data augmentation performed since training data is large enough.")
        return X
    num_augmented_samples = num_augmented_samples - len(X)

    X_aug = pd.concat([X.iloc[[0]]]*num_augmented_samples)
    X_aug.reset_index(drop=True, inplace=True)
    continuous_types = ['float','int', 'datetime']
    continuous_featnames = [] # these features will have shuffled values with added noise
    for contype in continuous_types:
        if contype in feature_types_metadata:
            continuous_featnames += feature_types_metadata[contype]
    for feature in X.columns:
        feature_data = X[feature]
        new_feature_data = feature_data.sample(n=num_augmented_samples, replace=True)
        new_feature_data.reset_index(drop=True, inplace=True)
        if feature in continuous_featnames:
            noise = np.random.normal(scale=np.std(feature_data)*continuous_feature_noise, size=num_augmented_samples)
            new_feature_data = new_feature_data + noise
        X_aug[feature] = pd.Series(new_feature_data, index=X_aug.index)
    X_aug.drop_duplicates(keep='first', inplace=True)
    # print(X_aug)
    X_aug = pd.concat([X_aug, X])
    X_aug.reset_index(drop=True, inplace=True)
    print("Augmented training dataset has %s datapoints" % X_aug.shape[0])
    return X_aug

# TODO: experimental code.


def passdata_gibbs(trainer, outputdir=''):
    """ Note: outputdir needs to end in '/'. """
    X_train = trainer.load_X_train()  # these are deterministic
    X_val = trainer.load_X_val()
    Y_train = trainer.load_y_train()
    Y_val = trainer.load_y_val()
    continuous_types = ['float','int', 'datetime']
    continuous_featnames = []
    categ_featnames = []
    for contype in continuous_types:
        if contype in trainer.feature_types_metadata:
            continuous_featnames += trainer.feature_types_metadata[contype]
        if len(continuous_featnames) == X_train.shape[1]:
            all_numerical = True
        else:
            all_numerical = False

    numer_feats = pd.Series(continuous_featnames)
    save_pd.save(path=outputdir + 'Xtrain.csv', df=X_train)
    save_pd.save(path=outputdir  + 'Xval.csv', df=X_val)
    Y_train.to_csv(outputdir + 'Ytrain.csv', index=False)
    Y_val.to_csv(outputdir  + 'Yval.csv', index=False)
    numer_feats.to_csv(outputdir + 'numericalfeatures.csv', index=False)
    if all_numerical:
        open(outputdir + 'all_features_numerical.txt', 'a').close()
    else:
        open(outputdir + 'notall_features_numerical.txt', 'a').close()

    print(f"Data saved to: {outputdir}")



def OLDaugment_trade(self, X = None):
    if X is None:
        X = self.load_X_train()
    # Convert X to numbers:
    continuous_types = ['float','int', 'datetime']
    continuous_featnames = [] # these features will have shuffled values with added noise
    for contype in continuous_types:
        if contype in self.feature_types_metadata:
            continuous_featnames += self.feature_types_metadata[contype]
    # Convert categoricals to int:
    feature_levels = {}
    for feature in X.columns:
        if feature not in continuous_featnames:
            feature_levels[feature] = {}
            feature_vals = X[feature].copy()
            feat_categories = sorted(list(feature_vals.unique()))
            for j in range(len(feat_categories)):
                feat_category_j = feat_categories[j]
                feature_levels[feature][feat_category_j] = j
            X.loc[:,feature] = pd.Series(feature_vals.map(feature_levels[feature]), index = X.index)
        feature_data = X[feature]
    # Save X:
    X.to_csv("data4trade.csv", index=False)
    import pickle
    pickle.dump(continuous_featnames, open("continuous_features.p", "wb") )
    """
    # Script to reload from this file for training TRADE generative model:
    import pandas as pd
    import pickle
    X = pd.read_csv("data4trade.csv")
    continuous_featnames = pickle.load( open( "continuous_features.p", "rb" ) )
    # Num categories for a categorical feature FEAT: len(X[FEAT].unique())
    num_categories = {}
    for feat in X.columns:
        if feat not in continuous_featnames:
            num_categories[feat] = len(X[feat].unique())
    X = X.to_numpy()
    """

def augmented_filter(self, X_aug, y_aug, X_real, y_real):
    """ Filters out certain points from the augmented dataset so that it better matches the real data """
    indices_to_drop = []
    y_aug_hard = pd.Series(get_pred_from_proba(y_aug, problem_type=self.problem_type))
    if self.problem_type == MULTICLASS:
        y_aug = pd.DataFrame(y_aug)
    else:
        y_aug = pd.Series(y_aug)

    # X_aug.reset_index(drop=True, inplace=True)
    # y_aug.reset_index(drop=True, inplace=True)
    # y_aug_hard.reset_index(drop=True, inplace=True)

    if self.problem_type in [MULTICLASS, BINARY]:
        p_y = y_real.value_counts(sort=False).sort_index()/len(y_real)
        desired_class_cnts = p_y * len(y_aug_hard)
        y_aug_cnts = y_aug_hard.value_counts(sort=False).sort_index()
        if len(y_aug_cnts) != len(p_y):  # some classes were never predicted, so cannot match label distributions
            return X_aug, y_aug

        scaling = np.min(y_aug_cnts/desired_class_cnts)
        desired_class_cnts = scaling *desired_class_cnts
        desired_class_cnts = np.maximum(np.floor(desired_class_cnts), 1).astype(int)
        num_to_drop = np.maximum(0, y_aug_cnts - desired_class_cnts)
        for clss in p_y.index.values.tolist():
            if clss in y_aug_cnts.index:
                print("clss",clss)
                clss_inds = y_aug_hard[y_aug_hard == clss].index.tolist()
                print("clss_inds",clss_inds)
                indices_to_drop += clss_inds[:num_to_drop[clss]]

    if len(indices_to_drop) == 0:
        return X_aug, y_aug

    y_aug.drop(indices_to_drop, inplace=True)
    y_aug.reset_index(drop=True, inplace=True)
    X_aug.drop(indices_to_drop, inplace=True)
    X_aug.reset_index(drop=True, inplace=True)
    print(f"Augmented training dataset has {len(X_aug)} datapoints after augmented_filter")
    return X_aug, y_aug


