using Distributions, LinearAlgebra, Random, Lasso
using DataFrames, CSV
BLAS.set_num_threads(1)

include("datagen.jl")
include("functions.jl")

case = haskey(ENV, "case") ? parse(Int, ENV["case"]) : 1


Np = 500; N = 500000

k = 50

S = 500

rhos = [0.0025, 0.005, 0.0075, 0.01]

lthr = 1e-8; gammas = [1]

dis = "nor"

ver = "1"
crtn = "P"

#Data generating
par = getcase(zeros(k), case)
name = par.name
true_idx = par.true_idx
beta0 = par.beta0
alpha0 = par.alpha0

Sigma = getSigma(par)

# Subsampling experiment
Random.seed!(2)
rst = Matrix{Float64}(undef,0,17)
@time for gamma in gammas, rho in rhos
    aerr = fill(NaN, S)
    berr = fill(NaN, S)
    perr = fill(NaN, S)
    aest = fill(NaN, S)
    betaest = fill(NaN, S, k)
    cover = zeros(S)
    fcover = zeros(S)
    overcover = zeros(S)
    fvnum = fill(NaN, S)
    svnum = fill(NaN, S)
    aucs = fill(NaN, S)
    flsrts = fill(NaN, S)
    lambdas = fill(NaN, S)
    iter_ins = fill(NaN, S)
    iter_outs = fill(NaN, S)
    t_plts = fill(NaN, S)
    t_ests = fill(NaN, S)
    @time for i in 1:S
        X, y = gendat(N, alpha0, beta0, Sigma)
        t_pl = @elapsed plt = PilotEst2(X, y, Np, criterion = string(crtn,"-opt"),
                                        standardize=true)
        t_es = @elapsed est = SubsampleEst2(X, y, plt, gamma, rho, lthr,
                                            nlambda = 100, eps = 0.001,
                                            method = "bic", criterion = string(crtn,"-opt"),
                                            standardize=false)
        t_plts[i] = t_pl
        t_ests[i] = t_es
        if est.message == "Successful convergence"
            aerr[i] =  (est.alpha - alpha0[1])^2
            berr[i] = sum((est.adpbetas .- beta0) .^ 2)
            ptrue = 1 .- 1 ./ (1 .+ exp.(alpha0[1] .+ X * beta0))
            pest = 1 .- 1 ./ (1 .+ exp.(est.alpha .+ X * est.adpbetas))
            perr[i] = sum((ptrue - pest) .^ 2) / N
            fvnum[i] = length(est.fscr_idx)
            svnum[i] = length(est.sscr_idx)
            aucs[i] = est.auc
            flsrts[i] = est.flsrt
            lambdas[i] = est.lambda[1]
            aest[i] = est.alpha
            betaest[i, :] = est.adpbetas
            if issubset(true_idx, est.fscr_idx)
                fcover[i] = 1
            end
            if issubset(true_idx, est.sscr_idx)
                cover[i] = 1
                if length(true_idx) < length(est.sscr_idx)
                    overcover[i] = 1
                end
            end
        end
    end
    conv_idx = berr .!== NaN
    n_no_conv = S - sum(conv_idx)
    t_plt = sum(t_plts[2:end])
    t_est = sum(t_ests[2:end])
    amse = median(aerr[conv_idx])
    bmse = median(berr[conv_idx])
    pmse = median(perr[conv_idx])
    abias = (mean(aest[conv_idx])-alpha0[1])^2
    aVar = var(aest[conv_idx])
    bbias = norm(mean.(eachcol(betaest[conv_idx,:] .- beta0')))^2
    bVar = tr(cov(betaest[conv_idx,:]))
    covr = 1 - mean(cover)
    vcovr = std(1 .- cover)/sqrt(S)
    fcovr = mean(fcover)
    ovcor = mean(cover) - mean(overcover)
    vovcor = std(cover .- overcover)/sqrt(S)
    mfvnum = mean(fvnum[conv_idx])
    vfvnum = std(fvnum[conv_idx])/sqrt(S)
    msvnum = mean(svnum[conv_idx])
    vsvnum = std(svnum[conv_idx])/sqrt(S)
    auc = median(aucs[conv_idx])
    flsrt = median(flsrts[conv_idx])
    lmbdbst = mean(lambdas[conv_idx])
    global rst = [rst; [gamma rho bmse pmse covr vcovr ovcor vovcor mfvnum vfvnum msvnum vsvnum auc flsrt lmbdbst t_plt+t_est n_no_conv]]
end
rst = DataFrame(rst, [:gamma, :rho, :bmse, :pmse, :covr, :vcovr, :ovcor, :vovcor, :mfvnum, :vfvnum, :msvnum, :vsvnum, :auc, :flsrt, :lambda, :ttotal, :noconv])
mkpath("results")
path = string("results/Subsampling-", ver, "-", dis, "-", name, "-", crtn, ".csv")
CSV.write(path, rst)
