from spaghettini import quick_register

import random

import torch

RELATIVE_RESIDUAL_EPS = 1e-5


@quick_register
def fixed_point_iterator(f, x0, num_iters):
    z = x0
    for _ in range(num_iters):
        z = f(z)

    # Compute convergence metrics.
    get_sample_norm = lambda x:  torch.sqrt(torch.sum(x**2, dim=list(range(1, len(x.shape)))))
    z_plus_one = f(z)
    diff = z_plus_one - z
    diff_l2 = get_sample_norm(diff)
    rel_diff = diff_l2 / (get_sample_norm(z_plus_one) + RELATIVE_RESIDUAL_EPS)

    output_dict = dict(result=z, num_iters=num_iters, diff_l2=diff_l2, rel_diff=rel_diff)

    return output_dict


@quick_register
def fixed_point_iterator_with_random_depth(f, x0, min_num_iters, max_num_iters):
    assert min_num_iters >= 2
    assert min_num_iters <= max_num_iters
    z = x0

    num_iters = random.randint(min_num_iters, max_num_iters)
    for _ in range(num_iters):
        z = f(z)

    # Compute convergence metrics.
    get_sample_norm = lambda x:  torch.sqrt(torch.sum(x**2, dim=list(range(1, len(x.shape)))))
    z_plus_one = f(z)
    diff = z_plus_one - z
    diff_l2 = get_sample_norm(diff)
    rel_diff = diff_l2 / (get_sample_norm(z_plus_one) + RELATIVE_RESIDUAL_EPS)

    output_dict = dict(result=z, num_iters=num_iters, diff_l2=diff_l2, rel_diff=rel_diff)

    return output_dict


@quick_register
def truncated_fixed_point_iterator(f, x0, num_iters, num_keep_grads_iters="half"):
    """A solver where the backwards pass only backprops throough the final `num_keep_grads_iters` iterations. """
    if num_keep_grads_iters == "half":
        num_keep_grads_iters = num_iters // 2

    if num_keep_grads_iters > num_iters:
        message = f"The number of iters to keep in backwards pass cannot be larger than the total number of iterations."
        raise ValueError(message)

    num_no_grad_iter = num_iters - num_keep_grads_iters

    z = x0
    for _ in range(num_no_grad_iter):
        z = f(z)

    # Detach from the computational graph. Keep iterating to reach num_iters number of iterations.
    z = z.clone().detach()

    for _ in range(num_keep_grads_iters):
        z = f(z)

    # Compute convergence metrics.
    get_sample_norm = lambda x:  torch.sqrt(torch.sum(x**2, dim=list(range(1, len(x.shape)))))
    z_plus_one = f(z)
    diff = z_plus_one - z
    diff_l2 = get_sample_norm(diff)
    rel_diff = diff_l2 / (get_sample_norm(z_plus_one) + RELATIVE_RESIDUAL_EPS)

    output_dict = dict(result=z, num_iters=num_iters, diff_l2=diff_l2, rel_diff=rel_diff,
                       num_keep_grads_iters=num_keep_grads_iters)

    return output_dict
