library(tidyverse)
library(ggpubr)

buckets <- c(0.0001, 0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 1.0)
bucketize <- function(x){
  for(y in buckets){
    if(x <= y){
      return(y)
    }
  }
  return(1.0)
}

allin <- read_csv("results/all_in.combined_csv") %>%
  mutate(time = time + 0.01)

# PERFORMANCE

data <- allin %>%
  filter(distance > 0, algorithm == "kgraph" | algorithm == "hnsw" | algorithm == "SW-graph") %>%
#  filter(distance > 0, algorithm == "kgraph" | algorithm == "hnsw" | algorithm == "SW-graph", dataset == "mnist") %>%
#  filter(distance > 0, algorithm == "kgraph" | algorithm == "hnsw" | algorithm == "SW-graph", dataset == "fashion") %>%
#  filter(distance > 0, algorithm == "kgraph" | algorithm == "hnsw" | algorithm == "SW-graph", dataset == "sift") %>%
#  filter(distance > 0, algorithm == "kgraph", dataset == "mnist" | dataset == "fastion") %>%
  mutate(parameters = paste(c2, c1, sep=", ")) %>%
  mutate(bucketdist = sapply(distance, bucketize))

count <- data %>%
  group_by_at(vars(parameters, bucketdist)) %>%
  summarise(n = n())

bucketed <- data %>%
  group_by_at(vars(c1, c2, parameters, decision, bucketdist)) %>%
  summarise(hit = n()) %>%
  left_join(count)

falsebase <- bucketed %>%
  filter(decision == "False") %>%
  mutate(rate = hit / n)

truefill <- bucketed %>%
  filter(decision == "True", hit == n) %>%
  mutate(rate = 1 - hit / n)

plotdata <- bind_rows(falsebase, truefill)

ggplot(plotdata, aes(factor(bucketdist), rate)) +
  ggtitle("all ANN algorithms, all datasets") +
#  ggtitle("all ANN algorithms, MNIST") +
#  ggtitle("all ANN algorithms, Fashion-MNIST") +
#  ggtitle("all ANN algorithms, SIFT") +
#  ggtitle("KGraph, MNIST+Fashion-MNIST, k=10") +
#  ggtitle("KGraph, MNIST+Fashion-MNIST, k=50") +
#  theme(legend.position="none") +
  scale_x_discrete(name = "bucketed distance") +
  scale_y_continuous(name = "recall") +
  scale_fill_brewer(palette = "Set1") +
  geom_col(aes(fill = parameters), position = "dodge")
