using Distributed
addprocs(100);

@everywhere using LightGraphs
@everywhere using DelimitedFiles
@everywhere using SparseArrays
@everywhere using MAT
@everywhere using Random
@everywhere using JLD2,FileIO
@everywhere using MatrixNetworks
@everywhere using Statistics


@everywhere include("PageRank.jl")
@everywhere include("SLQcvx.jl") # this includes SLQ.jl
@everywhere include("common.jl")
@everywhere include("FlowSeed-1.0.jl")

@everywhere global offset = Dict("dblp"=>1,"liveJournal"=>10,"amazon"=>10)
@everywhere global dataset = "liveJournal"

@everywhere function worker_slq(jobs,results,vars)
    A = vars["A"]
    A = Int.(A)
    C = vars["C"]
    csize = []
    for i = 1:size(C,2)
        push!(csize,nnz(C[:,i]))
    end
    G = SLQ.graph(A)
    c34 = round(Int,maximum(csize)^(3/4))
    cids = sortperm(abs.(csize .- c34))[1:600]
    k_list = collect(1:offset[dataset]:round(Int,c34/2))
    q_kappa_map = Dict()
    q_kappa_map[1.5] = 0.02
    q_kappa_map[4.0] = 0.001
    q_kappa_map[8.0] = 0.00001
    q_kappa_map[10.0] = 0.000001
    while true
        input = take!(jobs)
        q,index = input
        if q == -1
            break
        end
        cid = cids[index]
        truth = C[:,cid].nzind
        n = length(truth)
        seed = 1
        delta = 0.0
        S = truth[randperm(MersenneTwister(seed),n)[1:max(1,round(Int,0.1*n))]]
        L = SLQ.QHuberLoss(q, delta)
        kappa = q_kappa_map[q]
        curr_time = @elapsed (x_slq_degnorm,r,iter) = SLQ.slq_diffusion(G, S, 0.05, kappa, 0.9, L, max_iters=10000000,epsilon=1.0e-8)
        x_slq = (x_slq_degnorm.^(q-1)).*G.deg
        sorted_ids = sortperm(-1*x_slq)
        sorted_ids_degnorm = sortperm(-1*x_slq_degnorm)
        curr_col = zeros(length(k_list))
        curr_col_degnorm = zeros(length(k_list))
        for (i,k) in enumerate(k_list)
            cluster_slq = sorted_ids[1:k]
            pr_slq,rc_slq = compute_pr_rc(cluster_slq,truth)
            cluster_slq_degnorm = sorted_ids_degnorm[1:k]
            pr_slq_degnorm,rc_slq_degnorm = compute_pr_rc(cluster_slq_degnorm,truth)
            curr_col[i] = rc_slq
            curr_col_degnorm[i] = rc_slq_degnorm
        end
        put!(results,(q,index,curr_col,curr_col_degnorm,curr_time))
    end
end




@everywhere function worker_acl(jobs,results,vars)
    A = vars["A"]
    A = Int.(A)
    C = vars["C"]
    csize = []
    for i = 1:size(C,2)
        push!(csize,nnz(C[:,i]))
    end
    G = SLQ.graph(A)
    c34 = round(Int,maximum(csize)^(3/4))
    cids = sortperm(abs.(csize .- c34))[1:600]
    k_list = collect(1:offset[dataset]:round(Int,c34/2))
    while true
        input = take!(jobs)
        index = input
        if index == -1
            break
        end
        cid = cids[index]
        truth = C[:,cid].nzind
        n = length(truth)
        seed = 1
        S = truth[randperm(MersenneTwister(seed),n)[1:max(1,round(Int,0.1*n))]]
        curr_time = @elapsed x_acl = PageRank.acl_diffusion(G,S,0.05,0.002)
        x_acl_degnorm = x_acl./G.deg
        sorted_ids = sortperm(-1*x_acl)
        sorted_ids_degnorm = sortperm(-1*x_acl_degnorm)
        curr_col = zeros(length(k_list))
        curr_col_degnorm = zeros(length(k_list))
        for (i,k) in enumerate(k_list)
            cluster_acl = sorted_ids[1:k]
            pr_acl,rc_acl = compute_pr_rc(cluster_acl,truth)
            cluster_acl_degnorm = sorted_ids_degnorm[1:k]
            pr_acl_degnorm,rc_acl_degnorm = compute_pr_rc(cluster_acl_degnorm,truth)
            curr_col[i] = rc_acl
            curr_col_degnorm[i] = rc_acl_degnorm
        end
        put!(results,(index,curr_col,curr_col_degnorm,curr_time))
    end
end



function make_jobs_slq(q_list,jobs)
    for q in q_list
        for i in 1:600
            put!(jobs,(q,i))
        end
    end
    for i in 1:length(workers())
        put!(jobs,(-1.0,-1))
    end
end

function make_jobs_acl(jobs)
    for i in 1:600
        put!(jobs,i)
    end
    for i in 1:length(workers())
        put!(jobs,-1)
    end
end




function huge_graph_parallel_slq(q_list)
    #vars = matread("liveJournal.mat")
    vars = matread("$dataset.mat")
    A = vars["A"]
    A = Int.(A)
    C = vars["C"]
    csize = []
    for i = 1:size(C,2)
        push!(csize,nnz(C[:,i]))
    end
    G = SLQ.graph(A)
    c34 = round(Int,maximum(csize)^(3/4))
    cids = sortperm(abs.(csize .- c34))[1:600]
    k_list = collect(1:offset[dataset]:round(Int,c34/2))
    nexps = length(q_list)*600
    jobs = RemoteChannel(()->Channel{Tuple{Float64,Int64}}(nexps+length(workers())))
    records = Dict(string(q)=>zeros(length(k_list),600) for q in q_list)
    records["time"] = zeros(length(q_list),1)
    records_degnorm = Dict(string(q)=>zeros(length(k_list),600) for q in q_list)
    records_degnorm["time"] = zeros(length(q_list),1)
    results = RemoteChannel(()->Channel{Tuple{Float64,Int64,Array{Float64,1},Array{Float64,1},Float64}}(nexps))
    make_jobs_slq(q_list,jobs)
    for p in workers()
        remote_do(worker_slq,p,jobs,results,vars)
    end
    while nexps > 0 # wait for all jobs to finish
       q,index,curr_col,curr_col_degnorm,curr_time = take!(results)
       records[string(q)][:,index] = curr_col
       records_degnorm[string(q)][:,index] = curr_col_degnorm
       nexps = nexps - 1
       records["time"][findall(x->x==q,q_list)[1]] += curr_time
       records_degnorm["time"][findall(x->x==q,q_list)[1]] += curr_time
       println("$nexps jobs left.")
    end
    return records,records_degnorm
end




function huge_graph_parallel_acl()
    #vars = matread("liveJournal.mat")
    vars = matread("$dataset.mat")
    A = vars["A"]
    A = Int.(A)
    C = vars["C"]
    csize = []
    for i = 1:size(C,2)
        push!(csize,nnz(C[:,i]))
    end
    G = SLQ.graph(A)
    c34 = round(Int,maximum(csize)^(3/4))
    cids = sortperm(abs.(csize .- c34))[1:600]
    k_list = collect(1:offset[dataset]:round(Int,c34/2))
    nexps = 600
    jobs = RemoteChannel(()->Channel{Int64}(nexps+length(workers())))
    records = Dict("acl"=>zeros(length(k_list),600))
    records["time"] = zeros(1,1)
    records_degnorm = Dict("acl"=>zeros(length(k_list),600))
    records_degnorm["time"] = zeros(1,1)
    results = RemoteChannel(()->Channel{Tuple{Int64,Array{Float64,1},Array{Float64,1},Float64}}(nexps))
    make_jobs_acl(jobs)
    for p in workers()
        remote_do(worker_acl,p,jobs,results,vars)
    end
    while nexps > 0 # wait for all jobs to finish
       index,curr_col,curr_col_degnorm,curr_time = take!(results)
       records["acl"][:,index] = curr_col
       records_degnorm["acl"][:,index] = curr_col_degnorm
       nexps = nexps - 1
       records["time"][1] += curr_time
       records_degnorm["time"][1] += curr_time
       println("$nexps jobs left.")
    end
    return records,records_degnorm
end


q_list = [1.5,4.0,8.0]
records,records_degnorm = huge_graph_parallel_slq(q_list)

save("$dataset-slq.jld2",records)
save("$dataset-slq-degnorm.jld2",records_degnorm)

records,records_degnorm = huge_graph_parallel_acl()
save("$dataset-acl.jld2",records)
save("$dataset-acl-degnorm.jld2",records_degnorm)