import math
import time
import numpy as np
from copy import deepcopy
from multiprocessing.pool import ThreadPool as Pool
from sklearn.gaussian_process.kernels import RBF, Matern
from sklearn.gaussian_process import GaussianProcessRegressor

import optax
import jax
import jax.numpy as jnp

import warnings
warnings.filterwarnings("ignore")


def tuning_mattern(xs, ys, target_xs, target_ys, choice=[0.1, 1, 10, 100], std=0.001, effective_dim=-1):
    loss = []
    n, d = xs.shape
    
    if effective_dim > 0:
        indices = np.random.choice(d, size=min(d, effective_dim), replace=False)    
        xs = xs[:,indices]
        target_xs = target_xs[:,indices]
        
    norm = np.sqrt(np.linalg.norm(xs, axis=-1)).mean()
    xs /= norm
    target_xs /= norm
    
    for l in choice:
        kernel = Matern(length_scale=l, nu=2.5)
        
        K_mat = kernel(xs, xs)
        K_mat_inv = np.linalg.inv(K_mat + std * np.eye(n))
        k_vec = kernel(target_xs, xs)
        pred_ys = jax.numpy.matmul(jax.numpy.matmul(k_vec, K_mat_inv), ys)
        
        loss += [jax.numpy.linalg.norm(pred_ys - target_ys)]
        
    return choice[np.argmin(loss)]

def cosine_similarity(vector1, vector2):
    dot_product = np.dot(vector1, vector2)
    norm_vector1 = np.linalg.norm(vector1)
    norm_vector2 = np.linalg.norm(vector2)
    similarity = dot_product / (norm_vector1 * norm_vector2 + 1e-5)
    return similarity

def get_proxy_grad_func(x_values, y_values, kernel=1.0 * RBF(1.0), std=0.01, normalize_x=True, effective_dim=-1):
    n, d = x_values.shape
    
    # to check a smaller dimension can also help the optimzation
    if effective_dim > 0:
        indices = np.random.choice(d, size=min(d, effective_dim), replace=False)    
        x_values = x_values[:,indices]
        
    if normalize_x:
        norm = np.sqrt(np.linalg.norm(x_values, axis=-1)).mean()
        x_values /= norm
    
    K_mat_inv = np.linalg.inv(kernel(x_values, x_values) + std * np.eye(n))
    # K_mat_inv = np.linalg.pinv(kernel(x_values, x_values))
    ys = jax.numpy.einsum('bi,ij->bj', K_mat_inv, y_values)
    
    def proxy_grad_func(x):        
        x = x[indices] if effective_dim > 0 else x
        x = x / norm if normalize_x else x
        pred = jax.numpy.einsum('bi,ij->bj', kernel(x.reshape(1,-1), x_values), ys)
        return pred.reshape(-1)
    return proxy_grad_func


def run_standard(func, opt_name, lr, x0, num_iters, num_parall, datas=None, opt_state=None):
    
    # grad_func = jax.jit(jax.grad(func))
    fgx_func = jax.jit(jax.value_and_grad(func))
    global_opt = eval(opt_name)(learning_rate=lr)
    global_opt_state = global_opt.init(x0) if opt_state is None else opt_state
    
    x = x0
    
    for i in range(num_iters * num_parall):
        fx, grad = fgx_func(x) if datas is None else fgx_func(x, *datas[i])
        updates, global_opt_state = global_opt.update(grad, global_opt_state)
        x = optax.apply_updates(x, updates)
        
        # print("===>", i, "%.4f" % fx)
    
    return x, fx, global_opt_state


def run_line_search(func, opt_name, lr, x0, num_iters, num_parall, datas=None, opt_state=None, inter_results={}):
    # grad_func = jax.jit(jax.grad(func))
    fgx_func = jax.jit(jax.value_and_grad(func))
    global_opt = eval(opt_name)(learning_rate=lr)
    global_opt_state = global_opt.init(x0) if opt_state is None else opt_state
    
    def run_update(j, proxy_grad, global_opt_state, x, data=None):
        proxy_opt = eval(opt_name)(learning_rate=j * lr)
        proxy_updates, proxy_opt_state = proxy_opt.update((j > 0) * proxy_grad, global_opt_state)
        proxy_x = optax.apply_updates(x, proxy_updates)
        
        fx, grad = fgx_func(proxy_x) if data is None else fgx_func(proxy_x, *data)
        updates, proxy_opt_state = global_opt.update(grad, proxy_opt_state)
        x_update = optax.apply_updates(proxy_x, updates)
        return fx, proxy_x, grad, proxy_opt_state, x_update
    
    x = x0
    proxy_grad = jax.numpy.zeros_like(x) if "proxy_grad" not in inter_results else inter_results["proxy_grad"]
    
    for i in range(num_iters):
        caches = list(map(
            run_update, range(num_parall), [proxy_grad] * num_parall, [global_opt_state] * num_parall, [x] * num_parall, datas[i*num_parall:(i+1)*num_parall]
            ))
        idx = np.argmin([c[0] for c in caches]).item()
        # idx = np.argmin([jax.numpy.linalg.norm(c[2]) for j, c in enumerate(caches)]).item()
        fx, _, proxy_grad, global_opt_state, x = caches[idx]
        # print("===>", i, idx, "%.4f" % fx)
    
    inter_results.update({
        "proxy_grad": proxy_grad,
    })
    
    return x, fx, global_opt_state


def run_optex(func, opt_name, lr, x0, num_iters, num_parall, datas=None, opt_state=None, effective_dim=-1, inter_results={}):
    def run_proxy_update(opt_name, lr, proxy_grad_func, global_opt_state, x):
        proxy_opt = eval(opt_name)(learning_rate=lr)
        proxy_x, proxy_opt_state = x, deepcopy(global_opt_state)
        proxy_x_cache = [proxy_x]
        proxy_opt_state_cache = [proxy_opt_state]
        
        for k in range(num_parall-1):
            proxy_grad = proxy_grad_func(proxy_x) # grad = grad_func(proxy_x)
            proxy_updates, proxy_opt_state = proxy_opt.update(proxy_grad, proxy_opt_state)
            proxy_x = optax.apply_updates(proxy_x, proxy_updates)
            proxy_x_cache.append(proxy_x)
            proxy_opt_state_cache.append(proxy_opt_state)
        return proxy_x_cache, proxy_opt_state_cache
    
    def run_parallelized_iterations(proxy_x, proxy_opt_state, imgs, labels):
        fx, grad = fgx_func(proxy_x, imgs, labels)
        updates, proxy_opt_state = global_opt.update(grad, proxy_opt_state)
        x_update = optax.apply_updates(proxy_x, updates)
        return fx, grad, proxy_opt_state, x_update
    
    x = x0
    fgx_func = jax.value_and_grad(func)
    global_opt = eval(opt_name)(learning_rate=lr)
    global_opt_state = global_opt.init(x0) if opt_state is None else opt_state
    
    if "x_history" not in inter_results:
        x_history, g_history = [], []
    else:
        x_history, g_history = inter_results["x_history"], inter_results["g_history"]
        
    if "length_scale" in inter_results.keys():
        length_scale = inter_results["length_scale"]
    else:
        length_scale = 0.1
    
    for i in range(num_iters):
        if len(x_history) == 0:
            proxy_grad_func = lambda z: jnp.zeros_like(x)
        else:
            # start = time.time()
            proxy_grad_func = get_proxy_grad_func(
                jnp.concatenate(x_history, axis=0), jnp.concatenate(g_history, axis=0),
                kernel=1.0 * Matern(length_scale=length_scale, nu=2.5), 
                std=1e-3,
                normalize_x=True,
                effective_dim=effective_dim
            )
        
        proxy_x_cache, proxy_opt_state_cache = run_proxy_update(opt_name, lr, proxy_grad_func, global_opt_state, x)
        
        # for synthetic, rl, mnist, fashion-mnist task
        input1 = jnp.array([d[0] for d in datas[i*num_parall:(i+1)*num_parall]])
        input2 = jnp.array([d[1] for d in datas[i*num_parall:(i+1)*num_parall]])
        caches = list(map(run_parallelized_iterations, proxy_x_cache, proxy_opt_state_cache, input1, input2))
        fxs = [c[0] for c in caches]
        gxs = [c[1] for c in caches]
        states = [c[2] for c in caches]
        xs = [c[3] for c in caches]
        
        # for cifar-10, transformer task
        # input1 = jnp.array([d[0] for d in datas[i*num_parall:(i+1)*num_parall]])
        # input2 = jnp.array([jnp.array([d[1].inputs, d[1].targets]) for d in datas[i*num_parall:(i+1)*num_parall]])
        # proxy_x_cache = jnp.array(proxy_x_cache)
        # proxy_opt_state_cache = jnp.array(proxy_opt_state_cache)
        # fxs,gxs,states,xs = jax.vmap(run_parallelized_iterations)(proxy_x_cache, proxy_opt_state_cache, input1, input2)
        
        idx = -1
        fx, x = fxs[idx], xs[idx]
        
        if "sgd" not in opt_name:
            global_opt_state = states[idx]
        
        x_history = x_history[-20+num_parall:] + [c.reshape(1,-1) for c in proxy_x_cache]
        g_history = g_history[-20+num_parall:] + [c.reshape(1,-1) for c in gxs]

    inter_results.update({
        "x_history": x_history,
        "g_history": g_history,
    })
    
    return x, fx, global_opt_state


def run_benchmark(func, opt_name, lr, x0, num_iters, num_parall, datas=None, opt_state=None, inter_results={}):
    # grad_func = jax.jit(jax.grad(func))
    fgx_func = jax.jit(jax.value_and_grad(func))
    
    global_opt = eval(opt_name)(learning_rate=lr)
    global_opt_state = global_opt.init(x0) if opt_state is None else opt_state
    
    def run_proxy_update(opt_name, lr, proxy_grad_func, global_opt_state, x, datas=None):
        proxy_opt = eval(opt_name)(learning_rate=lr)
        proxy_x, proxy_opt_state = x, deepcopy(global_opt_state)
        
        proxy_x_cache = [proxy_x]
        proxy_opt_state_cache = [proxy_opt_state]
        
        for k in range(num_parall-1):
            fx, proxy_grad = proxy_grad_func(proxy_x) if datas is None else proxy_grad_func(proxy_x, *datas[k])
            proxy_updates, proxy_opt_state = proxy_opt.update(proxy_grad, proxy_opt_state)
            proxy_x = optax.apply_updates(proxy_x, proxy_updates)
            
            proxy_x_cache.append(proxy_x)
            proxy_opt_state_cache.append(proxy_opt_state)
        return proxy_x_cache, proxy_opt_state_cache
        
    def run_parallelized_iterations(proxy_x, proxy_opt_state, data=None):
        fx, grad = fgx_func(proxy_x) if data is None else fgx_func(proxy_x, *data)
        updates, proxy_opt_state = global_opt.update(grad, proxy_opt_state)
        x_update = optax.apply_updates(proxy_x, updates)
        return fx, grad, proxy_opt_state, x_update
    
    x = x0
    
    for i in range(num_iters):
        proxy_x_cache, proxy_opt_state_cache = run_proxy_update(
            opt_name, lr, fgx_func, global_opt_state, x, 
            None if datas is None else datas[i*num_parall:(i+1)*num_parall]
        )
        
        caches = list(map(run_parallelized_iterations, proxy_x_cache, proxy_opt_state_cache, datas[i*num_parall:(i+1)*num_parall]))
        idx = -1
        fx, _, global_opt_state, x = caches[idx]
        # print("===>", i, idx, "%.4f" % fx)
    
    return x, fx, global_opt_state