import numpy as np

def g_new(x, alpha):
    # Find the largest term in the exponent to factor out
    max_exp = np.maximum((alpha - 1) * x, -alpha * x)
    
    # Factor out the largest term
    exp_term1 = np.exp((alpha - 1) * x - max_exp)
    exp_term2 = np.exp(-alpha * x - max_exp)
    
    # Rewrite the original function using the factored-out terms
    result = np.log((alpha / (2 * alpha - 1) * exp_term1 + (alpha - 1) / (2 * alpha - 1) * exp_term2)) + max_exp
    return result

def compute_epsilon_new(sigma, alpha, beta, eta, eta_l1, tau, T, delta):
    """Compute epsilon for the given sigma and other parameters."""
    term1 = (T - tau) / (alpha - 1)
    term2 = g_new(np.sqrt(2) * eta_l1 * (1 - beta) / sigma, alpha)
    term3 = 1 / (alpha - 1)
    term4 = g_new(np.sqrt(2) * beta**(T - tau) * eta_l1 / sigma, alpha)
    term5 = np.log(1/delta) / (alpha - 1)
    return term1 * term2 + term3 * term4 + term5

def find_sigma(epsilon, alpha, beta, eta, eta_l1, tau, T, delta, sigma_min, sigma_max, tolerance=1e-11):
    """Find sigma using binary search within the given range."""
    while sigma_max - sigma_min > tolerance:
        sigma_mid = (sigma_min + sigma_max) / 2
        epsilon_mid = compute_epsilon_new(sigma_mid, alpha, beta, eta, eta_l1, tau, T, delta)
        if epsilon_mid < epsilon:
            sigma_max = sigma_mid
        else:
            sigma_min = sigma_mid
    return sigma_min

def noise_calibration_laplace(epsilon, delta, eta, eta_l1, beta, max_iter, middlestep = 0):
    # Define the range for alpha values
    alpha_values = np.concatenate((np.arange(1.1, 101), [200, 500, 1000, 2000, 5000, 1e4, 5e4, 1e5, 1e6, 1e7, 1e8]))
    # Initialize the minimum sigma and corresponding tau and alpha
    min_sigma = float('inf')
    best_tau = None
    best_alpha = None
    
    # Loop over each alpha and tau value
    for alpha in alpha_values:
        for tau in range(middlestep, max_iter):
            # Use binary search to find the sigma that yields the correct epsilon
            sigma = find_sigma(epsilon, alpha, beta, eta, eta_l1, tau, max_iter, delta, 1e-12, 10)
            # Check if this is the smallest sigma found so far
            if sigma < min_sigma:
                min_sigma = sigma
                best_tau = tau
                best_alpha = alpha
    
    # print(f'Optimal tau: {best_tau}, Optimal alpha: {alpha}, New calibrated laplacian noise: {min_sigma}')
    
    # Return the minimum sigma and corresponding tau and alpha
    return min_sigma, best_tau, best_alpha