import argparse
import datetime

import numpy as np
import pandas as pd
import tensorflow as tf

from relational_erm.data_cleaning.pokec import load_data_pokec, process_pokec_attributes
from relational_erm.sampling import adapters, factories
from relational_erm.data_cleaning.wikipedia import load_data_wikipedia_processed
from scipy.special import expit
import numpy as np
import numpy.random as npr
import pandas as pd
import scipy.stats as stats
import os
from fnmatch import fnmatch
from sklearn.linear_model import LogisticRegression, LinearRegression
from xgboost import XGBClassifier
import statsmodels.api as sm

def add_parser_sampling_arguments(parser=None):
    if parser is None:
        parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--max-steps', type=int, default=1000)
    parser.add_argument('--proportion-censored', type=float, default=0.5,
                        help='proportion of censored vertex labels at train time.')
    parser.add_argument('--batch-size', type=int, default=100, help='minibatch size')
    parser.add_argument('--dataset-shards', type=int, default=None, help='dataset parallelism')
    parser.add_argument('--sampler', type=str, default=None, choices=factories.dataset_names(),
                        help='The sampler to use. biased-walk gives a skipgram random-walk with unigram negative '
                             'sampling; p-sampling gives p-sampling with unigram negative sampling; uniform-edge '
                             'gives uniform edge sampling with unigram negative sampling; biased-walk-induced-uniform '
                             'gives induced random-walk with unigram negative-sampling; p-sampling-induced gives '
                             'p-sampling with induced non-edges.')
    parser.add_argument('--sampler-test', type=str, default=None,
                        choices=factories.dataset_names(),
                        help='if not None, the sampler to use for testing')
    parser.add_argument('--embedding-dim', type=int, default=128,
                        help='Number of dimensions. Default is 128.')
    parser.add_argument('--indef-ip', default=False, action='store_true',
                        help='Uses a krein inner product instead of the regular inner product.')
    parser.add_argument('--num-edges', type=int, default=800,
                        help='Number of edges per sample.')
    parser.add_argument('--p-sample-prob', type=float, default=None,
                        help='Probability of sampling a vertex for p-sampling. Only used if the sampling scheme is a '
                             'p-sampling scheme, in which case this is used to override the num-edges argument.')
    parser.add_argument('--window-size', type=int, default=10,
                        help='Context size for optimization. Default is 10.')
    parser.add_argument('--num-negative', type=int, default=5,
                        help='negative examples per vertex for negative sampling')
    parser.add_argument('--num-negative-total', type=int, default=None,
                        help='total number of negative vertices sampled')
    parser.add_argument('--embedding_learning_rate', type=float, default=0.025,
                        help='SGD learning rate for embedding updates.')
    parser.add_argument('--global_learning_rate', type=float, default=1.,
                        help='SGD learning rate for global updates.')
    parser.add_argument('--global_regularization', type=float, default=1.,
                        help='Regularization scale for global variables.')
    return parser


###Several functions for getting the dataset in the right tf.data.Dataset format
def compose(*fns):
    """ Composes the given functions in reverse order.
    Parameters
    ----------
    fns: the functions to compose
    Returns
    -------
    comp: a function that represents the composition of the given functions.
    """
    import functools

    def _apply(x, f):
        if isinstance(x, tuple):
            return f(*x)
        else:
            return f(x)

    def comp(*args):
        return functools.reduce(_apply, fns, args)

    return comp


def get_dataset_fn(sampler, args):
    if sampler is None:
        sampler = 'biased-walk'

    return factories.make_dataset(sampler, args)


def make_test_fn(graph_data, args, treatments, outcomes, dataset_fn=None, num_samples=None, is_test=False):
    def input_fn():

        dataset = dataset_fn(graph_data, args.seed)

        data_processing = adapters.compose(
            adapters.relabel_subgraph(),
            adapters.append_vertex_labels(treatments, 'treatment'),
            adapters.append_vertex_labels(outcomes, 'outcome'),
            adapters.make_split_vertex_labels(
                graph_data.num_vertices, args.proportion_censored,
                np.random.RandomState(args.seed), is_test, is_pred=False),
            adapters.add_sample_size_info(),
            adapters.format_features_labels1())

        dataset = dataset.map(data_processing, 8)
        if num_samples is not None:
            dataset = dataset.take(num_samples)

        batch_size = args.batch_size
        num_edges = args.num_edges

        if batch_size is not None:
            dataset = dataset.apply(
                adapters.padded_batch_samples_supervised(batch_size, n_edges_max=num_edges * 2,
                                                         n_vertices_max=num_edges,
                                                         t_dtype=treatments.dtype, o_dtype=outcomes.dtype))
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        return dataset

    return input_fn


def make_no_graph_input_fn0(graph_data, args, treatments, outcomes, filter_test=False):
    """
    A dataset w/ all the label processing, but no graph structure.
    Used at evaluation and prediction time

    """

    def input_fn():
        metadata = {'edge_list': tf.expand_dims(np.zeros((graph_data.num_vertices, 2)), 1),
                    'vertex_index': tf.expand_dims(np.array(range(graph_data.num_vertices)), 1),
                    'treatment': tf.expand_dims(np.ones(graph_data.num_vertices), 1),
                    'weights': graph_data.weights,
                    'is_positive': tf.expand_dims(np.ones(graph_data.num_vertices), 1)}

        num_samples = graph_data.num_vertices

        def gen():
            for i in range(num_samples):
                ls = {}
                for key, val in metadata.items():
                    ls[key] = val[i]
                yield ls

        dataset = tf.data.Dataset.from_generator(gen, output_types={k: tf.int64 for k in metadata})
        # dataset = dataset.shuffle(buffer_size=num_samples, reshuffle_each_iteration=True)
        data_processing = adapters.compose(
            adapters.append_vertex_labels(treatments, 'treatment'),
            adapters.append_vertex_labels(metadata['weights'], 'weights'),
            adapters.append_vertex_labels(outcomes, 'outcome'),
            adapters.make_split_vertex_labels(
                graph_data.num_vertices, args.proportion_censored,
                np.random.RandomState(args.seed), is_test=filter_test, is_pred=True),
            adapters.format_features_labels())

        dataset = dataset.map(data_processing, 8)

        if filter_test:
            def filter_test_fn(features, labels):
                return tf.equal(tf.squeeze(features['in_test']), 1)

            dataset = dataset.filter(filter_test_fn)
        batch_size = 78982 #args.num_edges
        dataset = dataset.batch(batch_size=batch_size, drop_remainder=False)

        return dataset

    return input_fn


### SIMULATING TREATMENT/OUTCOME VARIABLES:

def sigmoid(x):
    z = np.exp(-x)
    sig = 1 / (1 + z)

    return sig


def simulate_y(propensities, treatment, beta0=1.0, beta1=1.0, gamma=1.0, set_seed=42):
    confounding = (propensities - 0.5).astype(np.float32)
    np.random.seed(set_seed)
    noise = np.random.normal(0., 1., size=propensities.shape[0]).astype(np.float32)

    y0 = beta1 * confounding
    y1 = beta0 * treatment + y0
    y = y1 + gamma * noise

    return y, y0, y1


def simulate_y_binary(propensities, treatment, beta0=1.0, beta1=1.0, gamma=1.0, set_seed=42):
    confounding = (propensities - 0.5).astype(np.float32)
    np.random.seed(set_seed)
    noise = np.random.normal(0., 1., size=propensities.shape[0]).astype(np.float32)
    y0 = beta1 * confounding
    y1 = beta0 * treatment + y0
    y = y1 + gamma * noise
    y = np.random.binomial(1, sigmoid(y))
    return y


def simulate_from_wikipedia_covariate(data_dir, covariate='unique_category', beta0=1.0, beta1=1.0, gamma=1.0, set_seed=42):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    np.random.seed(set_seed)


    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.



    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.random.binomial(1, propensities)

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg

    y, y0, y1 = simulate_y(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma, set_seed=set_seed)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)
    y0 = y0.astype(np.float32)
    y1 = y1.astype(np.float32)

    return t, y, y0, y1, propensities


def simulate_from_wikipedia_covariate_y_binary(data_dir, covariate='unique_category', beta0=1.0, beta1=1.0, gamma=1.0, set_seed=42):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    np.random.seed(set_seed)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.



    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.random.binomial(1, propensities)

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg

    y = simulate_y_binary(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma, set_seed=set_seed)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)

    return t, y


def simulate_from_wikipedia_covariate_treatment_all0(data_dir, covariate='unique_category', beta0=1.0, beta1=1.0, gamma=1.0,
                                                 set_seed=42):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)
    np.random.seed(set_seed)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.

    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.zeros(shape=len(propensities), dtype=np.float32)

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg

    y, y0, y1 = simulate_y(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma, set_seed=set_seed)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)
    y0 = y0.astype(np.float32)
    y1 = y1.astype(np.float32)

    return t, y, y0, y1


def simulate_from_wikipedia_covariate_treatment_all1(data_dir, covariate='unique_category', beta0=1.0, beta1=1.0, gamma=1.0,
                                                 set_seed=42):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    np.random.seed(set_seed)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.


    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.ones(shape=len(propensities), dtype=np.float32)

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg

    y, y0, y1 = simulate_y(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma, set_seed=set_seed)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)
    y0 = y0.astype(np.float32)
    y1 = y1.astype(np.float32)

    return t, y, y0, y1


# def simulate_exogeneity_experiment(base_propensity_scores, exogeneous_con=0.,
#                                    beta0=1.0, beta1=1.0, gamma=1.0):
#     extra_confounding = np.random.normal(0, 1, base_propensity_scores.shape[0]).astype(np.float32)
#
#     propensities = expit((1. - exogeneous_con) * logit(base_propensity_scores) +
#                          exogeneous_con * extra_confounding).astype(np.float32)
#
#     treatment = np.random.binomial(1, propensities)
#     y, y0, y1 = simulate_y(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma)
#
#     t = treatment.astype(np.int32)
#     y = y.astype(np.float32)
#     y0 = y0.astype(np.float32)
#     y1 = y1.astype(np.float32)
#
#     return t, y, y0, y1, propensities


def simulate_from_wikipedia_covariate_treatment_label(data_dir, covariate='unique_category', set_seed=2):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)


    # predictable covariates

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.

    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    np.random.seed(set_seed)
    propensities = 0.5 + 0.35 * confounder
    treatment = np.random.binomial(1, propensities)
    y = treatment
    treatment_new = treatment[:].copy()
    indices = np.where(np.in1d(treatment_new, [1]))[0]
    n_obs = indices.shape[0]
    subset = random.sample(indices.tolist(), int(n_obs / 2))
    treatment_new[subset] = 0

    treatment_agg = np.empty(shape=(len(treatment_new)), dtype=np.float32)
    for i in range(len(treatment_new)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        treatment_agg[i] = np.mean(treatment_new[neighbours], dtype=np.float32)

    t = treatment_agg.astype(np.float32)
    y = y.astype(np.float32)

    return t, y


def simulate_from_wikipedia_covariate_treatment_all0_treatment_label(data_dir, covariate='unique_category'):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.

    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.zeros(shape=len(propensities), dtype=np.float32)
    y = treatment
    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg
    # y = simulate_y_binary(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)

    return t, y


def simulate_from_wikipedia_covariate_treatment_all1_treatment_label(data_dir, covariate='unique_category'):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.

    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.ones(shape=len(propensities), dtype=np.float32)
    y = treatment

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg
    # y = simulate_y_binary(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)

    return t, y


def simulate_from_wikipedia_covariate_binary_region(data_dir, covariate='unique_category', set_seed=2):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)


    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.


    if covariate == 'unique_category':
        confounder = unique_category

    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    np.random.seed(set_seed)
    propensities = 0.5 + 0.35 * confounder
    treatment = np.random.binomial(1, propensities)
    y = treatment
    treatment_new = treatment[:].copy()
    indices = np.where(np.in1d(treatment_new, [1]))[0]
    n_obs = indices.shape[0]
    subset = random.sample(indices.tolist(), int(n_obs / 2))
    treatment_new[subset] = 0

    treatment_agg = np.empty(shape=(len(treatment_new)), dtype=np.float32)
    for i in range(len(treatment_new)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        treatment_agg[i] = np.mean(treatment_new[neighbours], dtype=np.float32)

    t = treatment_agg.astype(np.float32)
    y = y.astype(np.float32)
    confounder = confounder.astype(np.float32)
    return t, confounder





def main():
    # parser = add_parser_sampling_arguments()
    # args = parser.parse_args()
    # graph_data, profiles = load_data_pokec('dat/pokec/regional_subset')
    # data_dir = 'dat/pokec/regional_subset'
    # t, y = simulate_from_pokec_covariate_treatment_label(data_dir, covariate='age', beta0=1.0,
    #                                                            beta1=0.5, gamma=1.0)
    #
    # treatments = t
    # outcomes = y
    # make_prediction_generator = make_no_graph_input_fn0(graph_data, args, treatments, outcomes,
    #                                                     filter_test=False)  # make_no_graph_input_fn0(graph_data, args)
    # prediction_generator = make_prediction_generator()
    # itr = iter(prediction_generator)
    # sample = next(itr)
    # sbm_embedding = np.loadtxt('groups.txt')
    # sbm_embedding = sbm_embedding[:, 1:]  # drop the first column of embedding
    # sbm_embedding = sbm_embedding[sbm_embedding[:, 0].argsort()]
    # sbm_embedding = sbm_embedding[:, 1:]
    # outcomes = sample[1]['outcome']
    # treatments = sample[0]['treatment']
    # X = np.column_stack([treatments, sbm_embedding])
    # Y = tf.squeeze(outcomes).numpy()
    # n_obs = X.shape[0]
    # index_1 = np.random.choice(np.arange(X.shape[0]), int(n_obs / 2), replace=False)
    # Y = Y[index_1]
    # mask = np.ones(n_obs, dtype=bool)
    # mask[index_1] = False
    # X = X[mask]
    # Y = list(Y)
    # X = list(X)
    # log_reg = sm.Logit(Y, X).fit()

    # print('sbm_estimate')
    # print(log_reg.params[0])
    # print('sbm_std')
    # print(log_reg.bse[0])

    # n = X.shape[0]
    # train_prop = 0.5
    # train_idx = npr.choice(np.arange(n), int(train_prop * n), replace=False)
    # index_1 = train_idx
    # Y = Y[index_1]
    # X_tr = X[train_idx]
    # Y_tr = Y[train_idx]
    # X_tr_coef = sm.add_constant(X_tr)
    # model1 = sm.OLS(Y_tr, X_tr_coef).fit()
    # print('sbm_estimate')
    # print(model1.params[1])
    # print('sbm_std')
    # print(model1.bse[1])

    parser = add_parser_sampling_arguments()
    args = parser.parse_args()
    data_dir = 'dat/wikipedia'
    graph_data, profiles = load_data_wikipedia_processed(data_dir)
    t, y, y0, y1, prop = simulate_from_wikipedia_covariate(data_dir, covariate='unique_category', beta0=1.0, beta1=1, gamma=1.0, set_seed=42)



    treatments = t
    outcomes = y
    make_prediction_generator = make_no_graph_input_fn0(graph_data, args, treatments, outcomes,
                                                        filter_test=False)  # make_no_graph_input_fn0(graph_data, args)
    prediction_generator = make_prediction_generator()
    itr = iter(prediction_generator)
    sample = next(itr)
    sbm_embedding = np.loadtxt('groups.txt')
    sbm_embedding = sbm_embedding[:, 1:]  # drop the first column of embedding
    sbm_embedding = sbm_embedding[sbm_embedding[:, 0].argsort()]
    sbm_embedding = sbm_embedding[:, 1:]
    outcomes = sample[1]['outcome']
    treatments = sample[0]['treatment']
    X = np.column_stack([treatments, sbm_embedding])
    Y = tf.squeeze(outcomes).numpy()
    n = X.shape[0]
    train_prop = 0.5
    train_idx = npr.choice(np.arange(n), int(train_prop * n), replace=False)
    X_tr = X[train_idx]
    Y_tr = Y[train_idx]
    X_tr_coef = sm.add_constant(X_tr)
    model1 = sm.OLS(Y_tr, X_tr_coef).fit()
    print('sbm_estimate')
    print(model1.params[1])
    print('sbm_std')
    print(model1.bse[1])

    breakpoint()

if __name__ == "__main__":
    main()
