using LightGraphs
using DelimitedFiles
using SparseArrays
using MAT
using Random
using JLD2,FileIO
using PyCall
using PyPlot
using MatrixNetworks
using Statistics

include("PageRank.jl")
include("SLQcvx.jl") # this includes SLQ.jl
include("common.jl")
include("CRDlgc.jl")
include("HeatKernel.jl")
include("FlowSeed-1.0.jl")
include("GCN.jl")

dataset = "livejournal"
#vars = matread("livejournal.mat")
vars = matread("$dataset.mat")

A = vars["A"]
A = Int.(A)
C = vars["C"]
csize = []
for i = 1:size(C,2)
    push!(csize,nnz(C[:,i]))
end
G = SLQ.graph(A)

c34 = round(Int,maximum(csize)^(3/4))
cids = sortperm(abs.(csize .- c34))[1:600]
k_list = collect(1:200:round(Int,c34/2))
#k_list = collect(1:10:round(Int,c34/2))

q_list = [1.5,4.0,8.0,16.0]
q_kappa_map = Dict()
q_kappa_map[1.5] = 0.02
# q_kappa_map[4.0] = 0.0005
# q_kappa_map[8.0] = 0.00001
q_kappa_map[4.0] = 0.001
q_kappa_map[8.0] = 0.00001
# q_kappa_map[16.0] = 0.000000001
rc_slq_records = Dict(string(q)=>zeros(length(k_list),600) for q in q_list)
rc_slq_degnorm_records = Dict(string(q)=>zeros(length(k_list),600) for q in q_list)
rc_acl_all = zeros(length(k_list),600)
rc_acl_degnorm_all = zeros(length(k_list),600)

global total_time = 0.0
for (j,cid) in enumerate(cids)
    truth = C[:,cid].nzind
    _,_,_,conductance = set_stats(A.*1.0,truth,sum(G.deg)*1.0)
    n = length(truth)
    seed = 1
    S = truth[randperm(MersenneTwister(seed),n)[1:max(1,round(Int,0.1*n))]]
    curr_time = @elapsed x_acl = PageRank.acl_diffusion(G,S,0.05,0.002)
    global total_time += curr_time
    x_acl_degnorm = x_acl./G.deg
    sorted_ids = sortperm(-1*x_acl)
    sorted_ids_degnorm = sortperm(-1*x_acl_degnorm)
    for (i,k) in enumerate(k_list)
        cluster_acl = sorted_ids[1:k]
        pr_acl,rc_acl = compute_pr_rc(cluster_acl,truth)
        rc_acl_all[i,j] = rc_acl
        cluster_acl_degnorm = sorted_ids_degnorm[1:k]
        pr_acl_degnorm,rc_acl_degnorm = compute_pr_rc(cluster_acl_degnorm,truth)
        rc_acl_degnorm_all[i,j] = rc_acl_degnorm
        @show j,k,rc_acl,rc_acl_degnorm,curr_time,total_time
    end
end
jldopen("results/$dataset-acl.jld2", "w") do file
    file["acl"] = rc_acl_all
    file["time"] = total_time
end
jldopen("results/$dataset-acl-degnorm.jld2", "w") do file
    file["acl-degnorm"] = rc_acl_degnorm_all
    file["time"] = total_time
end

q_list = [1.5]

for q in q_list
    kappa = q_kappa_map[q]
    total_time = 0.0
    for (j,cid) in enumerate(cids)
        truth = C[:,cid].nzind
        _,_,_,conductance = set_stats(A.*1.0,truth,sum(G.deg)*1.0)
        n = length(truth)
        seed = 1
        delta = 0.0
        S = truth[randperm(MersenneTwister(seed),n)[1:max(1,round(Int,0.1*n))]]
        L = SLQ.QHuberLoss(q, delta)
        curr_time = @elapsed (x_slq_degnorm,r,iter) = SLQ.slq_diffusion(G, S, 0.05, kappa, 0.9, L, max_iters=10000000,epsilon=1.0e-8)
        total_time += curr_time
        x_slq = (x_slq_degnorm.^(q-1)).*G.deg
        sorted_ids = sortperm(-1*x_slq)
        sorted_ids_degnorm = sortperm(-1*x_slq_degnorm)
        for (i,k) in enumerate(k_list)
            cluster_slq = sorted_ids[1:k]
            pr_slq,rc_slq = compute_pr_rc(cluster_slq,truth)
            rc_slq_records[string(q)][i,j] = rc_slq
            cluster_slq_degnorm = sorted_ids_degnorm[1:k]
            pr_slq_degnorm,rc_slq_degnorm = compute_pr_rc(cluster_slq_degnorm,truth)
            rc_slq_degnorm_records[string(q)][i,j] = rc_slq_degnorm
            @show j,k,q,rc_slq,rc_slq_degnorm,curr_time,total_time
        end
    end
    jldopen("results/$dataset-slq.jld2", "a+") do file
        file[string(q)] = rc_slq_records[string(q)]
    end
    jldopen("results/$dataset-slq-degnorm.jld2", "a+") do file
        file[string(q)] = rc_slq_degnorm_records[string(q)]
    end
end







q_list = [4.0]

for q in q_list
    kappa = q_kappa_map[q]
    total_time = 0.0
    for (j,cid) in enumerate(cids)
        truth = C[:,cid].nzind
        _,_,_,conductance = set_stats(A.*1.0,truth,sum(G.deg)*1.0)
        n = length(truth)
        seed = 1
        delta = 0.0
        S = truth[randperm(MersenneTwister(seed),n)[1:max(1,round(Int,0.1*n))]]
        L = SLQ.QHuberLoss(q, delta)
        curr_time = @elapsed (x_slq_degnorm,r,iter) = SLQ.slq_diffusion(G, S, 0.05, kappa, 0.9, L, max_iters=10000000,epsilon=1.0e-8)
        total_time += curr_time
        x_slq = (x_slq_degnorm.^(q-1)).*G.deg
        sorted_ids = sortperm(-1*x_slq)
        sorted_ids_degnorm = sortperm(-1*x_slq_degnorm)
        for (i,k) in enumerate(k_list)
            cluster_slq = sorted_ids[1:k]
            pr_slq,rc_slq = compute_pr_rc(cluster_slq,truth)
            rc_slq_records[string(q)][i,j] = rc_slq
            cluster_slq_degnorm = sorted_ids_degnorm[1:k]
            pr_slq_degnorm,rc_slq_degnorm = compute_pr_rc(cluster_slq_degnorm,truth)
            rc_slq_degnorm_records[string(q)][i,j] = rc_slq_degnorm
            @show j,k,q,rc_slq,rc_slq_degnorm,curr_time,total_time
        end
    end
    jldopen("results/$dataset-slq.jld2", "a+") do file
        file[string(q)] = rc_slq_records[string(q)]
    end
    jldopen("results/$dataset-slq-degnorm.jld2", "a+") do file
        file[string(q)] = rc_slq_degnorm_records[string(q)]
    end
end












q_list = [8.0]

for q in q_list
    kappa = q_kappa_map[q]
    total_time = 0.0
    for (j,cid) in enumerate(cids)
        truth = C[:,cid].nzind
        _,_,_,conductance = set_stats(A.*1.0,truth,sum(G.deg)*1.0)
        n = length(truth)
        seed = 1
        delta = 0.0
        S = truth[randperm(MersenneTwister(seed),n)[1:max(1,round(Int,0.1*n))]]
        L = SLQ.QHuberLoss(q, delta)
        curr_time = @elapsed (x_slq_degnorm,r,iter) = SLQ.slq_diffusion(G, S, 0.05, kappa, 0.9, L, max_iters=10000000,epsilon=1.0e-8)
        total_time += curr_time
        x_slq = (x_slq_degnorm.^(q-1)).*G.deg
        sorted_ids = sortperm(-1*x_slq)
        sorted_ids_degnorm = sortperm(-1*x_slq_degnorm)
        for (i,k) in enumerate(k_list)
            cluster_slq = sorted_ids[1:k]
            pr_slq,rc_slq = compute_pr_rc(cluster_slq,truth)
            rc_slq_records[string(q)][i,j] = rc_slq
            cluster_slq_degnorm = sorted_ids_degnorm[1:k]
            pr_slq_degnorm,rc_slq_degnorm = compute_pr_rc(cluster_slq_degnorm,truth)
            rc_slq_degnorm_records[string(q)][i,j] = rc_slq_degnorm
            @show j,k,q,rc_slq,rc_slq_degnorm,curr_time,total_time
        end
    end
    jldopen("results/$dataset-slq.jld2", "a+") do file
        file[string(q)] = rc_slq_records[string(q)]
    end
    jldopen("results/$dataset-slq-degnorm.jld2", "a+") do file
        file[string(q)] = rc_slq_degnorm_records[string(q)]
    end
end











# q_list = [16.0]

# for q in q_list
#     kappa = q_kappa_map[q]
#     for (j,cid) in enumerate(cids)
#         truth = C[:,cid].nzind
#         _,_,_,conductance = set_stats(A.*1.0,truth,sum(G.deg)*1.0)
#         n = length(truth)
#         seed = 1
#         delta = 0.0
#         S = truth[randperm(MersenneTwister(seed),n)[1:max(1,round(Int,0.1*n))]]
#         L = SLQ.QHuberLoss(q, delta)
#         (x_slq_degnorm,r,iter) = SLQ.slq_diffusion(G, S, 0.05, kappa, 0.9, L, max_iters=10000000,epsilon=1.0e-8)
#         x_slq = (x_slq_degnorm.^(q-1)).*G.deg
#         for (i,k) in enumerate(k_list)
#             cluster_slq = sortperm(-1*x_slq)[1:k]
#             pr_slq,rc_slq = compute_pr_rc(cluster_slq,truth)
#             rc_slq_records[string(q)][i,j] = rc_slq
#             cluster_slq_degnorm = sortperm(-1*x_slq_degnorm)[1:k]
#             pr_slq_degnorm,rc_slq_degnorm = compute_pr_rc(cluster_slq_degnorm,truth)
#             rc_slq_degnorm_records[string(q)][i,j] = rc_slq_degnorm
#             @show j,k,q,rc_slq,rc_slq_degnorm
#         end
#     end
#     jldopen("livejournal-slq.jld2", "a+") do file
#         file[string(q)] = rc_slq_records[string(q)]
#     end
#     jldopen("livejournal-slq-degnorm.jld2", "a+") do file
#         file[string(q)] = rc_slq_degnorm_records[string(q)]
#     end
# end











dataset = "dblp"
vars = matread("$dataset.mat")
A = vars["A"]
A = Int.(A)
C = vars["C"]
csize = []
for i = 1:size(C,2)
    push!(csize,nnz(C[:,i]))
end
G = SLQ.graph(A)

c34 = round(Int,maximum(csize)^(3/4))
cids = sortperm(abs.(csize .- c34))[1:600]
k_list = collect(1:1:round(Int,c34/2))
rc_slq_records = load("results/$dataset-slq.jld2")
rc_slq_records_degnorm = load("results/$dataset-slq-degnorm.jld2")
rc_acl_all = load("results/$dataset-acl.jld2","acl")
rc_acl_all_degnorm = load("results/$dataset-acl-degnorm.jld2","acl")

fig,ax = subplots(1,1,figsize=(5,4))
q_list = [1.5,4.0,8.0]
offset = [-0.04,-0.02,0.001]
offset_degnorm = [-0.065,-0.04,-0.02]
for (i,q) in enumerate(q_list)
    rc_slq_all = rc_slq_records[string(q)]
    rc_slq_all_degnorm = rc_slq_records_degnorm[string(q)]
    med_slq = vec(median(rc_slq_all,dims=2))
    med_slq_degnorm = vec(median(rc_slq_all_degnorm,dims=2))
    p = ax.plot(k_list,med_slq)
    stderror = vec(std(rc_slq_all,dims=2))./sqrt(600)
    ax.fill_between(k_list, med_slq-2*stderror, med_slq+2*stderror, alpha=0.3, color=p[1].get_color())
    p = ax.plot(k_list,med_slq_degnorm)
    stderror = vec(std(rc_slq_all_degnorm,dims=2))./sqrt(600)
    ax.fill_between(k_list, med_slq_degnorm-2*stderror, med_slq_degnorm+2*stderror, alpha=0.3, color=p[1].get_color())
    ax.text(420,med_slq[end]+offset[i],"SLQ (q=$q)",fontsize=14)
    ax.plot([405,418],[med_slq[end],med_slq[end]+offset[i]+0.01],color="k")
    ax.plot([405,418],[med_slq_degnorm[end],med_slq_degnorm[end]+offset_degnorm[i]+0.02],color="k")
    ax.text(420,med_slq_degnorm[end]+offset_degnorm[i]+0.015,"SLQ normalized (q=$q)",fontsize=14)
    @show med_slq[end]
end

for tick in ax.xaxis.get_major_ticks()
    tick.label.set_fontsize(18)
end
for tick in ax.yaxis.get_major_ticks()
    tick.label.set_fontsize(18)
end

med_acl = vec(median(rc_acl_all,dims=2))
med_acl_degnorm = vec(median(rc_acl_all_degnorm,dims=2))
p = ax.plot(k_list,med_acl)
stderror = vec(std(rc_acl_all,dims=2))./sqrt(600)
ax.fill_between(k_list, med_acl-2*stderror, med_acl+2*stderror, alpha=0.3, color=p[1].get_color())
p = ax.plot(k_list,med_acl_degnorm)
stderror = vec(std(rc_acl_all_degnorm,dims=2))./sqrt(600)
ax.fill_between(k_list, med_acl_degnorm-2*stderror, med_acl_degnorm+2*stderror, alpha=0.3, color=p[1].get_color())
ax.text(420,med_acl[end]-0.04,"ACL",fontsize=14)
ax.plot([405,418],[med_acl[end],med_acl[end]-0.03],color="k")
ax.plot([405,418],[med_acl_degnorm[end],med_acl_degnorm[end]-0.02],color="k")
ax.text(420,med_acl_degnorm[end]-0.035,"ACL normalized",fontsize=14)
ax.spines["top"].set_visible(false)
ax.spines["right"].set_visible(false)
fig.savefig("figures/600-communities-$dataset.pdf",format="pdf", bbox_inches="tight")



dataset = "liveJournal"
vars = matread("$dataset.mat")
A = vars["A"]
A = Int.(A)
C = vars["C"]
csize = []
for i = 1:size(C,2)
    push!(csize,nnz(C[:,i]))
end
G = SLQ.graph(A)

c34 = round(Int,maximum(csize)^(3/4))
cids = sortperm(abs.(csize .- c34))[1:600]
k_list = collect(1:10:round(Int,c34/2))
rc_slq_records = load("results/$dataset-slq.jld2")
rc_slq_records_degnorm = load("results/$dataset-slq-degnorm.jld2")
rc_acl_all = load("results/$dataset-acl.jld2","acl")
rc_acl_all_degnorm = load("results/$dataset-acl-degnorm.jld2","acl")

fig,ax = subplots(1,1,figsize=(5,4))
q_list = [1.5,8.0,4.0]
q_list_degnorm = [1.5,4.0,8.0]
offset = [0.01,-0.01,0.01]
offset_degnorm = [-0.03,-0.04,-0.025]
for (i,q) in enumerate(q_list)
    rc_slq_all = rc_slq_records[string(q)]
    med_slq = vec(median(rc_slq_all,dims=2))
    p = ax.plot(k_list,med_slq)
    stderror = vec(std(rc_slq_all,dims=2))./sqrt(600)
    ax.fill_between(k_list, med_slq-2*stderror, med_slq+2*stderror, alpha=0.3, color=p[1].get_color())
    ax.text(4520,med_slq[end]+offset[i],"SLQ (q=$q)",fontsize=14)
    ax.plot([4370,4470],[med_slq[end],med_slq[end]+offset[i]+0.01],color="k")
    @show med_slq[end]
    q = q_list_degnorm[i]
    rc_slq_all_degnorm = rc_slq_records_degnorm[string(q)]
    med_slq_degnorm = vec(median(rc_slq_all_degnorm,dims=2))
    p = ax.plot(k_list,med_slq_degnorm)
    stderror = vec(std(rc_slq_all_degnorm,dims=2))./sqrt(600)
    ax.fill_between(k_list, med_slq_degnorm-2*stderror, med_slq_degnorm+2*stderror, alpha=0.3, color=p[1].get_color())
    ax.plot([4370,4470],[med_slq_degnorm[end],med_slq_degnorm[end]+offset_degnorm[i]+0.025],color="k")
    ax.text(4520,med_slq_degnorm[end]+offset_degnorm[i]+0.015,"SLQ normalized (q=$q)",fontsize=14)
    @show med_slq_degnorm[end]
end


for tick in ax.xaxis.get_major_ticks()
    tick.label.set_fontsize(18)
end
for tick in ax.yaxis.get_major_ticks()
    tick.label.set_fontsize(18)
end

med_acl = vec(median(rc_acl_all,dims=2))
med_acl_degnorm = vec(median(rc_acl_all_degnorm,dims=2))
p = ax.plot(k_list,med_acl)
stderror = vec(std(rc_acl_all,dims=2))./sqrt(600)
ax.fill_between(k_list, med_acl-2*stderror, med_acl+2*stderror, alpha=0.3, color=p[1].get_color())
p = ax.plot(k_list,med_acl_degnorm)
stderror = vec(std(rc_acl_all_degnorm,dims=2))./sqrt(600)
ax.fill_between(k_list, med_acl_degnorm-2*stderror, med_acl_degnorm+2*stderror, alpha=0.3, color=p[1].get_color())
ax.text(4520,med_acl[end]-0.03,"ACL",fontsize=14)
ax.plot([4370,4470],[med_acl[end],med_acl[end]-0.02],color="k")
ax.plot([4370,4470],[med_acl_degnorm[end],med_acl_degnorm[end]-0.01],color="k")
ax.text(4520,med_acl_degnorm[end]-0.04,"ACL normalized",fontsize=1)
ax.spines["top"].set_visible(false)
ax.spines["right"].set_visible(false)

fig.savefig("figures/600-communities-$dataset.pdf",format="pdf", bbox_inches="tight")