import argparse

import jax
import jax.numpy as np

from jax import grad, jit, vmap, pmap, value_and_grad
from jax import random

from jax.tree_util import tree_multimap, tree_map
from utils import optimizers
from utils import adaptation_utils
from utils.regularizers import weighted_parameter_loss
import haiku as hk

import numpy as onp

import tensorflow_datasets as tfds
import tensorflow as tf
# disables tf from seeing the GPU, might help with OOM errors on BRC
tf.config.experimental.set_visible_devices([], "GPU")

from jax.config import config

import os
import requests
import pickle

import time

from models.util import get_model

from utils.training_utils import train_epoch
from utils.eval import eval_ds_all, get_logits, get_labels

from utils.losses import nll, accuracy, entropy, brier, ece
from utils.misc import get_single_copy, manual_pmap_tree

from posteriors.utils import sample_weights_diag
from posteriors.swag import init_swag, update_swag, collect_posterior

parser = argparse.ArgumentParser(description='Runs basic train loop on a supervised learning task')
parser.add_argument(
    "--dir",
    type=str,
    default=None,
    required=False,
    help="Training directory for logging results"
)
parser.add_argument(
    "--log_prefix",
    type=str,
    default=None,
    required=False,
    help="Name prefix for logging results"
)
parser.add_argument(
    "--data_dir",
    type=str,
    default=None,
    required=False,
    help="Directory for storing datasets"
)
parser.add_argument(
    "--seed",
    type=int,
    default=0,
    required=False
)
parser.add_argument(
    "--wd",
    type=float,
    default=5e-4,
    required=False
)
parser.add_argument(
    "--swag_file",
    type=str,
    default=None,
    required=True,
    help="Directory for loading SWAG State"
)
parser.add_argument(
    "--model",
    type=str,
    default="ResNet26",
    required=False,
    help="Model class"
)
parser.add_argument(
    "--dataset",
    type=str,
    default="mnist",
    required=False,
    help="Either mnist, mnist-m or usps"
)
parser.add_argument( # try cifar10-1?
    "--corruption_type",
    type=str,
    default="brightness",
    required=False,
)
parser.add_argument(
    "--corruption_level",
    type=int,
    default=1,
    required=False,
)
parser.add_argument(
    "--ensemble_size",
    type=int,
    default=10,
    required=False,
)
parser.add_argument(
    "--n_epochs",
    type=int,
    default=10,
    required=False,
)
parser.add_argument(
    "--batch_size",
    type=int,
    default=128,
    required=False,
)
parser.add_argument(
    "--lr",
    type=float,
    default=1e-3,
    required=False,
)
parser.add_argument(
    "--adapt_bn_only",
    dest="adapt_bn_only",
    action='store_true'
)
### Not yet implemented
parser.add_argument(
    "--use_swag_posterior",
    dest="use_swag_posterior",
    action='store_true'
)
parser.add_argument(
    "--use_data_augmentation",
    dest="use_data_augmentation",
    action='store_true'
)
parser.add_argument(
    "--swag_posterior_weight",
    type=float,
    default=1e-4,
    required=False,
)

args = parser.parse_args()

### CIFAR10 channel means and stddevs
channel_means = np.array([0.4914, 0.4822, 0.4465])
channel_stds = np.array([(0.2023, 0.1994, 0.2010)])

ds = args.dataset

n_classes = 10

# use local device count here even with tpus?
n_devices = jax.device_count()

batch_size = args.batch_size
def preprocess_inputs(datapoint):
    image, label = datapoint['image'], datapoint['label']
    image = image / 255
    image = tf.image.resize(image, (32, 32))
    image = (image - channel_means) / channel_stds
    label = tf.one_hot(label, n_classes) 
    return image, label

def augment_train_data(image, label):
    # return image, label
    if args.use_data_augmentation:
        image = tf.image.resize_with_crop_or_pad(image, 36, 36)
        image = tf.image.random_crop(image, size=(32, 32, 3))
        image = tf.image.random_flip_left_right(image)
        # janky label smoothing
        label += 0.005
    return image, label


uncorrupted_ds_train = tfds.load('svhn_cropped', split='train', data_dir=args.data_dir, shuffle_files=True)
ds_uncorrupted = uncorrupted_ds_train.map(preprocess_inputs).cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

orig_ds_train = tfds.load(args.dataset, split='test', data_dir=args.data_dir, shuffle_files=True)
orig_ds_test = tfds.load(args.dataset, split='test', data_dir=args.data_dir)

ds_train = orig_ds_train.shuffle(10000, reshuffle_each_iteration=True).map(preprocess_inputs, num_parallel_calls=tf.data.experimental.AUTOTUNE).cache().map(augment_train_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size, drop_remainder=True).batch(n_devices, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

ds_test = orig_ds_test.map(preprocess_inputs).cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
ds_train_eval = orig_ds_train.map(preprocess_inputs).cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)


options = ds_train.options()
options.experimental_threading.private_threadpool_size = 48
options.experimental_threading.max_intra_op_parallelism = 1

model = get_model(args.model, n_classes)
### removes RNG component and runs with is_training=True
@jit
def orig_net_apply(params, state, rng, x):
    return model.apply(params, state, None, x, True)

@jit
def orig_net_apply_eval(params, state, x):
    return model.apply(params, state, None, x, False)

@jit
def orig_net_apply_eval_bn(params, state, x):
    return model.apply(params, state, None, x, True)

rng = random.PRNGKey(0)
rng = np.broadcast_to(rng, (n_devices,) + rng.shape)
# thing = next(iter(tfds.as_numpy(ds_train)))
# import ipdb; ipdb.set_trace()
# initializes copies of parameters and states on each device
init_params, init_state = pmap(lambda rng, x: model.init(rng, x, is_training=True))(rng, next(iter(tfds.as_numpy(ds_train)))[0])
del init_params


all_logits = []
logits_dict = {}
logits_dict['Initial Logits'] = []
logits_dict['Initial Batchnorm Adapted Logits'] = []
for i in range(args.n_epochs):
    logits_dict['Epoch_{} Logits'.format(i)] = []

num_epochs = args.n_epochs
def step_size_schedule(i):
    return args.lr

def collect_logits(swag_state):
    swag_means, swag_vars = collect_posterior(swag_state)

    init_params = tree_map(lambda x: x[None], swag_means)
    net_state = init_state

    single_params = swag_means # get_single_copy(init_params)
    single_state = get_single_copy(init_state)

    if args.adapt_bn_only:
        all_param_names = init_params.keys()
        bn_params, other_params = hk.data_structures.partition(lambda m, n, p: 'batchnorm' in m, swag_means)
        other_params = get_single_copy(other_params)

        bn_only_net_apply = lambda bn_p, state, rng, x: orig_net_apply(hk.data_structures.merge(bn_p, other_params), state, rng, x)
        net_apply = jit(bn_only_net_apply)
        bn_only_net_apply_eval = lambda bn_p, state, rng, x: orig_net_apply_eval(hk.data_structures.merge(bn_p, other_params), state, rng, x)
        net_apply_eval = jit(bn_only_net_apply)
        bn_only_net_apply_eval_bn = lambda bn_p, state, x: orig_net_apply_eval(hk.data_structures.merge(bn_p, other_params), state, x)
        net_apply_eval_bn = jit(bn_only_net_apply_eval_bn)
        net_params = bn_params
        print("Working with adapt bn only", flush=True)
    else:
        net_params = init_params
        net_apply = orig_net_apply
        net_apply_eval = orig_net_apply_eval
        net_apply_eval_bn = orig_net_apply_eval_bn
        print("Adapting all parameters", flush=True)

    if args.use_swag_posterior:
        print("Using swag posterior")
        regularizer = lambda params: args.swag_posterior_weight * weighted_parameter_loss(params, swag_means, swag_vars)
        regularizer(single_params)
    else:
        regularizer = None

    opt_init, opt_update, get_params = optimizers.momentum(step_size=step_size_schedule, mass=0.9, wd=args.wd)
    opt_state = pmap(opt_init)(net_params)



    rng = random.PRNGKey(args.seed)
    rng = np.broadcast_to(rng, (n_devices,) + rng.shape)

    def eval(eval_params, eval_net_state):
        start = time.time()

        test_results = eval_ds_all(tfds.as_numpy(ds_test), 
                               eval_params, 
                               eval_net_state, 
                               net_apply_eval, 
                               (nll, entropy, accuracy, brier, ece))
        return test_results

    def eval_bn(eval_params, eval_net_state):
        start = time.time()

        test_results = eval_ds_all(tfds.as_numpy(ds_test), 
                               eval_params, 
                               eval_net_state, 
                               net_apply_eval_bn, 
                               (nll, entropy, accuracy, brier, ece))
        return test_results


    def eval_logits(eval_params, eval_net_state, with_labels=False):
        start = time.time()

        logits, labels = get_logits(tfds.as_numpy(ds_test), 
                               eval_params, 
                               eval_net_state, 
                               net_apply_eval)
        if with_labels:
            return logits, labels
        return logits

    def eval_logits_bn(eval_params, eval_net_state, with_labels=False):
        start = time.time()

        logits, labels = get_logits(tfds.as_numpy(ds_test), 
                               eval_params, 
                               eval_net_state, 
                               net_apply_eval_bn)
        if with_labels:
            return logits, labels
        return logits

    uncorrupted_state = adaptation_utils.collect_batchnorm_running_stats(swag_means,
            single_state, tfds.as_numpy(ds_uncorrupted), net_apply)

    test_results = eval(swag_means, uncorrupted_state)
    logits, labels = eval_logits(swag_means, uncorrupted_state, with_labels=True)
    test_labels = labels
    print("Initial Logits Results", test_results)
    logits_dict['Initial Logits'].append(onp.array(logits))

    batchnorm_adapted_state = adaptation_utils.collect_batchnorm_running_stats(swag_means,
            single_state, tfds.as_numpy(ds_test), net_apply)
    test_results = eval(swag_means, batchnorm_adapted_state)

    logits = eval_logits(swag_means, batchnorm_adapted_state)
    logits_dict["Initial Batchnorm Adapted Logits"].append(onp.array(logits))
    print("Initial Batchnorm Results", test_results)

    rng = random.PRNGKey(args.seed)
    for epoch in range(num_epochs):
        # constructs numpy iterator
        start = time.time()
        np_ds = tfds.as_numpy(ds_train)
        opt_state, net_state, train_loss = train_epoch(epoch, 
                                                       opt_state, 
                                                       net_state, 
                                                       rng,
                                                       np_ds, 
                                                       entropy, 
                                                       get_params, 
                                                       net_apply, 
                                                       opt_update, 
                                                       regularizer=regularizer,
                                                       distributed=True)
        print('Epoch {}: {} {}'.format(epoch, train_loss, time.time() - start), flush=True)
        if epoch % 1 == 0:
            # neesd to flatten params for non-distributed eval, arbitrarily takes first copy of params
            eval_params = get_params(opt_state)
            eval_params, eval_net_state = get_single_copy((eval_params, net_state))
            test_results = eval_bn(eval_params, eval_net_state)
            logits = eval_logits_bn(eval_params, eval_net_state)
            logits_dict['Epoch_{} Logits'.format(epoch)].append(onp.array(logits))
            print("Evaluation {}".format(epoch), test_results, time.time() - start)


# compute marginals after
bn_only_str = 'adaptbnonly_' if args.adapt_bn_only else ''

filename = 'logs/entropy_minimization/ensemble{}/{}/{}/posteriorweight{}_{}lr{}_batchsize{}/seed{}.pkl'.format(args.ensemble_size, ds, args.model, args.swag_posterior_weight, bn_only_str, args.lr, args.batch_size, args.seed)
try:
    pickle.load(open(filename, 'rb'))
    print(filename, 'file loaded')
except:
    print(filename, 'file not found')


for seed in range(args.ensemble_size):
    with open(args.swag_file.format(seed), 'rb') as f:
        swag_state = pickle.load(f)
        collect_logits(swag_state)
        del swag_state

test_labels = get_labels(tfds.as_numpy(ds_test))

def marginal_logits(logits):
    logits = np.array(logits)
    logits = jax.nn.log_softmax(logits, axis=-1)
    print(logits.shape)
    n = logits.shape[0]
    return jax.scipy.special.logsumexp(logits, axis=0, b=1/n)

log_dict = {}

for string in ['Initial Logits', 'Initial Batchnorm Adapted Logits']:
    logits = marginal_logits(logits_dict[string])
    stats = [nll(logits, test_labels), entropy(logits, test_labels), accuracy(logits, test_labels), brier(logits, test_labels), ece(logits, test_labels)]
    log_dict[string] = stats
    print(string, stats)

for itr in range(args.n_epochs):
    logits = marginal_logits(logits_dict['Epoch_{} Logits'.format(itr)])
    stats = [nll(logits, test_labels), entropy(logits, test_labels), accuracy(logits, test_labels), brier(logits, test_labels), ece(logits, test_labels)]
    log_dict['Epoch_{} Test'.format(itr)] = stats
    print('Epoch_{} Marginal'.format(itr), stats)

print(filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)
pickle.dump(log_dict, open(filename, 'wb'))
