{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Include\n",
    "using LightGraphs, SparseArrays, SimpleWeightedGraphs\n",
    "using Statistics, BenchmarkTools, LinearAlgebra, ProgressMeter\n",
    "using Base.Threads, PhyloNetworks, StatsBase, Distributions\n",
    "using Base.GC, JLD2, FileIO, CSV, DataFrames\n",
    "using Random, NPZ, GraphRecipes, Plots, Laplacians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "include(\"TreeRep.jl\")\n",
    "include(\"ConstructTree.jl\")\n",
    "include(\"LevelTree.jl\")\n",
    "include(\"NJ.jl\")\n",
    "include(\"Utilities.jl\")\n",
    "include(\"Visualize.jl\")\n",
    "include(\"TreeOpt.jl\")\n",
    "include(\"SparseRep.jl\")\n",
    "include(\"Bartal.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotly()\n",
    "\n",
    "font = Plots.font(\"Helvetica\", 15)\n",
    "font2 = Plots.font(\"Helvetica\", 9)\n",
    "myfonts = Dict(:guidefont=>font, :xtickfont=>font2, :ytickfont=>font2, :legendfont=>font2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize the Imunological Trees. \n",
    "\n",
    "First argument should be the tree, the second argument should be the weights. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Visualize.visualize(R,W = wbfs, labels = [text(\"Dog\", :top, 20),text(\"Bear\", :left, 20),\n",
    "        text(\"Racoon\", :top, 20), text(\"Weasel\", :left, 20),text(\"Seal\", :bottom, 20),\n",
    "        text(\"Sea Lion\", :top, 20),\n",
    "        text(\"Cat\", :left, 20),\n",
    "        text(\"Monkey\", :right, 20),\n",
    "        \"\",\"\",\"\",\"\",\"\",\"\",\"\"])\n",
    "plot!(legend=:false,axis=:false)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Visualize.visualize(g2,W=D5, labels = [\"dog\",\"bear\",\"racoon\", \"weasel\",\"seal\",\"sea lion\",\"cat\",\"monkey     \",\n",
    "        \"\",\"\",\"\",\"\",\"\",\"\",\"\"])\n",
    "plot!(legend=:false,axis=:false)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plotting Optimization embeddings for immunological data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "P = zeros(8,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PM Coordinates\n",
    "\n",
    "P[1,:] = [-0.6015,  -0.7989]\n",
    "P[2,:] = [0.8426,  0.5386]\n",
    "P[3,:] = [-0.8188,  0.5740]\n",
    "P[4,:] = [0.7942, -0.6047]\n",
    "P[5,:] = [-0.9950, -0.0819]\n",
    "P[6,:] = [-0.1758,  0.9827]\n",
    "P[7,:] = [0.0065, -0.9208]\n",
    "P[8,:] = [-0.0761, -0.1009]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "scatter(P[:,1],P[:,2], series_annotations = [text(\"Dog\", :right, 20),text(\"Bear\", :right, 20),\n",
    "        text(\"Racoon\", :right, 20), text(\"Weasel\", :right, 20),text(\"Seal\", :right, 20),\n",
    "        text(\"Sea Lion\", :right, 20),\n",
    "        text(\"Cat\", :right, 20),\n",
    "        text(\"Monkey\", :right, 20)],\n",
    "legend=:false, ms = 1; myfonts..., axis=:false)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "P2 = zeros(8,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reading in the PT embeddings\n",
    "T = npzread(\"./sarich.final_coordinates0.npy\")\n",
    "P2[1,:] = T[2:end]/(1+T[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter(P2[:,1],P2[:,2], series_annotations = [text(\"Dog\", :right, 20),text(\"Bear\", :right, 20),\n",
    "        text(\"Racoon\", :top, 20), text(\"Weasel\", :bottom, 20),text(\"Seal\", :right, 20),\n",
    "        text(\"Sea Lion\", :right, 20),\n",
    "        text(\"Cat\", :right, 20),\n",
    "        text(\"Monkey\", :right, 20)],\n",
    "legend=:false, axis=:false)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random tree reconstruction experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 7\n",
    "\n",
    "T1 = zeros(N)\n",
    "T2 = zeros(N)\n",
    "map1 = zeros(N)\n",
    "map2 = zeros(N)\n",
    "dist1 = zeros(N)\n",
    "dist2 = zeros(N)\n",
    "nvs1 = zeros(N)\n",
    "nvs2 = zeros(N)\n",
    "nvs = zeros(N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ],
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i = 1:7\n",
    "    g = utilities.block(LightGraphs.SimpleGraphs.DoubleBinaryTree(i),10)\n",
    "    n = nv(g)\n",
    "    g = LightGraphs.bfs_tree(g,rand(1:n))\n",
    "    G = SimpleGraph(n)\n",
    "    for e in edges(g)\n",
    "        add_edge!(G,e)\n",
    "    end\n",
    "    @show(G)\n",
    "    W = rand(n,n)\n",
    "    W = W+W'\n",
    "    W = adjacency_matrix(G) .* W\n",
    "    D = utilities.parallel_dp_shortest_paths(G,W);\n",
    "    T1[i] = @elapsed G2,W2 = TreeRep.metric_to_structure(D,undef,undef);\n",
    "    \n",
    "    nvs[i] = nv(G)\n",
    "    \n",
    "\n",
    "    B = W2[1:nv(G2),1:nv(G2)];\n",
    "    B = sparse(B);\n",
    "    B = (B .> 0) .* B;\n",
    "    \n",
    "    D2 = utilities.parallel_dp_shortest_paths(G2, B, false);\n",
    "    for i = 1:n\n",
    "        D2[i,i] = 0\n",
    "    end\n",
    "    \n",
    "    nvs1[i] = nv(G2)\n",
    "    \n",
    "    map1[i] = utilities.MAP(D2[1:n,1:n],G)/utilities.MAP(D,G)\n",
    "    dist1[i] = utilities.avg_distortion(D2[1:n,1:n],D)\n",
    "    \n",
    "    \n",
    "    T2[i] = @elapsed R =  NJ.nj!(copy(convert(Matrix{Float64},D)));\n",
    "    \n",
    "    g2 = SimpleGraph(R.numNodes)\n",
    "    w = spzeros(R.numNodes,R.numNodes)\n",
    "    for i = 1:R.numEdges\n",
    "        src = R.edge[i].node[1].number\n",
    "        dst = R.edge[i].node[2].number\n",
    "        add_edge!(g2,src,dst)\n",
    "        w[src,dst] = R.edge[i].length\n",
    "        w[dst,src] = w[src,dst]\n",
    "    end\n",
    "    \n",
    "    nvs2[i] = nv(g2)\n",
    "    \n",
    "    D5 = utilities.parallel_dp_shortest_paths(g2, w)\n",
    "\n",
    "\n",
    "    dist2[i] = utilities.avg_distortion(D5[1:n,1:n],D);\n",
    "    map2[i] = utilities.MAP(D5[1:n,1:n],G)/utilities.MAP(D,G)\n",
    "    \n",
    "end\n",
    "\n",
    "@show(T1,T2,map1,map2,dist1,dist2,nvs,nvs1,nvs2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot(nvs,T1,label=\"TreeRep\",lc = :blue, linewidth = 2, xlabel = \"Nodes\", ylabel = \"Time taken (Seconds)\"; myfonts...)\n",
    "scatter!(nvs,T1,label=\"TreeRep\",xlabel = \"Nodes\", ylabel = \"Time taken (Seconds)\", mc = :blue, shape=:circle; myfonts...)\n",
    "plot!(nvs,T2,label=\"NJ\", lc = :red, linewidth = 2; myfonts...)\n",
    "scatter!(nvs,T2,label=\"NJ\", mc = :red, shape =:xcross; myfonts...)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot(nvs,nvs1./nvs,label = \"TreeRep\", xlabel = \"Nodes\", ylabel = \"Ratio of number of nodes in returned tree to original tree\")\n",
    "plot!(nvs, nvs2./nvs,label = \"NJ\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     11,
     15,
     35
    ],
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "N = 10\n",
    "dist1 = zeros(10)\n",
    "dist2 = zeros(10)\n",
    "dist3 = zeros(10)\n",
    "dist4 = zeros(10)\n",
    "dist5 = zeros(10)\n",
    "dist6 = zeros(10)\n",
    "dist7 = zeros(10)\n",
    "\n",
    "n = 100\n",
    "\n",
    "for i = 1:N\n",
    "    sc = 2^i\n",
    "    Z = CSV.read(\"rand-dim10-scale$sc.csv\")\n",
    "    D = zeros(100,100)\n",
    "    for k = 1:50*99\n",
    "        ii = Z[k,1]\n",
    "        j = Z[k,2]\n",
    "        D[ii,j] = Z[k,3]\n",
    "        D[j,ii] = Z[k,3]\n",
    "    end\n",
    "    n = 100\n",
    "    #D = utilities.rand_hyperbolic(n,10,2^i)\n",
    "    \n",
    "    @time G2, W2 = TreeRep.metric_to_structure(D,undef,undef)\n",
    "    B = W2[1:nv(G2),1:nv(G2)];\n",
    "    B = sparse(B);\n",
    "    B = (B .> 0) .* B;\n",
    "    D2 = utilities.parallel_dp_shortest_paths(G2, B, false);\n",
    "    α = tr(D2[1:n,1:n]'*D)/tr(D2[1:n,1:n]'*D2[1:n,1:n])\n",
    "    dist1[i] = utilities.avg_distortion(D2[1:n,1:n]*α,D)\n",
    "    \n",
    "    @time R = NJ.nj!(copy(convert(Matrix{Float64},D)))\n",
    "    g2 = SimpleGraph(R.numNodes)\n",
    "    w = spzeros(R.numNodes,R.numNodes)\n",
    "    for i = 1:R.numEdges\n",
    "        src = R.edge[i].node[1].number\n",
    "        dst = R.edge[i].node[2].number\n",
    "        add_edge!(g2,src,dst)\n",
    "        w[src,dst] = R.edge[i].length\n",
    "        w[dst,src] = w[src,dst]\n",
    "    end\n",
    "    D5 = utilities.parallel_dp_shortest_paths(g2, w, false);\n",
    "    dist2[i] = utilities.avg_distortion(D5[1:n,1:n],D)\n",
    "    \n",
    "    @time T,W4 = ConstructTree.basicConstructTree(collect(2:n),1,D)\n",
    "    D4 = LightGraphs.floyd_warshall_shortest_paths(T,W4[1:nv(T),1:nv(T)]).dists;\n",
    "    α = tr(D4[1:n,1:n]'*D)/tr(D4[1:n,1:n]'*D4[1:n,1:n])\n",
    "    dist3[i] = utilities.avg_distortion(D4[1:n,1:n]*α,D)\n",
    "    \n",
    "    \n",
    "    #g = CompleteGraph(n)\n",
    "    g = utilities.kNN(D,10)\n",
    "    @time r = LightGraphs.prim_mst(g,D)\n",
    "    #r = LightGraphs.bfs_tree(g,1)\n",
    "    R = SimpleGraph(n)\n",
    "    for e in r\n",
    "        add_edge!(R,e)\n",
    "    end\n",
    "    wbfs =  adjacency_matrix(R) .* D\n",
    "    D3 = utilities.parallel_dp_shortest_paths(R,wbfs);\n",
    "    α = tr(D3[1:n,1:n]'*D)/tr(D3[1:n,1:n]'*D3[1:n,1:n])\n",
    "    dist4[i] = utilities.avg_distortion(D3[1:n,1:n]*α,D)\n",
    "    \n",
    "    g = CompleteGraph(n)\n",
    "    #g = utilities.kNN(D,10)\n",
    "    @time r = LightGraphs.prim_mst(g,D)\n",
    "    #r = LightGraphs.bfs_tree(g,1)\n",
    "    R = SimpleGraph(n)\n",
    "    for e in r\n",
    "        add_edge!(R,e)\n",
    "    end\n",
    "    wbfs =  adjacency_matrix(R) .* D\n",
    "    D3 = utilities.parallel_dp_shortest_paths(R,wbfs);\n",
    "    α = tr(D3[1:n,1:n]'*D)/tr(D3[1:n,1:n]'*D3[1:n,1:n])\n",
    "    dist5[i] = utilities.avg_distortion(D3[1:n,1:n]*α,D)\n",
    "    \n",
    "    \n",
    "    g = utilities.kNN(D,10)\n",
    "    A = adjacency_matrix(g) .* D\n",
    "    @time R = Laplacians.akpw(A);\n",
    "\n",
    "    g2 = build_graph(R)\n",
    "    @time D6 = LightGraphs.floyd_warshall_shortest_paths(g2, R).dists;\n",
    "    dist6[i] = utilities.avg_distortion(D6[1:n,1:n]*α,D)\n",
    "    \n",
    "    \n",
    "    g = CompleteGraph(n)\n",
    "    A = adjacency_matrix(g) .* D\n",
    "    @time R = Laplacians.akpw(A);\n",
    "\n",
    "    g2 = build_graph(R)\n",
    "    @time D6 = LightGraphs.floyd_warshall_shortest_paths(g2, R).dists;\n",
    "    dist7[i] = utilities.avg_distortion(D6[1:n,1:n]*α,D)\n",
    "    \n",
    "    g = CompleteGraph(n)\n",
    "    E = ne(g)\n",
    "    \n",
    "    id1 = convert(Array{Int64,1},zeros(E))\n",
    "    id2 = convert(Array{Int64,1},zeros(E))\n",
    "    weight = zeros(E)\n",
    " \n",
    "    Ed = collect(edges(g))\n",
    "\n",
    "    for i = 1:E\n",
    "        e = Ed[i]\n",
    "        id1[i] = e.src\n",
    "        id2[i] = e.dst\n",
    "        weight[i] = D[e.src,e.dst]\n",
    "    end\n",
    "    \n",
    "    df = DataFrame(id1 = id1, id2 = id2, weight = weight)\n",
    "    scale = 2^i\n",
    "    dim = 10\n",
    "    CSV.write(\"rand-dim$dim-scale$scale.csv\",  df, writeheader=true)\n",
    "    \n",
    "    @show((dist1[i],dist2[i]))\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot(collect(1:10),dist1,label=\"TreeRep\", lc = :blue, linewidth = 2)\n",
    "plot!(collect(1:10),dist2,label=\"NJ\", lc = :red, linewidth = 2)\n",
    "plot!(collect(1:10),dist3,label=\"ConstructTree\", lc = :green, linewidth = 2)\n",
    "plot!(collect(1:10),dist5,label=\"MST : Complete\", lc = :indigo, linewidth = 2)\n",
    "plot!(collect(1:10),dist4,label=\"MST : KNN 10\", lc = :darkorange, linewidth = 2)\n",
    "plot!(collect(1:10),dist6,label=\"AKPW : KNN 10\", lc = :pink, linewidth = 2)\n",
    "plot!(collect(1:10),dist7,label=\"AKPW : Complete\", lc = :black, linewidth = 2)\n",
    "plot!(collect(1:10),dist6LM,label=\"Lorentz Maps\", lc = :yellow, linewidth = 2)\n",
    "plot!(collect(1:10),dist8PT,label=\"PT\", lc = :gray, linewidth = 2)\n",
    "scatter!(collect(1:10),dist1,label=\"TreeRep\", shape=:circle, mc = :blue)\n",
    "scatter!(collect(1:10),dist2,label=\"NJ\",shape=:cross, mc = :red)\n",
    "scatter!(collect(1:10),dist3,label=\"ConstructTree\", shape =:xcross, mc = :green)\n",
    "scatter!(collect(1:10),dist4,label=\"MST : KNN 10\", shape =:vline, mc = :darkorange)\n",
    "scatter!(collect(1:10),dist5,label=\"MST : Complete\", shape =:diamond, mc = :indigo)\n",
    "scatter!(collect(1:10),dist6,label=\"AKPW : KNN 10\", shape = :square, mc = :pink)\n",
    "scatter!(collect(1:10),dist7,label=\"AKPW : Complete\", shape = :dtriangle, mc = :black)\n",
    "scatter!(collect(1:10),dist6LM,label=\"Lorentz Maps\", shape = :hexagon, mc = :yellow)\n",
    "scatter!(collect(1:10),dist8PT,label=\"PT\", shape = :triangle, mc = :gray)\n",
    "plot!(xlabel = \"Dimension (Log Scale)\",ylabel = \"Average Distortion (Log scale)\",yscale=:log; myfonts...)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     11
    ]
   },
   "outputs": [],
   "source": [
    "N = 10\n",
    "dist1 = zeros(10)\n",
    "dist2 = zeros(10)\n",
    "dist3 = zeros(10)\n",
    "dist4 = zeros(10)\n",
    "dist5 = zeros(10)\n",
    "dist6 = zeros(10)\n",
    "dist7 = zeros(10)\n",
    "\n",
    "n = 100\n",
    "\n",
    "for i = 1:N\n",
    "    sc = 2^i\n",
    "    Z = CSV.read(\"rand-dim$sc-scale1.csv\")\n",
    "    D = zeros(100,100)\n",
    "    for k = 1:50*99\n",
    "        ii = Z[k,1]\n",
    "        j = Z[k,2]\n",
    "        D[ii,j] = Z[k,3]\n",
    "        D[j,ii] = Z[k,3]\n",
    "    end\n",
    "    n = 100\n",
    "    #D = utilities.rand_hyperbolic(n,2^i,1)\n",
    "    \n",
    "    G2, W2 = TreeRep.metric_to_structure(D,undef,undef)\n",
    "    B = W2[1:nv(G2),1:nv(G2)];\n",
    "    B = sparse(B);\n",
    "    B = (B .> 0) .* B;\n",
    "    D2 = utilities.parallel_dp_shortest_paths(G2, B, false);\n",
    "    α = tr(D2[1:n,1:n]'*D)/tr(D2[1:n,1:n]'*D2[1:n,1:n])\n",
    "    dist1[i] = utilities.avg_distortion(D2[1:n,1:n]*α,D)\n",
    "    \n",
    "    R = NJ.nj!(copy(convert(Matrix{Float64},D)))\n",
    "    g2 = SimpleGraph(R.numNodes)\n",
    "    w = spzeros(R.numNodes,R.numNodes)\n",
    "    for i = 1:R.numEdges\n",
    "        src = R.edge[i].node[1].number\n",
    "        dst = R.edge[i].node[2].number\n",
    "        add_edge!(g2,src,dst)\n",
    "        w[src,dst] = R.edge[i].length\n",
    "        w[dst,src] = w[src,dst]\n",
    "    end\n",
    "    D5 = utilities.parallel_dp_shortest_paths(g2, w, false);\n",
    "    dist2[i] = utilities.avg_distortion(D5[1:n,1:n],D)\n",
    "    \n",
    "    T,W4 = ConstructTree.basicConstructTree(collect(2:n),1,D)\n",
    "    D4 = LightGraphs.floyd_warshall_shortest_paths(T,W4[1:nv(T),1:nv(T)]).dists;\n",
    "    α = tr(D4[1:n,1:n]'*D)/tr(D4[1:n,1:n]'*D4[1:n,1:n])\n",
    "    dist3[i] = utilities.avg_distortion(D4[1:n,1:n]*α,D)\n",
    "    \n",
    "    \n",
    "    #g = CompleteGraph(n)\n",
    "    g = utilities.kNN(D,10)\n",
    "    r = LightGraphs.prim_mst(g,D)\n",
    "    #r = LightGraphs.bfs_tree(g,1)\n",
    "    R = SimpleGraph(n)\n",
    "    for e in r\n",
    "        add_edge!(R,e)\n",
    "    end\n",
    "    wbfs =  adjacency_matrix(R) .* D\n",
    "    D3 = utilities.parallel_dp_shortest_paths(R,wbfs);\n",
    "    α = tr(D3[1:n,1:n]'*D)/tr(D3[1:n,1:n]'*D3[1:n,1:n])\n",
    "    dist4[i] = utilities.avg_distortion(D3[1:n,1:n]*α,D)\n",
    "    \n",
    "    g = CompleteGraph(n)\n",
    "    #g = utilities.kNN(D,10)\n",
    "    r = LightGraphs.prim_mst(g,D)\n",
    "    #r = LightGraphs.bfs_tree(g,1)\n",
    "    R = SimpleGraph(n)\n",
    "    for e in r\n",
    "        add_edge!(R,e)\n",
    "    end\n",
    "    wbfs =  adjacency_matrix(R) .* D\n",
    "    D3 = utilities.parallel_dp_shortest_paths(R,wbfs);\n",
    "    α = tr(D3[1:n,1:n]'*D)/tr(D3[1:n,1:n]'*D3[1:n,1:n])\n",
    "    dist5[i] = utilities.avg_distortion(D3[1:n,1:n]*α,D)\n",
    "    \n",
    "    g = utilities.kNN(D,10)\n",
    "    A = adjacency_matrix(g) .* D\n",
    "    @time R = Laplacians.akpw(A);\n",
    "\n",
    "    g2 = build_graph(R)\n",
    "    @time D6 = LightGraphs.floyd_warshall_shortest_paths(g2, R).dists;\n",
    "    dist6[i] = utilities.avg_distortion(D6[1:n,1:n]*α,D)\n",
    "    \n",
    "    \n",
    "    g = CompleteGraph(n)\n",
    "    A = adjacency_matrix(g) .* D\n",
    "    @time R = Laplacians.akpw(A);\n",
    "\n",
    "    g2 = build_graph(R)\n",
    "    @time D6 = LightGraphs.floyd_warshall_shortest_paths(g2, R).dists;\n",
    "    dist7[i] = utilities.avg_distortion(D6[1:n,1:n]*α,D)\n",
    "    \n",
    "    g = CompleteGraph(n)\n",
    "    E = ne(g)\n",
    "    \n",
    "    id1 = convert(Array{Int64,1},zeros(E))\n",
    "    id2 = convert(Array{Int64,1},zeros(E))\n",
    "    weight = zeros(E)\n",
    "\n",
    "    Ed = collect(edges(g))\n",
    "\n",
    "    for i = 1:E\n",
    "        e = Ed[i]\n",
    "        id1[i] = e.src\n",
    "        id2[i] = e.dst\n",
    "        weight[i] = D[e.src,e.dst]\n",
    "    end\n",
    "    \n",
    "    df = DataFrame(id1 = id1, id2 = id2, weight = weight)\n",
    "    scale = 1\n",
    "    dim = 2^i\n",
    "    CSV.write(\"rand-dim$dim-scale$scale.csv\",  df, writeheader=true)\n",
    "    \n",
    "    @show((dist1[i],dist2[i]))\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot(collect(1:10),dist1,label=\"TreeRep\", lc = :blue, linewidth = 2)\n",
    "plot!(collect(1:10),dist2,label=\"NJ\", lc = :red, linewidth = 2)\n",
    "plot!(collect(1:10),dist3,label=\"ConstructTree\", lc = :green, linewidth = 2)\n",
    "plot!(collect(1:10),dist5,label=\"MST : Complete\", lc = :indigo, linewidth = 2)\n",
    "plot!(collect(1:10),dist4,label=\"MST : KNN 10\", lc = :darkorange, linewidth = 2)\n",
    "plot!(collect(1:10),dist6,label=\"AKPW : KNN 10\", lc = :pink, linewidth = 2)\n",
    "plot!(collect(1:10),dist7,label=\"AKPW : Complete\", lc = :black, linewidth = 2)\n",
    "plot!(collect(1:10),dist7LM,label=\"Lorentz Maps\", lc = :yellow, linewidth = 2)\n",
    "plot!(collect(1:10),dist9PT,label=\"PT\", lc = :gray, linewidth = 2)\n",
    "scatter!(collect(1:10),dist1,label=\"TreeRep\", shape=:circle, mc = :blue)\n",
    "scatter!(collect(1:10),dist2,label=\"NJ\",shape=:cross, mc = :red)\n",
    "scatter!(collect(1:10),dist3,label=\"ConstructTree\", shape =:xcross, mc = :green)\n",
    "scatter!(collect(1:10),dist4,label=\"MST : KNN 10\", shape =:vline, mc = :darkorange)\n",
    "scatter!(collect(1:10),dist5,label=\"MST : Complete\", shape =:diamond, mc = :indigo)\n",
    "scatter!(collect(1:10),dist6,label=\"AKPW : KNN 10\", shape = :square, mc = :pink)\n",
    "scatter!(collect(1:10),dist7,label=\"AKPW : Complete\", shape = :dtriangle, mc = :black)\n",
    "scatter!(collect(1:10),dist7LM,label=\"Lorentz Maps\", shape = :hexagon, mc = :yellow)\n",
    "scatter!(collect(1:10),dist9PT,label=\"PT\", shape = :triangle, mc = :gray)\n",
    "plot!(xlabel = \"Scale (Log scale)\",ylabel = \"Average Distortion (Log scale)\",yscale=:log; myfonts...)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist6LM =  load(\"dist6.jld2\")[\"dist6\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist7LM = load(\"dist7.jld2\")[\"dist7\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist8PT = load(\"dist8.jld2\")[\"dist8\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist9PT = load(\"dist9.jld2\")[\"dist9\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Making CSV files for inputs to optimization based methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 1000\n",
    "D = utilities.rand_hyperbolic(1000,10,1000);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = CompleteGraph(n)\n",
    "E = ne(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id1 = convert(Array{Int64,1},zeros(E))\n",
    "id2 = convert(Array{Int64,1},zeros(E))\n",
    "weight = zeros(E)\n",
    "\n",
    "Ed = collect(edges(g))\n",
    "\n",
    "for i = 1:E\n",
    "    e = Ed[i]\n",
    "    id1[i] = e.src\n",
    "    id2[i] = e.dst\n",
    "    weight[i] = D[e.src,e.dst]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = DataFrame(id1 = id1, id2 = id2, weight = weight)\n",
    "CSV.write(\"rand1000.csv\",  df, writeheader=true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Zeisel and CBMC\n",
    "\n",
    "cite-dists is the CBMC data file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@load \"zeisel-dists.jld2\" A2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@load \"cite-dists.jld2\" A2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = A2\n",
    "n = size(D)[1]\n",
    "G = CompleteGraph(n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the data files that will be used by PM, LM, and PT. CBMC is too big to run the optimization methods. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = CompleteGraph(size(D)[1])\n",
    "E = ne(g)\n",
    "\n",
    "id1 = convert(Array{Int64,1},zeros(E))\n",
    "id2 = convert(Array{Int64,1},zeros(E))\n",
    "weight = zeros(E)\n",
    "\n",
    "Ed = collect(edges(g))\n",
    "\n",
    "for i = 1:E\n",
    "    e = Ed[i]\n",
    "    id1[i] = e.src\n",
    "    id2[i] = e.dst\n",
    "    weight[i] = D[e.src,e.dst]\n",
    "end\n",
    "\n",
    "df = DataFrame(id1 = id1, id2 = id2, weight = weight)\n",
    "CSV.write(\"sarich.csv\",  df, writeheader=true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = \"sarich.\"\n",
    "CSV.write(filename*\"edges\",CSV.read(filename*\"csv\"); delim=' ',writeheader=false)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "open(\"zeisel-tree.edges\", \"w\") do f\n",
    "    for e in edges(G2)\n",
    "        i = e.src - 1\n",
    "        j = e.dst - 1\n",
    "        w = W2[i+1,j+1]/100\n",
    "        if w < 0\n",
    "            w = 0\n",
    "        end\n",
    "        write(f, \"$i $j $w\\n\")\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data Sets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the Sarich et al immunological data set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = sparse([ 0  32  48  51  50  48  98 148;\n",
    "32   0  26  34  29  33  84 136;\n",
    "48  26   0  42  44  44  92 152;\n",
    "51  34  42   0  44  38  86 142;\n",
    "50  29  44  44   0  24  89 142;\n",
    "48  33  44  38  24   0  90 142;\n",
    "98  84  92  86  89  90   0 148;\n",
    "148 136 152 142 142 142 148 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "function build_graph(A)\n",
    "    n = size(A)[1]\n",
    "    g = SimpleGraph(n)\n",
    "    for i = 1:n\n",
    "        for j = 1:i-1\n",
    "            if A[i,j] != 0\n",
    "                add_edge!(g,i,j)\n",
    "            end\n",
    "        end\n",
    "    end\n",
    "    \n",
    "    return g\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the data sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G,w = utilities.read_tree_withweights(\"./../hyperbolics-master/data/edges/bio-CE-GT.edges\",\" \")\n",
    "n = nv(G)\n",
    "E = ne(G)\n",
    "@show((n,E));\n",
    "@show(is_connected(G));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = utilities.remove_loops(utilities.read_tree(\"./../hyperbolics-master/data/edges/wordnet.edges\",\" \"))\n",
    "n = nv(G)\n",
    "E = ne(G)\n",
    "@show((n,E));\n",
    "@show(is_connected(G));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extract largest connected component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "C = connected_components(G)\n",
    "idxmax = argmax(length.(C))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g,V = induced_subgraph(G, C[idxmax])\n",
    "@show(is_connected(g));\n",
    "n = nv(g)\n",
    "E = ne(g)\n",
    "\n",
    "print(n,\" \",E)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@time D = utilities.parallel_dp_shortest_paths(g,adjacency_matrix(g));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calculating δ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = utilities.calc_delta_for_w(D./maximum(D),1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run Alon et al algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = adjacency_matrix(G).*D;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@time R = Laplacians.akpw(A);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g2 = build_graph(R)\n",
    "@time D5 = utilities.parallel_dp_shortest_paths(g2,R);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utilities.MAP(D5[1:n,1:n],g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "α = tr(D5[1:n,1:n]'*D)/tr(D5[1:n,1:n]'*D5[1:n,1:n])\n",
    "utilities.avg_distortion(D5[1:n,1:n]*α,D)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bartal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     2
    ]
   },
   "outputs": [],
   "source": [
    "D7 = zeros(n,n)\n",
    "p2 = Progress(2)\n",
    "for i = 1:200\n",
    "    R = Bartal.bartal(g,collect(1:n),D)\n",
    "    D6 = utilities.parallel_dp_shortest_paths(R[1],R[4])\n",
    "    p = R[2]\n",
    "    IndexToIdx = copy(R[2])\n",
    "    for i = 1:length(p)\n",
    "        IndexToIdx[p[i]] = i\n",
    "    end\n",
    "    D2p = zeros(n,n)\n",
    "    for i = 1:length(p)\n",
    "        for j = 1:i-1\n",
    "            D2p[i,j] = D6[IndexToIdx[i],IndexToIdx[j]]\n",
    "            D2p[j,i] = D2p[i,j]\n",
    "        end\n",
    "    end\n",
    "    D7 = (D7*(i-1) + D2p)/i\n",
    "    update!(p2,i)\n",
    "    flush(stdout)\n",
    "end\n",
    "α = tr(D7[1:n,1:n]'*D)/tr(D7[1:n,1:n]'*D7[1:n,1:n])\n",
    "@show(utilities.avg_distortion(D7*α,D))\n",
    "@show(utilities.MAP(D7,g))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct Tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = CompleteGraph(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global λ = 2000\n",
    "n = size(D)[1]\n",
    "@time T,W4 = ConstructTree.basicConstructTree(collect(2:n),1,D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@time D4 = LightGraphs.floyd_warshall_shortest_paths(T,W4[1:nv(T),1:nv(T)]).dists;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "α = tr(D4[1:n,1:n]'*D)/tr(D4[1:n,1:n]'*D4[1:n,1:n])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utilities.avg_distortion(D4[1:n,1:n]*α,D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utilities.MAP(D4[1:n,1:n],g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LevelTree algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@time gT = utilities.remove_loops(LevelTree.build_level_graph(g,D,1))\n",
    "DT = utilities.parallel_dp_shortest_paths(gT, adjacency_matrix(gT));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utilities.MAP(DT,g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "α = tr(DT[1:n,1:n]'*D)/tr(DT[1:n,1:n]'*DT[1:n,1:n])\n",
    "utilities.avg_distortion(DT[1:n,1:n]*α,D)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tree Rep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global p2 = Progress(nv(g))\n",
    "global jj = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "NN = 20\n",
    "\n",
    "times = zeros(NN)\n",
    "map2 = zeros(NN)\n",
    "distort = zeros(NN)\n",
    "map2opt = zeros(NN)\n",
    "distortopt = zeros(NN)\n",
    "j = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This has the heurestic optimization as full"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ],
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for j = 1:20\n",
    "    times[j] = @elapsed G2,W2 = TreeRep.metric_to_structure(D,undef,undef);\n",
    "    \n",
    "    flush(stdout)\n",
    "    B = W2[1:nv(G2),1:nv(G2)];\n",
    "    B = sparse(B);\n",
    "    B = (B .> 0) .* B;\n",
    "    \n",
    "    D2 = utilities.parallel_dp_shortest_paths(G2, B,false);\n",
    "    α = tr(D2[1:n,1:n]'*D)/tr(D2[1:n,1:n]'*D2[1:n,1:n])\n",
    "    \n",
    "    map2[j] = utilities.MAP(D2[1:n,1:n],g)\n",
    "    distort[j] = utilities.avg_distortion(D2[1:n,1:n]*α,D)\n",
    "    D2 = 0\n",
    "    B = 0\n",
    "    GC.gc()\n",
    "    \n",
    "    N = size(D)[1]\n",
    "    \n",
    "    @show(Sys.free_memory()/2^(30))\n",
    "    \n",
    "    L = Int((N*(N-1))/2)\n",
    "    IDXs = Array{Tuple{Int,Int},1}(undef,L)\n",
    "    c = 1\n",
    "    for i = 1:N \n",
    "        for j = 1:i-1\n",
    "            IDXs[c] = (i,j)\n",
    "            c += 1\n",
    "        end\n",
    "    end\n",
    "    \n",
    "    @show(Sys.free_memory()/2^(30))\n",
    "    \n",
    "    times[j] += @elapsed x,EdgetoIdx = TreeOpt.lsngd_mengdi(G2,D,W2,IDXs,0.0001,200);\n",
    "    \n",
    "    IDXs = 0\n",
    "    W2 = 0\n",
    "    GC.gc()\n",
    "    \n",
    "    N = nv(G2)\n",
    "    W3 = zeros(N,N)\n",
    "    E = collect(edges(G2))\n",
    "    for e in E\n",
    "        i2 = e.src\n",
    "        j2 = e.dst\n",
    "        idx = EdgetoIdx[(i2,j2)]\n",
    "        w = max(0,x[idx])\n",
    "        W3[i2,j2] = w\n",
    "        W3[j2,i2] = w\n",
    "    end\n",
    "    \n",
    "    B = W3[1:nv(G2),1:nv(G2)];\n",
    "    B = sparse(B);\n",
    "    B = (B .> 0) .* B;\n",
    "    D3 = utilities.parallel_dp_shortest_paths(G2, B,false) .+ 1e-14;\n",
    "    for i = 1:n\n",
    "        D3[i,i] = 0\n",
    "    end\n",
    "\n",
    "    α = tr(D3[1:n,1:n]'*D)/tr(D3[1:n,1:n]'*D3[1:n,1:n])\n",
    "\n",
    "    distortopt[j] = utilities.avg_distortion(D3[1:n,1:n]*α,D)\n",
    "    map2opt[j] = utilities.MAP(D3[1:n,1:n],g)\n",
    "    @show((times[j],distort[j],map2[j]))\n",
    "end\n",
    "\n",
    "t = mean(times)\n",
    "dis = minimum(distort)\n",
    "m = maximum(map2)\n",
    "disopt = minimum(distortopt)\n",
    "mopt = maximum(map2opt)\n",
    "\n",
    "mdis = mean(distort)\n",
    "mm = mean(map2)\n",
    "mdisopt = mean(distortopt)\n",
    "mmopt = mean(map2opt)\n",
    "\n",
    "@show((t,dis,m))\n",
    "@show((t,disopt,mopt))\n",
    "@show((t,mdis,mm))\n",
    "@show((t,mdisopt,mmopt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This has the full optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     6
    ],
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "NN = 20\n",
    "\n",
    "times = zeros(NN)\n",
    "map2 = zeros(NN)\n",
    "distort = zeros(NN)\n",
    "\n",
    "for j = 1:NN\n",
    "    times[j] = @elapsed G2,W2 = TreeRep.metric_to_structure(D,undef,undef);\n",
    "    \n",
    "    flush(stdout)\n",
    "    B = W2[1:nv(G2),1:nv(G2)];\n",
    "    B = sparse(B);\n",
    "    B = (B .> 0) .* B;\n",
    "    \n",
    "    D2 = utilities.parallel_dp_shortest_paths(G2, B,false);\n",
    "    \n",
    "    distort[j] = utilities.avg_distortion(D2[1:n,1:n],D)\n",
    "    map2[j] = utilities.MAP(D2[1:n,1:n],g)\n",
    "    \n",
    "    B = 0\n",
    "    D2 = 0\n",
    "    \n",
    "    GC.gc()\n",
    "    \n",
    "    @show((times[j],distort[j],map2[j]))\n",
    "    flush(stdout)\n",
    "    \n",
    "    times[j] += @elapsed A,b,EdgetoIdx,x0 = TreeOpt.makeAbMatrix(G2,D,W2)\n",
    "    times[j] += @elapsed x,loss = TreeOpt.lsngd(A,b,0.00000001,x0,5000)\n",
    "    flush(stdout)\n",
    "    N = nv(G2)\n",
    "    W2  = zeros(N,N)\n",
    "    E = collect(edges(G2))\n",
    "    for e in E\n",
    "        i2 = e.src\n",
    "        j2 = e.dst\n",
    "        idx = EdgetoIdx[(i2,j2)]\n",
    "        w = max(0,x[idx])\n",
    "        W2[i2,j2] = w\n",
    "        W2[j2,i2] = w\n",
    "    end\n",
    "    flush(stdout)\n",
    "    B = W2[1:nv(G2),1:nv(G2)];\n",
    "    B = sparse(B);\n",
    "    B = (B .> 0) .* B;\n",
    "    D2 = utilities.parallel_dp_shortest_paths(G2, B) .+1e-13;\n",
    "    for i = 1:size(D2)[1]\n",
    "        D2[i,i] = 0\n",
    "    end\n",
    "    flush(stdout)\n",
    "\n",
    "    distort[j] = utilities.avg_distortion(D2[1:n,1:n],D)\n",
    "    map2[j] = utilities.MAP(D2[1:n,1:n],g)\n",
    "    flush(stdout)\n",
    "    @show((times[j],distort[j],map2[j]))\n",
    "    flush(stdout)\n",
    "end\n",
    "\n",
    "t = mean(times)\n",
    "dis = minimum(distort)\n",
    "m = maximum(map2)\n",
    "\n",
    "@show((t,dis,m))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This has no optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     7
    ],
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "N = 20\n",
    "\n",
    "times = zeros(N)\n",
    "map2 = zeros(N)\n",
    "distort = zeros(N)\n",
    "D2 = 0\n",
    "j=1\n",
    "for j = 1:20\n",
    "    times[j] = @elapsed G2,W2 = TreeRep.metric_to_structure(D,undef,undef);\n",
    "    G2 = utilities.remove_loops(G2)\n",
    "    @show(times[j])\n",
    "    flush(stdout)\n",
    "    B = W2[1:nv(G2),1:nv(G2)];\n",
    "    B = sparse(B);\n",
    "    B = (B .> 0) .* B;\n",
    "    \n",
    "    W2 = 0\n",
    "    GC.gc()\n",
    "    D2 = utilities.parallel_dp_shortest_paths(G2, B);\n",
    "    α = trm(D2[1:n,1:n]',D)/(sum(D2.^2))\n",
    "    distort[j] = utilities.avg_distortion(D2[1:n,1:n]*α,D)\n",
    "    map2[j] = utilities.MAP(D2[1:n,1:n],g)\n",
    "    \n",
    "    @show((distort[j],map2[j]))\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [],
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "t = mean(times)\n",
    "dis = maximum(distort)\n",
    "m = maximum(map2)\n",
    "mdis = mean(distort)\n",
    "mm = mean(map2)\n",
    "\n",
    "@show((t,dis,m))\n",
    "@show((t,mdis,mm))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neighbor Join"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@time R = NJ.nj!(copy(convert(Matrix{Float64},D)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g2 = SimpleGraph(R.numNodes)\n",
    "w = spzeros(R.numNodes,R.numNodes)\n",
    "for i = 1:R.numEdges\n",
    "    src = R.edge[i].node[1].number\n",
    "    dst = R.edge[i].node[2].number\n",
    "    add_edge!(g2,src,dst)\n",
    "    w[src,dst] = R.edge[i].length\n",
    "    w[dst,src] = w[src,dst]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@time D5 = utilities.parallel_dp_shortest_paths(g2, w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "α = tr(D5[1:n,1:n]'*D)/tr(D5[1:n,1:n]'*D5[1:n,1:n])\n",
    "@show(utilities.avg_distortion(D5[1:n,1:n],D));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@show(utilities.MAP(D5[1:n,1:n],g));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BFS Tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = utilities.kNN(D,10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = CompleteGraph(size(D)[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = nv(G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@elapsed r = LightGraphs.prim_mst(G,D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "R = SimpleGraph(n)\n",
    "for e in r\n",
    "    add_edge!(R,e)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@elapsed r = LightGraphs.bfs_tree(g,rand(1:n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "R = SimpleGraph(n)\n",
    "for e in edges(r)\n",
    "    add_edge!(R,e)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "wbfs =  adjacency_matrix(R) .* D\n",
    "D3 = utilities.parallel_dp_shortest_paths(R,wbfs);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "function trm(A,B)\n",
    "    n = size(A)[1]\n",
    "    t = 0\n",
    "    for i = 1:n\n",
    "        t += sum(A[i,:].*B[:,i])\n",
    "    end\n",
    "    \n",
    "    return t\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "α = trm(D3[1:n,1:n],D)/(sum(D3.^2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "@show(utilities.avg_distortion(D3[1:n,1:n]*α,D));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "@show(utilities.MAP(D3[1:n,1:n],g));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Doing the heurestic optimization for MST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = size(D)[1]\n",
    "    \n",
    "@show(Sys.free_memory()/2^(30))\n",
    "    \n",
    "L = Int((N*(N-1))/2)\n",
    "IDXs = Array{Tuple{Int,Int},1}(undef,L)\n",
    "c = 1\n",
    "for i = 1:N \n",
    "    for j = 1:i-1\n",
    "        IDXs[c] = (i,j)\n",
    "        c += 1\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@time x,EdgetoIdx = TreeOpt.lsngd_mengdi(R,D,adjacency_matrix(R),IDXs,0.0001,200);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = nv(R)\n",
    "W3 = zeros(N,N)\n",
    "E = collect(edges(R))\n",
    "for e in E\n",
    "    i2 = e.src\n",
    "    j2 = e.dst\n",
    "    idx = EdgetoIdx[(i2,j2)]\n",
    "    w = max(0,x[idx])\n",
    "    W3[i2,j2] = w\n",
    "    W3[j2,i2] = w\n",
    "end\n",
    "\n",
    "B = W3[1:nv(R),1:nv(R)];\n",
    "B = sparse(B);\n",
    "B = (B .> 0) .* B;\n",
    "D3 = utilities.parallel_dp_shortest_paths(R, B) .+ 1e-14;\n",
    "for i = 1:n\n",
    "    D3[i,i] = 0\n",
    "end\n",
    "\n",
    "\n",
    "distort = utilities.avg_distortion(D3[1:n,1:n],D)\n",
    "map2 = utilities.MAP(D3[1:n,1:n],g)\n",
    "\n",
    "@show((distort,map2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calculating statistics for the outputs from PM and LM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Z = CSV.read(\"rand-dim$sc-scale1.csv\")\n",
    "D = zeros(100,100)\n",
    "for k = 1:50*99\n",
    "    i = Z[k,1]\n",
    "    j = Z[k,2]\n",
    "    D[i,j] = Z[k,3]\n",
    "    D[j,i] = Z[k,3]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist7 = zeros(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = collect(1:n) #npzread(\"./../../../grid-worm-order.npy\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "IndexToIdx = copy(p)\n",
    "for i = 1:length(p)\n",
    "    IndexToIdx[p[i]] = i\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "IndexToIdx;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "D2 = npzread(\"./zeisel-knn10.npy\")\n",
    "D2p = copy(D2)\n",
    "for i = 1:length(p)\n",
    "    for j = 1:i-1\n",
    "        D2p[i,j] = D2[IndexToIdx[i],IndexToIdx[j]]\n",
    "        D2p[j,i] = D2p[i,j]\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "α = tr(D2p[1:n,1:n]'*D)/tr(D2p[1:n,1:n]'*D2p[1:n,1:n])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utilities.avg_distortion(D2p*α,D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utilities.MAP(D2p[1:n,1:n],g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calculate statistics for the outputs from PT "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist8 = zeros(10)\n",
    "dist9 = zeros(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n=100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 10\n",
    "sc = 2^i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Z = CSV.read(\"rand-dim$sc-scale1.csv\")\n",
    "D = zeros(100,100)\n",
    "for k = 1:50*99\n",
    "    ii = Z[k,1]\n",
    "    j = Z[k,2]\n",
    "    D[ii,j] = Z[k,3]\n",
    "    D[j,ii] = Z[k,3]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "D2 = npzread(\"./rand-dim$sc-scale1-l-1024npy.npy\")\n",
    "for i = 1:n\n",
    "    D2[i,i] = 0\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "α = tr(D2[1:n,1:n]'*D)/tr(D2[1:n,1:n]'*D2[1:n,1:n])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist8[i] = utilities.avg_distortion(D2*α,D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utilities.MAP(D2[1:n,1:n],g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ultra method\n",
    "\n",
    "Not in Tree! I am no Tree! I am Low Dimensional Hyperbolic Embedding. (I think the originla algorithm has a bug)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0,
     8,
     37,
     86
    ]
   },
   "outputs": [],
   "source": [
    "function compute_Ca(D,a=1)\n",
    "    n = size(D)[1]\n",
    "    ma = maximum(D[a,:])\n",
    "    Ca = 2*ma .- (ones(n)*D[a,:]' + D[a,:]*ones(n)')\n",
    "    \n",
    "    return Ca\n",
    "end\n",
    "\n",
    "function compute_U(M)\n",
    "    n = size(M)[1]\n",
    "    G = CompleteGraph(n)\n",
    "    r = LightGraphs.prim_mst(G,M)\n",
    "    R = SimpleGraph(n)\n",
    "    for e in r\n",
    "        add_edge!(R,e)\n",
    "    end\n",
    "    \n",
    "    A = LightGraphs.adjacency_matrix(R)\n",
    "    Mp = A .* M\n",
    "    \n",
    "    nextroots = collect(n+2:3*n)\n",
    "    W = spzeros(3*n,3*n)\n",
    "    \n",
    "    W,nextroots = recurse_U(R,collect(1:n),Mp,nextroots,W,n+1)\n",
    "    \n",
    "    T = SimpleGraph(3*n)\n",
    "    I,J,V = findnz(W)\n",
    "    for k = 1:length(I)\n",
    "        add_edge!(T,I[k],J[k])\n",
    "        if W[I[k],J[k]] == -1\n",
    "            W[I[k],J[k]] = 0\n",
    "        end\n",
    "    end\n",
    "    \n",
    "    return T,W\n",
    "end\n",
    "\n",
    "function recurse_U(T,V,M,nextroots,W,r)\n",
    "    if length(V) == 1\n",
    "        W[r,V[1]] = -1\n",
    "        W[V[1],r] = -1\n",
    "        \n",
    "        return W,nextroots\n",
    "    end\n",
    "    \n",
    "    r1 = nextroots[1]\n",
    "    r2 = nextroots[2]\n",
    "    \n",
    "    n = size(M)[1]\n",
    "    I,J,U = findnz(M)\n",
    "    m = argmax(U)\n",
    "    i = I[m]\n",
    "    j = J[m]\n",
    "    \n",
    "    W[r,r1] = U[m]/2\n",
    "    W[r,r2] = U[m]/2\n",
    "    W[r2,r] = U[m]/2\n",
    "    W[r1,r] = U[m]/2\n",
    "    \n",
    "    Aij = M[i,j]\n",
    "    \n",
    "    rem_edge!(T,i,j)\n",
    "    C = connected_components(T)\n",
    "    \n",
    "    T1,V1 = induced_subgraph(T,C[1])\n",
    "    T2,V2 = induced_subgraph(T,C[2])\n",
    "    \n",
    "    M1 = M[V1,V1]\n",
    "    M2 = M[V2,V2]\n",
    "    \n",
    "    U1 = copy(V1)\n",
    "    U2 = copy(V2)\n",
    "    \n",
    "    for k = 1:length(V1)\n",
    "        V1[k] = V[V1[k]]\n",
    "    end\n",
    "    for k = 1:length(V2)\n",
    "        V2[k] = V[V2[k]]\n",
    "    end\n",
    "    \n",
    "    W,nextroots = recurse_U(T1,V1,M1,nextroots[3:end],W,r1)\n",
    "    W,nextroots = recurse_U(T2,V2,M2,nextroots,W,r2)\n",
    "   \n",
    "    return W,nextroots\n",
    "end\n",
    "\n",
    "function eps(D)\n",
    "    n = size(D)[1]\n",
    "    Ca = compute_Ca(D)\n",
    "    M = D + Ca\n",
    "    T,W = compute_U(M)\n",
    "    C = connected_components(T)\n",
    "    T1,_ = induced_subgraph(T,C[argmax(length.(C))])\n",
    "    U = LightGraphs.Parallel.floyd_warshall_shortest_paths(T1,W).dists\n",
    "    AD = U[1:n,1:n] - Ca\n",
    "    \n",
    "    @time R = NJ.nj!(copy(convert(Matrix{Float64},AD)))\n",
    "    g2 = SimpleGraph(R.numNodes)\n",
    "    w = spzeros(R.numNodes,R.numNodes)\n",
    "    for i = 1:R.numEdges\n",
    "        src = R.edge[i].node[1].number\n",
    "        dst = R.edge[i].node[2].number\n",
    "        add_edge!(g2,src,dst)\n",
    "        w[src,dst] = R.edge[i].length\n",
    "        w[dst,src] = w[src,dst]\n",
    "    end\n",
    "    \n",
    "    return g2,w\n",
    "end"
   ]
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Julia 1.4.1",
   "language": "julia",
   "name": "julia-1.4"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.4.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
