from typing import Callable, List, Literal, NamedTuple, Optional, Tuple

from tqdm import tqdm

import jax
import jax.numpy as jnp
import jax.scipy as jsp

from ott.geometry import costs, pointcloud
from ott.solvers import linear
from scipy.optimize import bisect

__all__ = [
    "ExperimentOutput",
    "get_concave_fn",
    "structured_loss",
    "gradient_step",
    "train",
    "get_scale_reg",
    "generate_data",
    "proj",
    "svd_peakiness",
    "subspace_criterion",
    "get_mse",
]


class ExperimentOutput(NamedTuple):
    x: jnp.ndarray
    y: jnp.ndarray
    gt_cost: costs.RegTICost
    pred_cost: costs.RegTICost
    best_pred_cost: costs.RegTICost
    losses: List[float]
    criteria: List[float]
    norm_grads: List[float]


def get_concave_fn(
    rng: jax.Array, *, kind: Literal["quad", "lse", "icnn"], d: int
) -> Callable[[jnp.ndarray], float]:
    if kind == "quad":
        A = jax.random.normal(rng, (d, 2 * d))
        A = 0.5 * A @ A.T
        return lambda z: -jnp.sum(z * (A.dot(z)))
    if kind == "l2-1":
        return lambda z: -(jnp.sum((z - 0.3) ** 2) ** 1.1)
    if kind == "lse":
        return lambda z: -jsp.special.logsumexp(z)
    if kind == "icnn":
        from ott.neural.methods import neuraldual
        from ott.neural.networks import icnn
        # as defined in: https://arxiv.org/pdf/2106.01954
        dim_hidden = [max(2 * d, 64), max(2 * d, 64), max(d, 32)]
        net = icnn.ICNN(d, dim_hidden, pos_weights=False)
        params = net.init(rng, jnp.ones((1, d)))
        params = neuraldual.W2NeuralDual._clip_weights_icnn(params)
        return jax.jit(lambda z: -net.apply(params, z))
    raise NotImplementedError(kind)


def structured_loss(
    cost_fn: costs.RegTICost,
    x: jnp.ndarray,
    y: jnp.ndarray,
    epsilon: Optional[float],
    inner_iterations: int = 10,
) -> jnp.ndarray:
    n, d = x.shape
    geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=epsilon)
    out = linear.solve(geom, implicit_diff=None, inner_iterations=inner_iterations)
    transport_matrix = out.matrix

    z = jnp.expand_dims(x, 1) - y  # (nx, ny, d)
    reg_displacements = jax.vmap(cost_fn.reg)(z.reshape(-1, d)).reshape(n, -1)

    return jnp.sum(transport_matrix * reg_displacements)


def gradient_step(
    cost: costs.RegTICost, grads: costs.RegTICost, alpha: float
) -> Tuple[costs.RegTICost, jnp.ndarray]:
    matrix, matrix_grad = cost.matrix, grads.matrix
    if matrix_grad.ndim == 2:
      riem_grad = matrix_grad - matrix @ matrix_grad.T @ matrix
      matrix_new = proj(matrix - alpha * riem_grad)
      grad_norm = jnp.linalg.norm(riem_grad)
    else:
      # diagonal matrix, no projection
      matrix_new = matrix - alpha * matrix_grad
      grad_norm = jnp.linalg.norm(matrix_grad)

    return (
        cost.tree_unflatten(
            {"orthogonal": cost.orthogonal}, (cost.scaling_reg, matrix_new)
        ),
        grad_norm,
    )


def train(
    cost_fn: costs.RegTICost,
    x: jnp.ndarray,
    y: jnp.ndarray,
    loss: Callable[[costs.RegTICost, jnp.ndarray, jnp.ndarray, float], float],
    step_size_schedule: Callable[[int], float] = lambda _: 0.1,
    n_iter: int = 500,
    epsilon: Optional[float] = -1.0,
    gt_cost: Optional[costs.RegTICost] = None,
    inner_iterations: int = 10,
) -> ExperimentOutput:
    loss_grad = jax.jit(jax.value_and_grad(loss), static_argnames=["inner_iterations"])

    if epsilon is not None and epsilon <= 0.0:
        epsilon = pointcloud.PointCloud(x, y, cost_fn=cost_fn).epsilon * 0.2

    losses, norm_grads, criteria = [], [], []
    best_obj, best_pred_cost = jnp.inf, cost_fn
    for i in (pbar := tqdm(range(n_iter))):
        obj, grads = loss_grad(cost_fn, x, y, epsilon, inner_iterations)
        # Check Sinkhorn converged
        if not jnp.isfinite(obj):
            epsilon = epsilon * 1.1
            pbar.set_description(
                f"Obj: NaN  |Grad|: NaN  Loss: NaN  Eps: {epsilon:.5f}"
            )
            continue

        alpha = step_size_schedule(i)
        cost_fn, grad_norm = gradient_step(cost_fn, grads, alpha=alpha)

        if gt_cost is not None:
          if isinstance(gt_cost, costs.ElasticL1Diag):
            criterion = costs.Cosine()(gt_cost.diag, cost_fn.diag)
          else:
            criterion = subspace_criterion(gt=gt_cost.matrix, pred=cost_fn.matrix)
        else:
            criterion = jnp.array(jnp.nan)

        pbar.set_description(
            f"Obj: {obj:.5f} |Grad|: {grad_norm:.5f} Criterion: {criterion:.5f}  Eps: {epsilon:.5f}"
        )

        if obj < best_obj:
            best_obj = obj
            best_pred_cost = cost_fn

        losses.append(obj.item())
        norm_grads.append(grad_norm.item())
        criteria.append(criterion.item())

    return ExperimentOutput(
        x=x,
        y=y,
        gt_cost=gt_cost,
        pred_cost=cost_fn,
        best_pred_cost=best_pred_cost,
        losses=losses,
        criteria=criteria,
        norm_grads=norm_grads,
    )


def get_scale_reg(
    *,
    rng: jax.Array,
    rng_mat: jax.Array,
    target_criterion: float,
    n: int,
    d: int,
    d_proj: int,
    concave_fn: Callable[[jnp.ndarray], float],
    cost_type: Literal["el-l1", "el2", "el-l1-diag"],
    use_bisect: bool = True,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], costs.RegTICost, jnp.ndarray]:

    def callback(sr: float) -> float:
        (_, mat), aux_data_ = gt_cost.tree_flatten()
        cost_fn = type(gt_cost).tree_unflatten(aux_data_, (sr, mat))

        _, crit = generate_data(
            rng=rng, gt_cost=cost_fn, n=n, d=d, concave_fn=concave_fn
        )
        print(f"sr={cost_fn.scaling_reg:.4f}, crit={crit:.4f}, target={target_criterion}")

        return crit - target_criterion

    if cost_type == "el-l2":
      matrix = proj(jax.random.normal(rng_mat, (d_proj, d)))
      gt_cost = costs.ElasticL2(scaling_reg=1e-1, matrix=matrix, orthogonal=True)
    elif cost_type == "el-l1-diag":
      matrix = jax.random.normal(rng_mat, (d,))
      gt_cost = costs.ElasticL1Diag(scaling_reg=1e-1, matrix=matrix, orthogonal=False)
    elif cost_type == "el-l1":
      gt_cost = costs.ElasticL1(scaling_reg=1e-1)
    else:
      raise NotImplementedError(cost_type)

    if use_bisect:
        sr, res = bisect(callback, a=1e-3, b=1e6, maxiter=64, disp=False, full_output=True, rtol=1e-2)
        print(res)
        (_, matrix), aux_data = gt_cost.tree_flatten()
        gt_cost = type(gt_cost).tree_unflatten(aux_data, (sr, matrix))
        data, criterion = generate_data(
            rng=rng, gt_cost=gt_cost, n=n, d=d, concave_fn=concave_fn,
        )
        return data, gt_cost, criterion

    while True:
        data, criterion = generate_data(
            rng=rng, gt_cost=gt_cost, n=n, d=d, concave_fn=concave_fn
        )
        print(f"sr={gt_cost.scaling_reg:.4f}, crit={criterion:.4f}, target={target_criterion}")
        if criterion > target_criterion:
            break
        (scaling_reg, matrix), aux_data = gt_cost.tree_flatten()
        gt_cost = type(gt_cost).tree_unflatten(aux_data, (scaling_reg * 1.2, matrix))

    return data, gt_cost, criterion


def generate_data(
    *,
    rng: jax.random.PRNGKey,
    gt_cost: costs.RegTICost,
    n: int,
    d: int,
    concave_fn: Callable[[jnp.ndarray], float],
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]:
    rng1, rng2 = jax.random.split(rng)
    g_h = gt_cost.h_transform(concave_fn, variable=True)
    transport_fn = jax.jit(jax.vmap(lambda x: x - gt_cost.prox_reg(jax.grad(g_h)(x))))

    x = jax.random.normal(rng1, (n, d))
    x_te = jax.random.normal(rng2, (n, d))
    y = transport_fn(x)
    y_te = transport_fn(x_te)

    assert isinstance(gt_cost, (costs.ElasticL1, costs.ElasticL2, costs.ElasticL1Diag)), type(gt_cost)
    if isinstance(gt_cost, (costs.ElasticL1, costs.ElasticL1Diag)):
        criterion = jnp.mean(jnp.abs(y_te - x_te) <= 1e-4)
    else:
        d_proj = gt_cost.matrix.shape[0]
        criterion = svd_peakiness(x_te, y_te, d_proj=d_proj)

    return (x, y, x_te, y_te), criterion


@jax.jit
def proj(matrix: jnp.ndarray) -> jnp.ndarray:
    u, _, v_h = jnp.linalg.svd(matrix, full_matrices=False)
    return u.dot(v_h)


def svd_peakiness(x: jnp.ndarray, y: jnp.ndarray, *, d_proj: int) -> jnp.ndarray:
    z = y - x
    svalues = jnp.linalg.svd(z)[1]
    return jnp.sum(svalues[:d_proj]) / jnp.sum(svalues)


@jax.jit
def subspace_criterion(*, gt: jnp.ndarray, pred: jnp.ndarray) -> jnp.ndarray:
    V = gt
    U = pred
    tmp = jsp.linalg.solve(U @ U.T, U @ V.T)
    loss = jnp.linalg.norm(V.T - U.T @ tmp) ** 2
    return loss / V.shape[0]


@jax.jit
def get_mse(x, y, x_te, y_te, cost_fn, epsilon) -> jnp.ndarray:
    batch_size = 512
    pc = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=epsilon, batch_size=batch_size)
    out = linear.solve(pc)
    dp = out.to_dual_potentials()

    (n, _), y_est = x_te.shape, []
    for i in range(0, n, batch_size):
        tmp = x_te[i: i + batch_size]
        tmp = dp.transport(tmp)
        y_est.append(tmp)
    y_est = jnp.concatenate(y_est)

    return jnp.mean(jnp.sum((y_est - y_te) ** 2, axis=-1))
