
############################################################################################
# witness.R
#
# Code to search for witnessess and return bounds. This includes functions that simulate
# cases
#
# Code by
#
#  - Ricardo Silva (ricardo@stats.ucl.ac.uk)
#  - Robin Evans (robin.evans@stats.ox.ac.uk)
#
# Current version: 15/08/2014
# First version: 31/03/2014

#library(spikeSlabGAM)

source("witness_util.R")
source("bayesian.R")

# Note: many functions use a object of type 'problem', a list with the fields
#
# - model: a list of the same size as the number of vertices in the graph of the problem.
#          Each entry model[[i]] is an array corresponding tois P(V_i = 1 | parents = p),
#          p being a parent configuration. Configurations are sorted as follows.
#          If [1 2 3 ... (k - 2) (k - 1) k] are the parents, the corresponding configuration
#          positions within vector model[[i]] are 1 + bin2dec(000...000), 1 + bin2dec(000...001),
#          1 + bin2dec(000...010), 1 + bin2dec(111...11)
#          [BEWARE: this is not the order in which R stores array data]
# - counts: array of count data over all variables
# - probs: array giving the joint distribution over all variables
# - graph: binary matrix where graph[i, j] = 1 if V_j is a parent of V_i, 0 otherwise
# - ancestrals: a list where ancestrals[[i]] is a vector of indices indicating the ancestrals of V_i,
#               *excluding* itself
# - X_idx, Y_idx: the indices of variables X and Y, given that the problem is to find the causal
#                 effect X --> Y defined by E[Y = 1 | do(X = 1)] - E[Y = 1 | do(X = 0)]
# - latent_idx: vector indicating which variables in the problem are latent

# Also notice: the approach requires the specification of a vector of 'epsilons',
# parameters which define which relaxations are allowed. This 6-dimensional vector represents
# the following information:
#
#     |eta_x0^star - eta_x1^star| <= epsilons[1]
#     |eta_x0^star - P(Y = 1 | X = x, W = 0)| <= epsilons[2]
#     |eta_x1^star - P(Y = 1 | X = x, W = 1)| <= epsilons[3]
#     |delta_w0^star - P(X = 1 | W = 0)| <= epsilons[4]
#     |delta_w1^star - P(X = 1 | W = 1)| <= epsilons[4]
#     epsilons[5] * P(U) <= P(U | W = w) <= epsilons[6] * P(U),
#
# where
#
#     eta_xw^star   == P(X = x | W = w, U)
#     delta_xw^star == P(Y = 1 | X = x, W = w, U)

############################################################################################
# witness_search::
#
# Given a problem, find the witnesses and adjustment sets (if any) for the causal relation
# X --> Y. NOTICE THIS DOES AN EXHAUSTIVE SEARCH AND SHOULD NOT BE USED FOR PROBLEMS WITH
# A LARGE NUMBER OF VARIABLES.
#
# * Input:
#
# - problem: a problem instance for the ACE of some X on some Y
# - min_only: for each witness, once a set of a particular size is found, don't look for
#             larger ones
# - pop_solve: if TRUE, assume we know the population graph instead of data
# - verbose: the usual
# - stop_at_first: if TRUE, stop as soon as some witness is found
#
# * Output:
#
# - witness, Z: the two pieces of information indicating the found witness x background set
#               combinations, stored separately as an array (witness) and a list (Z).
#               Each w[i], Z[[i]] is unique in value, but w may contain repeated entries
#               as different 
# - witness_score: corresponding scores of each witness/backdoor set, used to
#                  select a possible combination out of the available choices

witness_search <- function(problem, min_only = TRUE, pop_solve = FALSE, verbose = FALSE, stop_at_first = FALSE)
{
  x <- problem$X_idx
  y <- problem$Y_idx
  latents <- problem$latent_idx
  if (pop_solve)
    num_v <- ncol(problem$graph)
  else {
    num_v <- length(dim(problem$counts))
  }
  
  witness_pool <- seq_len(num_v)[-c(x, y, latents)]
  witness_choice <- c()
  Z_choice <- list()
  witness_score <- matrix(nrow = 0, ncol = 2)
  
  for (w in witness_pool) {
    if (length(witness_choice) > 0 && stop_at_first) break
    
    if (verbose) cat("################################# Testing witness", w, "\n")
    Z_set <- setdiff(witness_pool, w)
    num_combo <- 2^length(Z_set)
    Z_pick <- rep(0, length(Z_set))
    found_size <- Inf
    
    Zs <- combinations(rep(2,length(Z_set)))
    
    for (i in seq_len(num_combo)) {
      if (length(witness_choice) > 0 && stop_at_first) break
      Z <- Z_set[Zs[i,] > 0]
      if (length(Z) > found_size) next
      
      D1 <- infer_dep(w, y, Z, problem, pop_solve)
      D2 <- infer_dep(w, y, c(Z, x), problem, pop_solve)
      
      if (!D1$decision && D2$decision) {
        witness_choice <- c(witness_choice, w)
        Z_choice[[length(witness_choice)]] <- Z
        witness_score <- rbind(witness_score, c(D1$scores[2] - D1$scores[1], D2$scores[1] - D2$scores[2]))
        if (verbose) cat("Accepting witness", w, " set", Z, "\n")
        found_size <- length(Z)
      }
    }
  }
  
  return(list(witness = witness_choice, Z = Z_choice, witness_score = witness_score))
}

############################################################################################
# infer_dep::
#
# Chech whether x and y are independent given Z, as estimated from data. This chooses a
# method based on the size of Z.
#
# * Input:
#
# - x, y: the indices of the two variables to test
# - Z: the indices of the conditioning set
# - problem: a problem instance containing the counts
# - pop_solve: if TRUE, uses the problem instance's graph instead of the counts. So this
#              becomes an oracle
#
# * Output:
#
# - TRUE, if x is judged to be independent of y given Z, FALSE otherwise

infer_dep <- function(x, y, Z, problem, pop_solve)
{
  if (pop_solve) {
    decision <- dsep(x, y, Z, problem$graph, problem$ancestrals)
    if (decision)
      scores <- log(c(1, 0))
    else
      scores <- log(c(0, 1))
    return(list(decision = decision, scores = scores))
  }
  #if (length(Z) > 10) return(infer_dep_bayes_logit(x, y, Z, counts=problem$counts)) # Future extension
  return(infer_dep_bayes_saturated(x, y, Z, problem$counts))
}


############################################################################################
# infer_dep_bayes_saturated::
#
# Chech whether x and y are indpendent given Z, as estimated from data. This uses
# the saturated model with Heckerman's BDe score with an uniform prior over the two
# hypotheses.
#
# * Input:
#
# - x, y: the indices of the two variables to test
# - Z: the indices of the conditioning set
# - counts: array of counts
# - prior: hyperparameter for the prior (effective sample size)
#
# * Output:
#
# - decision: TRUE, if x is judged to be independent of y given Z, FALSE otherwise
# - scores: scores[1] is the log marginal likelihood of the model where x and y are independent given Z,
#           while scores[2] is the case where x and y are dependent

infer_dep_bayes_saturated <- function(x, y, Z, counts, prior = 10)
{
  scores <- c(0, 0)  
  
  ## get index corresponding to state of parents
  ## get score with x as a parent
  
  alpha_yj  <- prior / 2^(length(Z) + 1)
  alpha_yj0 <- alpha_yj / 2
  alpha_yj1 <- alpha_yj / 2

  counts2 <- marginTable(counts, c(y, x, rev(Z)))
  N_yj1   <- subtable(counts2, 1, 2)
  N_yj0   <- subtable(counts2, 1, 1)
  N_yj    <- N_yj0 + N_yj1
  
  scores[2] <- sum(lgamma(alpha_yj) - lgamma(alpha_yj + N_yj) +
                   lgamma(alpha_yj0 + N_yj0) - lgamma(alpha_yj0) +
                   lgamma(alpha_yj1 + N_yj1) - lgamma(alpha_yj1))
  
  ## now without x as a parent
  
  alpha_yj  <- prior / 2^length(Z)
  alpha_yj0 <- alpha_yj / 2
  alpha_yj1 <- alpha_yj / 2

  #counts2 <- marginTable(counts2, c(1, seq_along(Z) + 2))
  N_yj1 <- marginTable(N_yj1, seq_along(Z) + 1)
  N_yj0 <- marginTable(N_yj0, seq_along(Z) + 1)
  N_yj  <- N_yj0 + N_yj1
  
  #counts2 <- marginTable(counts, c(y, rev(Z)))
  #if (length(Z) > 0) {
  #  N_yj1 <- subtable(counts2, 1, 2)
  #  N_yj0 <- subtable(counts2, 1, 1)
  #} else {
  #  N_yj1 <- counts2[2]
  #  N_yj0 <- counts2[1]
  #}
  #N_yj <- N_yj0 + N_yj1
  
  scores[1] <- sum(lgamma(alpha_yj) - lgamma(alpha_yj + N_yj) +
                   lgamma(alpha_yj0 + N_yj0) - lgamma(alpha_yj0) +
                   lgamma(alpha_yj1 + N_yj1) - lgamma(alpha_yj1))
  
  return(list(decision = scores[1] > scores[2], scores = scores))
}

#######################################################################################################
# bound_search::
#
# Given a problem, first find the witnesses and adjustment sets (if any) for the causal relation
# X --> Y. NOTICE THIS DOES AN EXHAUSTIVE SEARCH AND SHOULD NOT BE USED FOR PROBLEMS WITH
# A LARGE NUMBER OF VARIABLES.
#
# After all admissable (witness, adjustment set) are found, generate bounds
#
# * Input:
#
# - problem: a problem instance that containts the indices of X and Y, the set of latent variables.
#            It should either contain an array of count data, or the true DAG and ancestral relationships.
# - epsilons: the relaxation parameters. See comments at the beginning of this file
# - M: sample size for Monte Carlo simulations
# - pop_solve: if TRUE, assume we know the population graph instead of data. Notice that data is
#              still used when computing posteriors over bounds.
#
# * Output:
#
# - w_list: a list, where w_list[[i]]$witness is a witness and w_list[[i]]$Z is the corresponding
#           background set
# - bounds: a matrix where each row corresponds to a different bound, and the two columns correspond to
#           an estimate of the lower bound and upper bound

witness_bound_search <- function(problem, epsilons, M, verbose = FALSE, pop_solve = FALSE)
{
  w_list <- witness_search(problem, TRUE, pop_solve, verbose)
  if (length(w_list$witness) == 0) {
    cat("No solution found\n")
    return(list(w_list = w_list, bounds = list()))
  }
  
  N <- sum(problem$counts)
  bounds <- matrix(0, nrow = length(w_list$witness), ncol = 2)
  
  for (i in seq_along(w_list$witness)) {
    
    w <- w_list$witness[i]
    if (verbose) cat(i, ":: ################################# Bounding using witness", w, "\n")
    Z <- w_list$Z[[i]]
    
    theta_samples <- witness_posterior_sampling_saturated(w, problem$X_idx, problem$Y_idx, Z,
                                                          counts = problem$counts, epsilons = epsilons, M = M,
                                                          verbose = verbose, prior_sample = FALSE)
    num_states <- 2^length(Z)
    
    P_Z_hat <- c(marginTable(problem$counts, rev(Z)))
    P_Z_hat <- P_Z_hat / sum(P_Z_hat) # That is, not completely Bayesian yet
    
    for (j in seq_len(num_states)) {
      if (nrow(theta_samples[[j]]$W0) > 0 || P_Z_hat[j] == 0) {
        if (P_Z_hat[j] > 0) {
          bounds[i, ] <- bounds[i, ] + colMeans(bayesian_interval_generation_analytical(theta_samples[[j]], epsilons), na.rm = TRUE) * P_Z_hat[j]
        }
      }
      else {
        bounds[i, ] <- c(NA, NA) # This mark the i-th entry to be purged later
      }
    }
    if (verbose) cat("\n")
    
  }
  
  # Purge failed witnessess
  
  f_witness <- c()
  f_Z <- list()
  f_witness_score <- matrix(nrow = 0, ncol = 2)
  f_bounds <- matrix(nrow = 0, ncol = 2)
  for (i in 1:length(w_list$witness)) {
    if (!is.na(bounds[i, 1]) && !is.na(bounds[i, 2])) {
      f_witness <- c(f_witness, w_list$witness[i])
      f_Z[[length(f_witness)]] <- w_list$Z[[i]]
      f_witness_score <- rbind(f_witness_score, w_list$witness_score[i,])
      f_bounds <- rbind(f_bounds, bounds[i,])
    }
  }
  f_w_list <- list(witness = f_witness, Z = f_Z, witness_score = f_witness_score)
  
  # Finalize
  
  if (length(f_witness) == 0)
    cat("No solution found\n")
  
  return(list(w_list = f_w_list, bounds = f_bounds))
}


#######################################################################################################
# witness_posterior_sampling_saturated::
#
# Generates a sample of the posterior distribution of the lower bound/upper bound of the
# ACE X --> Y with instrumental variable W. This averages over the distribution of background
# variables Z, where the averaging is exhaustive (so it doesn't scale to sets Z with more than
# a dozen items or so). NOTE: the conditioning distribution P(Z) is not, however, being estimated in
# a Bayesian way at this point - the empirical distribution is used instead - but this is something
# easy to fix.
#
# * Input:
#
# - w, x, y, Z: the indices of the corresponding variables (Z being a vector here)
# - dat: data (used only if counts missing)
# - epsilons: the relaxation parameters. See comments at the beginning of this file
# - M: sample size for Monte Carlo simulations
# - verbose: the usual meaning
# - prior_sample: if TRUE and there are no data points corresponding to a particular value of Z,
#                 then sample from the prior. Otherwise, skip
# - numerical_method: if TRUE, call numerical approach for bound calculation instead of analytical one
#
# * Output:
#
# - theta_samples: a list with fields W (posterior samples of P(W = 1)), W0 (posterior samples
#                  of P(X, Y | W = 0)), and W1 (analogously). W is a vector of M entries,
#                  W0 is a M x 4 table, where column indices 1, 2, 3, 4 correspond to combinations
#                  (Y = 0, X = 0), (Y = 0, X = 1), (Y = 1, X = 0), (Y = 1, X = 1)
# - counts: array of counts

witness_posterior_sampling_saturated <- function(w, x, y, Z, dat, epsilons, M, counts,
                                                 verbose = FALSE, prior_sample = FALSE, numerical_method = FALSE)
{
  
  if (missing(counts)) counts = array(tabulate(dat %*% 2^(seq_len(ncol(dat)) - 1) + 1, nbins = 2^ncol(dat)), rep(2, ncol(dat)))
  
  N <- sum(counts)
  num_states <- 2^length(Z)
  alpha_0 <- rep(10, 4) / num_states
  alpha_1 <- rep(10, 4) / num_states
  alpha_W <- c(10, 10) / num_states
  theta_samples <- list()
  
  ## matrix: each row corresponds to state of Z, each col to state of (w,x,y)
  totals = marginTable(counts, c(rev(Z),w,x,y))
  dim(totals) = c(num_states,8)
  
  for (i in 1:num_states) {
    if (verbose) {
      cat("witness = ", w, ", Z = ", sep = "")
      cat(Z)
      cat(", state = ", paste(as.integer(intToBits(i - 1)[rev(seq_along(Z))]), collapse = ""), sep = "")
      if (sum(totals[i,]) == 0) cat(" [no data]")
      cat("\n")
    }
    
    if (sum(totals[i,]) > 0 || prior_sample) {
      theta_samples[[i]] <- bayesian_posterior_sampling(alpha_0, alpha_1, alpha_W, epsilons = epsilons, 
                                                        M = M, verbose = verbose, numerical_method = numerical_method,
                                                        counts = array(totals[i,], rep(2, 3)))
    } else {
      theta_samples[[i]] <- list(W = c(),
                                 W0 = matrix(nrow = 0, ncol = 8),
                                 W1 = matrix(nrow = 0, ncol = 8),
                                 rejection_rate = 1)
    }
    if (theta_samples[[i]]$rejection_rate == 1 && i < num_states) {
      # Useless witness, terminate earlier
      theta_samples[(i + 1):num_states] = list(list(W = c(),
                                                    W0 = matrix(nrow = 0, ncol = 8),
                                                    W1 = matrix(nrow = 0, ncol = 8),
                                                    rejection_rate = 1))[rep(1,num_states-i)]
      break
    }
  }
  
  return(theta_samples)
}

#######################################################################################################
# witness_summarize_bounds::
#
# Gets the output of witness_bound_search and provides a handful of summaries.
#
# * Input:
#
# - v: the output of a run of witness_bound_search
#
# * Output:
#
# - brackets: a discretization of the [-1, 1] space which we use to estimate the areas which
#             are covered more frequently by the different bounds
# - b_prop: frequency in which a particular bracket is covered by some bound
# - min_l, max_u: minimum of all lower bounds, maximum of all upper bounds
# - tightest_bound: tighest of all bounds
# - highest_bound: bound associated with highest score
# - chosen_w, chosen_Z: witness and background set

witness_summarize_bounds <- function(v, problem, epsilons, verbose = FALSE, taboo_vars = c())
{
  num_brackets <- 100
  brackets <- seq(-1, 1, 2 / (num_brackets - 1))
  b_count <- rep(0, num_brackets - 1)
  num_bounds <- length(v$w_list$witness)
  if (num_bounds == 0)
    return(list(brackets = brackets, b_prop = rep(0, num_brackets - 1), min_l = -1, max_u = 1,
                tightest_bound = c(-1, 1), highest_bound = c(-1, 1), chosen_w = 0, chosen_Z = c()))
  
  min_l <- Inf; max_u <- -Inf
  tightest_bound <- rep(0, 2); t_b <- Inf; chosen_w_t <- 0; chosen_Z_t <- c()
  for (i in 1:num_bounds) {
    if (is.na(v$bounds[i, 1]) || is.na(v$bounds[i, 2])) next
    l <- round(v$bounds[i, 1] * num_brackets) / num_brackets; min_l <- min(l, min_l)
    u <- round(v$bounds[i, 2] * num_brackets) / num_brackets; max_u <- max(u, max_u)
    if (u - l < t_b) {
      t_b <- u - l
      tightest_bound <- c(l, u)
      chosen_w_t <- v$w_list$witness[i]
      chosen_Z_t <- v$w_list$Z[[i]]
    }
    b_i <- which(brackets >= l && brackets <= u)
    b_count[b_i] <- b_count[b_i] + 1    
  }
  b_prop <- b_count / num_bounds
  
  scores <- rowSums(v$w_list$witness_score)
  for (i in 1:num_bounds) {
    if (length(intersect(v$w_list$Z[[i]], taboo_vars)) > 0) {
      scores[i] <- -Inf
    }
  }
  idx_highest <- which.max(scores)
  highest_bound <- v$bounds[idx_highest, ]
  chosen_w <- v$w_list$witness[idx_highest]
  chosen_Z <- v$w_list$Z[[idx_highest]]
  
  numerical_bound <- witness_generate_cheap_numerical_bound(problem, epsilons, chosen_w, chosen_Z, 1000, verbose = verbose)
  if (!is.null(numerical_bound)) {
    highest_bound <- numerical_bound
  } else {
    cat("Length of sub-optimal bound:", highest_bound[2] - highest_bound[1], "\n")
  }
  
  return(list(brackets = brackets, b_prop = b_prop, min_l = min_l, max_u = max_u, 
              chosen_w_tighest = chosen_w_t, chosen_Z_tighest = chosen_Z_t, tightest_bound = tightest_bound, 
              chosen_w = chosen_w, chosen_Z = chosen_Z, highest_bound = highest_bound))
}

