# a Pure Exploration problem (pep)
# - a query, as embodied by a correct-answer function istar
# - nanswers: number of possible answers
# - istar: correct answer for feasible μ
# - glrt: value and best response (λ and ξ) to (N, μ) or (w, μ)
# - oracle: characteristic time and oracle weights at μ

```
Best Arm
```

struct BestArm
    expfams;    # Exponential family
end

nanswers(pep::BestArm, μ) = length(pep.expfams);
istar(pep::BestArm, μ) = argmax(μ);
getexpfam(pep::BestArm, k) = pep.expfams[k];
long(pep::BestArm) = "BAI for " * (typeof(getexpfam(pep, 1)) == Bernoulli ? "Bernoulli" : "Gaussian") * " bandits";

# Alternative parameter
function alt_λ(μ1, w1, μa, wa)
    if w1 == 0
        return μa;
    end
    if wa == 0 || μ1 == μa
        return μ1;
    end
    x = wa / w1;
    return (μ1 + x * μa) / (1 + x);
end

function glrt(pep, w, μ)
    @assert length(size(μ)) == 1
    expfam = getexpfam(pep, 1);
    K = length(μ);
    astar = argmax(μ); # index of best arm among μ

    vals = Inf * ones(K);
    θs = zeros(K);
    for a in 1:K
        if μ[a] < μ[astar]
            θs[a] = alt_λ(μ[astar], w[astar], μ[a], w[a]);
            vals[a] = w[astar] * d(expfam, μ[astar], θs[a]) + w[a] * d(expfam, μ[a], θs[a]);
        elseif a != astar
            θs[a] = μ[a];
            vals[a] = 0;
        end
    end
    k = argmin(vals);

    λ = copy(μ);
    λ[astar] = θs[k];
    λ[k] = θs[k];

    vals, (k, λ), (astar, μ);
end

# Solve for x such that d1(μx) + x*da(μx) == v
function X(expfam, μ1, μa, v)
    upd_a = d(expfam, μ1, μa); # range of V(x) is [0, upd_a]
    @assert 0 ≤ v ≤ upd_a "0 ≤ $v ≤ $upd_a";
    α = binary_search(
        z -> let uz = alt_λ(μ1, 1 - z, μa, z)
        (1 - z) * d(expfam, μ1, uz) + z * d(expfam, μa, uz) - (1 - z) * v
        end,
        0, 1, ϵ = upd_a*1e-10);
    α/(1-α), alt_λ(μ1, 1 - α, μa, α);
end

# Oracle solution
function oracle(pep, μs)
    μstar = maximum(μs);
    expfam = getexpfam(pep, 1);

    if all(μs .== μstar) # yes, this happens
        return Inf, ones(length(μs))/length(μs);
    end

    astar = argmax(μs);

    # determine upper range for subsequent binary search
    hi = minimum(
        d(expfam, μs[astar], μs[k])
        for k in eachindex(μs)
        if k != astar
    );

    val = binary_search(
        z -> sum(
            let ux = X(expfam, μs[astar], μs[k], z)[2];
            d(expfam, μs[astar], ux) / d(expfam, μs[k], ux)
            end
            for k in eachindex(μs)
            if k != astar
            ) - 1.0,
        0, hi);

    ws = [(k == astar) ? 1. : X(expfam, μs[astar], μs[k], val)[1] for k in eachindex(μs)];
    Σ = sum(ws);
    Σ / val, ws ./ Σ;
end

# Oracle solution for β = 1/2
function oracle_beta_half(pep, μs)
    μstar = maximum(μs);
    expfam = getexpfam(pep, 1);

    if all(μs .== μstar) # yes, this happens
        return Inf, ones(length(μs))/length(μs);
    end

    astar = argmax(μs);

    # determine upper range for subsequent binary search
    hi = minimum([(μs[astar] - μs[k])^2 for k in eachindex(μs) if k != astar]);

    # Binary search
    val = binary_search(
        z -> sum([1 / ((μs[astar] - μs[k])^2 / z - 1) for k in eachindex(μs) if k != astar])  - 1.0,
        0, hi);
    inv_val = 1 / val;

    ws = [(k == astar) ? 1/2 : 0.5 / (inv_val * (μs[astar] - μs[k])^2 - 1) for k in eachindex(μs)];
    4 * inv_val, ws;
end
