import argparse_utils
import numpy as onp
import jax
import jax.random as random
import jax.numpy as np
import neural_tangents as nt
from data import get_datasets
from neural_tangents import stax
from sklearn.metrics import roc_auc_score, roc_curve, average_precision_score
from jax.example_libraries import optimizers
from jax import jit, grad, vmap
import os
import copy
from pathlib import Path
import math
from neural_tangents import stax
import torch.nn.functional as F
from pathlib import Path
import os
import glob
import torch
from torch import nn as nn
import math
import copy
import torch.optim as optim
import functools

def get_argparse():
    parser = argparse_utils.CustomArgumentParser()
    parser.add_argument('--out_dir', type=str, default='out')
    parser.add_argument('--no_cuda', default=False, type=argparse_utils.str2bool)

    parser.add_argument('--dataset', type=str, default='mnist', help='mnist, concentric', metavar="data")
    parser.add_argument('--num_train_data', type=int, default=1000, metavar="train")

    # parser.add_argument('--last_layer_only', default=False, type=argparse_utils.str2bool, metavar=">
    parser.add_argument('--num_test_data', type=int, default=-1, metavar="test")
    parser.add_argument('--num_ensemble', type=int, default=30)
    parser.add_argument('--first_specialist', type=int, default=0)

    parser.add_argument('--binary', default=False, type=argparse_utils.str2bool, metavar="binary")

    #    parser.add_argument('--foo', type=int, default=30, metavar="ens")
    parser.add_argument('--compute_gd_ntk_change', default=False, type=argparse_utils.str2bool)
    parser.add_argument('--no_lin', default=False, type=argparse_utils.str2bool)

    # GD setting: only ones with metavar!
    parser.add_argument('--loss', type=str, default="mse", metavar="loss")

    # Architecture
    parser.add_argument('--hidden_widths', type=str, default="1024")
    parser.add_argument('--hidden_depths', type=str, default="2")
    parser.add_argument('--net', type=str, default="mlp", metavar="net")
    parser.add_argument('--activation', type=str, default="relu", metavar="act")
    parser.add_argument('--bias', default=1, type=float, metavar="bias")
    parser.add_argument('--lr', default=0.1, type=float, metavar="lr")
    parser.add_argument('--threshold', default="0.1,0.03,0.01", type=str)

    parser.add_argument('--momentum', default=0.9, type=float, metavar="momentum")
    parser.add_argument('--batch_sgd', default=-1, type=int, metavar="sgdbatch")

    # WRN CONFIG
    parser.add_argument('--batch', default=-1, type=int)

    return parser

def append(dict, key, value):
    if key not in dict:
        dict[key] = []
    dict[key].append(value)

def load_from_cache_or_compute(dir_name, file_prefix, fct, *args, **kwargs):
    file_name = os.path.join(dir_name, file_prefix + ".pt")
    if len(glob.glob(file_name)) > 0:
        print("LOAD")
        return torch.load(file_name)
    else:
        Path(dir_name).mkdir(parents=True, exist_ok=True)
        result = fct(*args, **kwargs)
        torch.save(result, file_name)
        return result

def acc_fn(pred, y):
    return np.sum(np.argmax(pred, axis=1) == np.argmax(y,axis=1))/y.shape[0]

def get_loss(loss_id):
    if loss_id == 'mse':
        loss = lambda pred, y: 0.5 * np.mean((pred - y) ** 2)
    return loss

def get_covariance(m1_, m2_):
    res = 0
    for i in range(m1_.shape[-1]):
        m1=m1_[:,:,i]
        m2=m2_[:,:,i]

        m1 = (m1.T-m1.T.mean(0)).T
        m2 = (m2.T-m2.T.mean(0)).T
        res += jax.numpy.diag(m1 @ m2.T) / m1.shape[1]
    return res


def compute_auroc(in_dis, out_dis):
    assert (len(in_dis.shape) == len(out_dis.shape) == 1), "{} {}".format(in_dis.shape, out_dis.shape)
    y_true = np.concatenate([np.zeros(in_dis.shape[0]),
                                 np.ones(out_dis.shape[0])]).reshape(-1)
    y_scores = np.concatenate([in_dis, out_dis]).reshape(-1)
    return roc_auc_score(y_true, y_scores)

def get_activation(activation):
    if activation == "relu":
        activation_fn=stax.Relu()
        w_std=math.sqrt(2)
    elif activation == 'softplus':
        activation_fn = nt.stax.Elementwise(fn=lambda x: 1/5*np.logaddexp(5*x, 0))
        w_std=math.sqrt(2)
    return activation_fn, w_std

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    Main = stax.serial(
      stax.LayerNorm(), stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME', W_std=math.sqrt(2)),
      stax.LayerNorm(), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME', W_std=math.sqrt(2)))
#      stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME', W_std=math.sqrt(2)),
#      stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME', W_std=math.sqrt(2)))
    Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
      channels, (3, 3), strides, padding='SAME', W_std=math.sqrt(2))
    return stax.serial(stax.FanOut(2),
                     stax.parallel(Main, Shortcut),
                     stax.FanInSum())

def WideResnetGroup(n, channels, strides=(1, 1)):
    blocks = []
    blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
    for _ in range(n - 1):
        blocks += [WideResnetBlock(channels, (1, 1))]
    return stax.serial(*blocks)


def get_wrn(out, width, depth, bias=None, activation='relu', data=None):
    k=width
    block_size=depth
    assert data is not None
    return stax.serial(
      stax.Conv(16, (3, 3), padding='SAME', W_std=math.sqrt(2)),
      WideResnetGroup(block_size, int(16 * k)),
      WideResnetGroup(block_size, int(32 * k), (2, 2)),
      WideResnetGroup(block_size, int(64 * k), (2, 2)),
      stax.LayerNorm(), stax.Relu(), #???
      stax.AvgPool((8,8)) if data == "cifar10" else stax.AvgPool((7,7)),
      stax.Flatten(),
      stax.Dense(out, 1., b_std=bias))

def get_mlp(out, width, depth, bias=None, activation="relu"):
    module=[]
    module.append(stax.Flatten())
    activation_fn, w_std = get_activation(activation)

    for h in range(depth):
        module.append(stax.Dense(width, b_std = bias, W_std=w_std))
        module.append(activation_fn)
    module.append(stax.Dense(out, b_std = bias )) #1 )) #bias))

    return  stax.serial(*module)

def get_miniconv(out, width, depth, bias=None, activation="relu", stride=1):
    module=[]
    activation_fn, w_std = get_activation(activation)
    for h in range(depth):
        curr_stride=stride
        if isinstance(stride, list):
            curr_stride=stride[h]
        module.append(stax.Conv(width, (3,3), strides=(curr_stride,curr_stride), padding='SAME', b_std = bias, W_std=w_std))
        module.append(activation_fn)
    module.append(stax.Flatten())
    module.append(stax.Dense(out, b_std = bias))

    return stax.serial(*module)

def get_miniminiconv(out, width, depth, **kwargs):
    return get_miniconv(out, width, depth, **kwargs, stride=[2, 1] * (depth // 2) + ([2] if depth % 2 == 1 else []))
