import jax
import jax.numpy as jnp
import torch

import torchvision.datasets as dset
import torchvision.transforms as transforms
import neural_tangents as nt

import numpy as np
import flax
import flax.linen as nn
import optax as tx
import neural_tangents.stax as stax

# from jax import config
# config.update("jax_enable_x64", True)

import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '1'


from typing import Any, Callable, Sequence, Tuple
from flax.training import train_state, checkpoints

import matplotlib.pyplot as plt
import functools
import operator
import fire

import data
from utils import *
import pickle
import math
import sys
import time


def get_noise_mask(key, shape, p):
    return (jax.random.uniform(key, shape = shape) > p).astype(jnp.float32)

def mask_noise(images, noise, noise_mask):
    return images * noise_mask + noise * (1. - noise_mask)

@functools.partial(jax.jit, static_argnames=('kernel_fn', 'use_mse', 'learn_labels'))
def get_distillation_loss(distill_train_state_params, train_images, train_labels, kernel_fn, noise_mask = None, init_noise = None, use_mse = False, learn_labels = False, minimum_dist = -1, lam = 1e-5):
    distill_images = distill_train_state_params['images']
    log_temp = distill_train_state_params['log_temp']
    distilled_labels = distill_train_state_params['labels']
    if not learn_labels:
        distilled_labels = jax.lax.stop_gradient(distilled_labels)
    
    log_temp = jnp.minimum(log_temp, 8)
    
    if noise_mask is not None:
        distill_images = mask_noise(distill_images, init_noise, noise_mask)
    
    K_ss = kernel_fn(distill_images, distill_images, 'ntk')
    K_ts = kernel_fn(train_images, distill_images, 'ntk')

    K_ss = K_ss + lam * K_ss.shape[0] * jnp.eye(K_ss.shape[0])
    
    preds = K_ts @ jnp.linalg.solve(K_ss, distilled_labels)
    # print((train_labels - preds).shape)
    
    
    loss = jnp.mean((train_labels - preds) ** 2)
    
    return loss, 0

@functools.partial(jax.jit, static_argnames=('kernel_fn', 'use_mse', 'mmd', 'learn_labels'))
def do_training_step_distillation(train_state, train_images, train_labels, kernel_fn, noise_mask = None, init_noise = None, use_mse = False, mmd = False, learn_labels = False, minimum_dist = -1, lam = 1e-5):

    # get_training_loss_l2(train_state.params, images, labels, train_state, l2 = l2, train = train, has_bn = has_bn, batch_stats = batch_stats, use_base_params = use_base_params)
        
    (loss, acc), grad = jax.value_and_grad(get_distillation_loss, argnums = 0, has_aux = True)(train_state.params, train_images, train_labels, kernel_fn, noise_mask = noise_mask, init_noise = init_noise, use_mse = use_mse, learn_labels = learn_labels, minimum_dist = minimum_dist, lam = lam)


    new_state = train_state.apply_gradients(grads = grad, train_it = train_state.train_it + 1)
    
    return new_state, (loss, acc)

@functools.partial(jax.jit, static_argnames=('kernel_fn', 'use_mse', 'mmd', 'learn_labels'))
def get_dk_and_base_loss(K, y, lam = 1e-3):
    n = K.shape[0]
    K_reg = K + jnp.eye(n) * n * lam

    M = K @ jnp.linalg.inv(K_reg)
    alpha = jnp.linalg.solve(K_reg, y)
    preds = K @ alpha

    eps = jnp.max(jnp.abs(y - preds))
    f_norm_squared = alpha.T @ preds #(1x1)
    #\alpha^T K \alpha

    #f(z) = \sum \alpha_i K(z, x_i)

    #f^T f = \sum_i \sum_j \alpha_i \alpha_j K(x_i, x_j)
    #

    

    #dk, base_loss, eps, preds, f_norm
    return jnp.trace(M), jnp.mean((y - preds)**2), eps, preds, jnp.sqrt(f_norm_squared)[0, 0]


def main(seed = 0, dataset_name = 'mnist_odd_even', output_dir = None, train_set_size = 100, noise_ratio = None, use_mse = False, mmd = False, learn_labels = False, minimum_dist = -1, lam = 1e-5, lengthscale = 0.1, s_factor = 1, scale_mode = 'const', min_size = -1, sigma = 0, find_min = False, kernel = 'rbf', batch_size = 50000, kernel_factor = 1.0, kernel_scale = 'const', init_real = False, n_distilled = None):
    if output_dir is not None:
        if not os.path.exists('./{}'.format(output_dir)):
            os.makedirs('./{}'.format(output_dir))

        with open('./{}/config.txt'.format(output_dir), 'a') as config_file:
            config_file.write(repr(locals()))

    

    key = jax.random.PRNGKey(seed)

    # init_fn, apply_fn, kernel_fn = stax.serial(
    #     stax.Flatten(),
    #     stax.Dense(2048, W_std = 1, parameterization = 'ntk'),
    #     stax.Relu(),
    #     stax.Dense(2048, W_std = 1, parameterization = 'ntk'),
    #     stax.Relu(),
    #     stax.Dense(1, W_std = 1, parameterization = 'ntk')
    # )

    base_lam = lam

    if scale_mode == 'const':
        lam = base_lam
    elif scale_mode == 'sqrt':
        lam = base_lam/np.sqrt(train_set_size)
    elif scale_mode == 'inv':
        lam = base_lam/train_set_size

    if kernel == 'rbf':
        kernel_pre_scale = 1
        if kernel_scale == 'sqrt':
            kernel_pre_scale = np.sqrt(train_set_size)
        elif kernel_scale == 'inv':
            kernel_pre_scale = train_set_size

        kernel_fn = bind(get_kernel_gaussian, lengthscale = lengthscale, k_scale = kernel_pre_scale * kernel_factor)
        # kernel_fn = bind(get_kernel_gaussian, lengthscale = lengthscale, k_scale = 1.0)
    elif kernel == 'matern12':
        kernel_fn = bind(get_kernel_matern_12, lengthscale = lengthscale)
    elif kernel == 'matern32':
        kernel_fn = bind(get_kernel_matern_32, lengthscale = lengthscale)
    elif kernel == 'matern52':
        kernel_fn = bind(get_kernel_matern_52, lengthscale = lengthscale)
    else:
        print(f'Invalid kernel choice: {kernel}')
        sys.exit()


    train_images, train_labels, train_mean = data.get_dataset(dataset_name, jax.random.PRNGKey(seed), train_set_size, sigma = sigma, kernel_fn = kernel_fn)

    K_train = kernel_fn(train_images.astype(jnp.float64), train_images.astype(jnp.float64))

    dk, base_loss, eps, preds, f_norm = get_dk_and_base_loss(K_train, train_labels, lam = lam)

    zero_error = jnp.mean(train_labels**2)

    label_scale = f_norm
    train_labels = train_labels/label_scale

    print(f'f_norm before scaling: {f_norm}')
    print(f"base_loss before scaling: {base_loss}")

    dk, base_loss, eps, preds, f_norm = get_dk_and_base_loss(K_train, train_labels, lam = lam)


    print(f"eps: {eps}")

    # eps_star = jnp.power((4 * lam * f_norm **2 + base_loss)/(16 * lam * f_norm **2), 1/3)
    eps_star = jnp.power((4 * lam * f_norm **2 + base_loss)/(16 * lam * f_norm **2), 1/3)

    print(eps_star)

    print((((8/eps_star**2) + 4 * (1 + eps_star)) * lam * f_norm))
    print(lam)
    print(label_scale)

    better_upper_bound = (((8/(eps_star**2)) + 4 * (1 + eps_star)) * lam * f_norm**2) + (1 + eps_star) * base_loss
    super_bound = 12 * lam * f_norm ** 2 + 2 * base_loss

    # s = 5 * dk * jnp.log(17 * dk)#/ lam
    s = dk * jnp.log(dk)#/ lam
    

    n_distilled_base = int(round(s_factor * s))

    if n_distilled is None:
        n_distilled = n_distilled_base

    if min_size != -1 and  n_distilled_base < min_size:
        n_distilled = min_size

    tol = 8 * lam * f_norm ** 2 + 2 * eps ** 2

    if find_min:
        n_distilled_max = K_train.shape[0]
        n_distilled_min = 2
        n_distilled = min(int(np.ceil(0.5 * (n_distilled_max + n_distilled_min))), 50)

        
        print(f'Running search with tolerance {tol}')
    else:
        n_distilled_max = K_train.shape[0]
        n_distilled_min = 2

    if find_min:
        output_dict = {
            'train_images': train_images,
            'train_labels': train_labels,
            'train_mean': train_mean,
            'distilled_images': train_images,
            'log_temp': 0,
            'distilled_labels': train_labels,
            'distillation_loss': base_loss,
            'dk': dk,
            's': s,
            'n_distilled': n_distilled,
            'base_loss': base_loss,
            'gap': base_loss - base_loss,
            'base_lam': base_lam,
            'lam': lam,
            'n_distilled_base': n_distilled_base,
            'eps': eps,
            'sigma': sigma,
            'f_norm': f_norm,
            'upper_bound': tol,
            'eps_star': eps_star,
            'better_upper_bound': better_upper_bound,
            'super_bound': super_bound,
            'label_scale': label_scale,
            'zero_error': zero_error
        }


    while(n_distilled_max - n_distilled_min > 1):
        print(f'n_train: {train_set_size}, dk: {dk}, s: {s}, n_distilled_base: {n_distilled_base}, n_distilled_actual: {n_distilled}, lam: {lam}, f_norm: {f_norm}, upper_bound: {tol}, better_upper_bound: {better_upper_bound}, super_bound: {super_bound}, base_loss: {base_loss}')
        
        #upper 8\lambda + 2e^2
        #better_upper solve for eps*, then plug into the the 8/e^2
        #super 12\lambda + 2 * base_loss
        print(f'relative crap: upper_bound: {tol * label_scale**2/zero_error}, better_upper_bound: {better_upper_bound * label_scale**2/zero_error}, super_bound: {super_bound * label_scale**2/zero_error}')


        if not dataset_name in ['two_clusters', 'from_kernel']:
            init_images = {
                'images': 0.2 * jnp.array(jax.random.normal(key, shape = [n_distilled, train_images.shape[1], train_images.shape[2], train_images.shape[3]])),
                'log_temp': 2.3 * jnp.ones(()),
                'labels': 0.2 * jnp.array(jax.random.normal(key, shape = [n_distilled, 1])) / label_scale,
            }
        else:
            if init_real:
                if n_distilled <= train_images.shape[0]:
                    indices = jax.random.choice(key, train_images.shape[0], shape = (n_distilled, ), replace = False)
                    key = jax.random.split(key)[0]
                    init_images = {
                        'images': jnp.take(train_images, indices, axis = 0),
                        'log_temp': 2.3 * jnp.ones(()),
                        'labels': jnp.take(train_labels, indices, axis = 0).astype(jnp.float64),
                    }
                else:
                    init_images = {
                        'images': jnp.concatenate([train_images, jnp.std(train_images) * jnp.array(jax.random.normal(key, shape = [n_distilled - train_images.shape[0], 2]))]),
                        'log_temp': 2.3 * jnp.ones(()),
                        'labels': jnp.concatenate([train_labels, 0.2 * jnp.array(jax.random.normal(key, shape = [n_distilled - train_images.shape[0], 1])) / label_scale,]),
                    }
            else:
                init_images = {
                    'images': jnp.std(train_images) * jnp.array(jax.random.normal(key, shape = [n_distilled, 2])),
                    'log_temp': 2.3 * jnp.ones(()),
                    'labels': 0.2 * jnp.array(jax.random.normal(key, shape = [n_distilled, 1])) / label_scale,
                }

        key = jax.random.split(key)[0]


        if noise_ratio is not None:
            init_noise = jax.random.normal(key, shape = init_images['images'].shape)

            key = jax.random.split(key)[0]

            noise_mask = get_noise_mask(key, shape = init_images['images'].shape, p = noise_ratio)
        else:
            init_noise = None,
            noise_mask = None


        opt = tx.chain(tx.adam(learning_rate=0.002))

        distill_train_state = TrainStateWithBatchStats.create(apply_fn = None, params = init_images, tx = opt, batch_stats = None, train_it = 0, base_params = None)
        
        max_iters = 20000
        min_loss = np.inf
        best_state = distill_train_state
        start_time = time.time()
        for i in range(max_iters + 1):
            if train_images.shape[0] > batch_size:
                indices = jax.random.choice(key, train_images.shape[0], (batch_size,), replace = False)
                key = jax.random.split(key)[0]
                train_images_batch = train_images[indices]
                train_labels_batch = train_labels[indices]
            else:
                train_images_batch = train_images
                train_labels_batch = train_labels
            distill_train_state, (loss, acc) = do_training_step_distillation(distill_train_state, train_images_batch, train_labels_batch, kernel_fn, noise_mask = noise_mask, init_noise = init_noise, use_mse = use_mse, mmd = mmd, learn_labels = learn_labels, minimum_dist = minimum_dist, lam = lam)

            if train_images.shape[0] <= batch_size and loss < min_loss:
                best_state = distill_train_state
                min_loss = loss
            elif i%100 == 0:
                loss = 0
                for b in range(int(np.ceil(train_images.shape[0]/batch_size))):
                    low_ind = b * batch_size
                    up_ind = min((b+1) * batch_size, train_images.shape[0])
                    loss_mean = get_distillation_loss(distill_train_state.params, train_images[low_ind: up_ind], train_labels[low_ind: up_ind], kernel_fn, noise_mask = noise_mask, init_noise = init_noise, use_mse = use_mse, learn_labels = learn_labels, minimum_dist = minimum_dist, lam = lam)[0]
                    loss_full = loss_mean * (up_ind - low_ind)
                    loss += loss_full
                loss = loss/train_images_batch.shape[0]

                if loss < min_loss:
                    best_state = distill_train_state
                    min_loss = loss

            if find_min and min_loss < tol:
                break
            
            # if i%100 == 0:
            #     print(time.time() - start_time)
            #     print(min_loss)

            if i % 10000 == 0:
                print(f'iter: {i}, current distill loss: {loss}, minimum distill loss: {min_loss}, gap: {loss - base_loss}, base_loss: {base_loss}')

        print(f'using checkpoint with min loss of {min_loss}')

        distill_train_state = best_state

        distilled_images = distill_train_state.params['images']
        distilled_labels = distill_train_state.params['labels']

        if noise_ratio is not None:
            distilled_images = mask_noise(distill_train_state.params['images'], init_noise, noise_mask)

        
        if find_min:
            gap = min_loss - base_loss
            if min_loss <= tol:
                n_distilled_max = n_distilled
                n_distilled = int(np.ceil(0.5 * (n_distilled_max + n_distilled_min)))

                output_dict = {
                    'train_images': train_images,
                    'train_labels': train_labels,
                    'train_mean': train_mean,
                    'distilled_images': distilled_images,
                    'log_temp': distill_train_state.params['log_temp'],
                    'distilled_labels': distilled_labels,
                    'distillation_loss': min_loss,
                    'dk': dk,
                    's': s,
                    'n_distilled': n_distilled,
                    'base_loss': base_loss,
                    'gap': min_loss - base_loss,
                    'base_lam': base_lam,
                    'lam': lam,
                    'n_distilled_base': n_distilled_base,
                    'eps': eps,
                    'sigma': sigma,
                    'f_norm': f_norm,
                    'upper_bound': tol,
                    'eps_star': eps_star,
                    'better_upper_bound': better_upper_bound,
                    'super_bound': super_bound,
                    'label_scale': label_scale,
                    'zero_error': zero_error
                }
                
            else:
                n_distilled_min = n_distilled
                n_distilled = int(np.ceil(0.5 * (n_distilled_max + n_distilled_min)))

            if n_distilled == n_distilled_max:
                n_distilled -= 1

        else:
            n_distilled_max = n_distilled_min + 1

            output_dict = {
                'train_images': train_images,
                'train_labels': train_labels,
                'train_mean': train_mean,
                'distilled_images': distilled_images,
                'log_temp': distill_train_state.params['log_temp'],
                'distilled_labels': distilled_labels,
                'distillation_loss': min_loss,
                'dk': dk,
                's': s,
                'n_distilled': n_distilled,
                'base_loss': base_loss,
                'gap': min_loss - base_loss,
                'base_lam': base_lam,
                'lam': lam,
                'n_distilled_base': n_distilled_base,
                'eps': eps,
                'sigma': sigma,
                'f_norm': f_norm,
                'upper_bound': tol,
                'eps_star': eps_star,
                'better_upper_bound': better_upper_bound,
                'super_bound': super_bound,
                'label_scale': label_scale,
                'zero_error': zero_error
            }

    

    

    if find_min:
        print(f'max: {n_distilled_max}, min: {n_distilled_min}')
        output_dict['n_distilled_max'] = n_distilled_max
        output_dict['n_distilled_min'] = n_distilled_min

    

    if output_dir is not None:
        pickle.dump(output_dict, open('./{}/distillation_result.pkl'.format(output_dir), 'wb'))

    print('done')

if __name__ == '__main__':
    # main('cifar10')
    fire.Fire(main)