############################################################################################
# witness_util.R
#
# Code to generate synthetic cases for the witness search procedure and other utility functions
#
# Code by
#
#  - Ricardo Silva (ricardo@stats.ucl.ac.uk)
#  - Robin Evans (robin.evans@stats.ox.ac.uk)
#
# Current version: 15/08/2014
# First version: 30/08/2014

source("logis.R")

library(igraph)
library(rje) # for functions 'combinations()' and 'marginTable()'

# See documentation in 'witness.R' for more details on the data structures that are generated here

############################################################################################
# dtoc::
#
# Convert a dataset into a count table.
#
# Input:
# 
# - d: a dataset
#
# Output:
#
# - the count representation of the dataset

dtoc <- function(d) 
{
  nc = ncol(d)
  tmp = tabulate(d %*% 2^(seq_len(nc)-1)+1, nbins=2^nc)
  return(array(tmp, rep(2,nc)))
}

############################################################################################
# simulate_witness_model_fixed::
#
# This simulated from a prefixed witness model:
#      ACE of interest, P(Y = 1 | do(X = 1)) - P(Y = 1 | do(X = 0))
#
# The model is the graph
#
#  Z1 <- Z2
#  Z2 <- \empty
#  Z3 <- {Z1, Z4}
#  Z4 <- \empty
#  X  <- Z1
#  Y  <- {X, Z2, Z4}
#
#  Latent variables: Z4
#
# So the adjustment set should be Z2, with Z1 the witness. Z3 should not be included
#
# * Input:s
#
# - M: sample_size

simulate_witness_model_fixed <- function(M)
{
  g <- matrix(0, 6, 6) # variable order: Z1, Z2, Z3, Z4, X, Y

  g[1, 2] <- 1          # Parents of Z1
  g[3, c(1, 4)] <- 1    # Parents of Z3
  g[5, 1] <- 1          # Parents of X
  g[6, c(2, 4, 5)] <- 1 # Parents of Y

  m <- get_bin_DAG_model_data(g, M)
  return(c(m, list(graph = g, X_idx = 5, Y_idx = 6, latent_idx = 4)))
}

############################################################################################
# simulate_witness_model_fixed_simplest::
#
# A even simpler fixed case: X <-- Z, Y <-- X, Z. Z is latent, so no possibility of finding
# a witness.

simulate_witness_model_fixed_simplest <- function(M)
{
  g <- matrix(rep(0, 9), ncol = 3)

  g[2, 1]       <- 1 # Parents of X
  g[3, c(1, 2)] <- 1 # Parents of Y

  m <- get_bin_DAG_model_data(g, M)
  return(c(m, list(graph = g, X_idx = 2, Y_idx = 3, latent_idx = 1)))
}

############################################################################################
# get_bin_DAG_model_data::
#
# Given a DAG, generate a binary model and a simulate corresponding sample
#
# * Input
#
# - g: binary matrix representing a DAG, each row represent the corresponding parents
# - M: sample size
#
# * Output
#
# list containing
# - model: a list of CPTs, where model[[i]] is P(V_i = 1 | parents = p), p being a parent
#          configuration. Configurations are sorted as follows. If [1 2 3 ... (k - 2) (k - 1) k]
#          are the parents, the corresponding configuration indices are
#          1 + bin2dec(000...000), 1 + bin2dec(000...001), 1 + bin2dec(000...010), 1 + bin2dec(111...11)
# - probs: array giving joint distribution over all variables
# - counts: array of count data over all variables. Again, this assumes modest dimensionality
# - ancestors: list containing the strict ancestors of each vertex

get_bin_DAG_model_data <- function(g, M)
{
  num_v <- nrow(g)

  ## get topological ordering
  graph_edge <- c(rbind(col(g)[g==1], row(g)[g==1]))
  graph_temp <- graph(graph_edge)
  v_order <- topological.sort(graph_temp)

  ## Create model and sample
  model <- list()
  ancestrals <- list()
  probs <- array(1, rep(2, num_v))

  for (v in v_order) {
    
    parents <- which(g[v, ] == 1)    
    
    #cpt <- runif(2^length(parents)) # Problems generated like this are far too easy
    
    cpt <- logitDist(length(parents), nu = 10, TRUE) # Non-linearities, to make problems harder   
    prob_cpt <- which(cpt > 0.975)                   # Allowing  extreme probabilities also make the problems much harder
    cpt[prob_cpt] <- 0.95 + runif(length(prob_cpt)) * 0.025
    prob_cpt <- which(cpt < 0.025)
    cpt[prob_cpt] <- 0.05 - runif(length(prob_cpt)) * 0.025
    
    model[[v]] <- cpt

    np = length(parents)
    if (np > 0) {
      bef = sum(parents < v)
      perm = c(seq(from = np, by = -1, length = bef), np + 1, rev(seq_len(np - bef)))
      cpt = aperm(array(c(1 - cpt, cpt), rep(2, np + 1)), perm)
    }
    else cpt = c(1-cpt, cpt)
    patt = patternRepeat(cpt, c(v, parents), rep(2,num_v))
    probs = probs * patt

    ancestrals[[v]] <- union(parents, unlist(ancestrals[parents]))
    ancestrals[[v]] <- sort.int(ancestrals[[v]])
    
  }

  counts = rmultinom(1, M, c(probs))
  dim(counts) = dim(probs)

  return(list(model = model, counts = counts, ancestrals = ancestrals, probs=probs))
}

############################################################################################
# bindag_monte_carlo_causal_effect::
#
# Finds the true ACE for X --> Y using Monte Carlo simulation.
#
# * Input:
#
# - problem: the problem specification
# - M: sample size for the simulation
#
# * Output:
#
# - a scalar, the true ACE (estimated by Monte Carlo)

bindag_monte_carlo_causal_effect <- function(problem, M)
{
  num_v <- nrow(problem$graph)

  graph_edge <- c()
  for (v in 1:num_v) {
    parents <- which(problem$graph[v, ] == 1)
    for (p in parents) graph_edge <- c(graph_edge, c(p, v))
  }
  graph_temp <- graph(graph_edge)
  v_order <- topological.sort(graph_temp)

  ## Modify model and sample
  model <- list()
  data <- matrix(rep(0, M * num_v), ncol = num_v)

  for (v in v_order) {

    if (v != problem$X_idx) {
      parents <- which(problem$graph[v, ] == 1)
      cpt <- problem$model[[v]]
    } else {
      parents <- c()
      cpt <- 0.5
    }

    if (length(parents) == 0) {
      d <- runif(M) < cpt[1]
      data[, v] <- d
      next
    }

    p_state <- rep(0, length(parents))
    num_combo <- 2^length(parents)
    for (p in 1:num_combo) {

      ## Generate data
      row_idx <- rep(1, M)
      for (j in 1:length(parents)) row_idx <- row_idx * (data[, parents[j]] == p_state[j])
      d <- runif(sum(row_idx)) < cpt[p]
      data[which(row_idx == 1), v] <- d

      ## Advance to the next
      for (k in length(parents):1) {
        if (p_state[k] == 0) {p_state[k] <- 1; break}
        p_state[k] <- 0;
      }

    }

  }

  do_0.rows <- which(data[, problem$X_idx] == 0)
  do_1.rows <- which(data[, problem$X_idx] == 1)
  effect <- mean(data[do_1.rows, problem$Y_idx]) - mean(data[do_0.rows, problem$Y_idx])

  ## Use problem$counts to estimate naive causal effect P(Y = 1 | X = 1) - P(Y = 1 | X = 0)
  tmp = conditionTable(problem$counts, problem$Y_idx, problem$X_idx)
  effect_naive2 <- tmp[2, 2] - tmp[2, 1]

  return(list(effect_real = effect, effect_naive2 = effect_naive2))

}

############################################################################################
# bindag_causal_effect::
#
# Finds the true ACE for X --> Y, plus two other (usually biased) effects. This uses brute
# force marginalization, so it barely scales up to more than a dozen variables.
#
# * Input:
#
# - problem: the problem specification
#
# * Output:
#
# - effect_real: the true ACE
# - effect_naive: uses the whole set of observed variables as the adjustment set.
#                 Biased, in general.
# - effect_naive2: naive ACE, P(Y = 1 | X = 1) - P(Y = 1 | X = 0).
# tested

bindag_causal_effect <- function(problem)
{
  num_v <- nrow(problem$graph)
  num_states <- 2^num_v

  ## store parents as matrix of Ts and Fs
  nparents = rowSums(problem$graph)

  ## get joint distribution
  probs = problem$probs
  
  ## Naive effect: condition on everybody (non-latent)
  S <- seq_len(num_v)[-c(problem$X_idx, problem$Y_idx, problem$latent_idx)]
  k <- length(S)
  P_SXY = marginTable(probs, c(S, problem$X_idx, problem$Y_idx))
  P_SX <- marginTable(P_SXY, seq_len(k+1))
  P_S <- marginTable(P_SX, seq_len(k))

  num_s <- 2^k

  sq = seq(num_s)
  effect_naive <-  sum((P_SXY[3 * num_s + sq] / P_SX[num_s + sq]  -
                        P_SXY[2 * num_s + sq] / P_SX[        sq]) * P_S[sq])

  ## Naive effect 2: condition on nobody
  P_Y.X <- conditionTable(P_SXY, k+2, k+1)
  effect_naive2 <- P_Y.X[2,2] - P_Y.X[2,1]

  ## Real effect: condition on everybody (including latents)
  S <- seq_len(num_v)[-c(problem$X_idx, problem$Y_idx)]
  k <- length(S)
  P_SXY = marginTable(probs, c(S, problem$X_idx, problem$Y_idx))
  P_SX <- marginTable(P_SXY, seq_len(k+1))
  P_S <- marginTable(P_SX, seq_len(k))

  num_s <- 2^k
  sq = seq(num_s)

  effect =  sum((P_SXY[3 * num_s + sq] / P_SX[num_s + sq]  -
                 P_SXY[2 * num_s + sq] / P_SX[        sq]) * P_S[sq])

  return(list(effect_real = effect, effect_naive = effect_naive, effect_naive2 = effect_naive2))

}

############################################################################################
# bindag_causal_effect_backdoor_S::
#
# As before, but the one given by backdoor adjustment S
#
# * Input:
#
# - problem: the problem specification
# - S: index of the variables used in the backdoor adjustment
#
# * Output:
#
# - effect: ACE resulting from this adjustment

bindag_causal_effect_backdoor_S <- function(problem, S)
{
  ### get joint distribution
  if (!is.null(problem$probs)) {
    probs <- problem$probs
  } else {
    probs <- problem$counts / sum(problem$counts) # Empirical counts
  }

  ## Naive effect: condition on everybody (non-latent)
  k <- length(S)
  P_SXY <- marginTable(probs, c(S, problem$X_idx, problem$Y_idx))
  P_SX <- marginTable(P_SXY, seq_len(k+1))
  P_S <- marginTable(P_SX, seq_len(k))

  num_s <- 2^k

  sq <- seq(num_s)
  ACE <- sum((P_SXY[3 * num_s + sq] / P_SX[num_s + sq]  -
                P_SXY[2 * num_s + sq] / P_SX[        sq]) * P_S[sq])
  odds <- sum(P_SXY[3 * num_s + sq] / P_SX[num_s + sq] * P_S[sq]) / 
          sum(P_SXY[2 * num_s + sq] / P_SX[        sq] * P_S[sq])

  return(list(ACE = ACE, odds = odds))
}


############################################################################################
# bindag_getmarginal::
#
# Given an array of numbers representing the full contingency table of a multivariate binary
# distribution, returns its marginal over a subset S. Again, this can only handle small
# dimensional problems, around a dozen variables.
#
# * Input:
#
# - table: array representation of a contingency table. If we have p variables, V_1 to V_p,
#          indexing is done by imagining the binary string (V_1, V_2, ..., V_p) and converting
#          it to decimal, plus 1. So 000...00 gets mapped to position 1, 000...01 gets mapped
#          to position 2, etc 111...11 to position 2^p.
# - S: vector indicating the indices of the set whose marginal we want

bindag_getmarginal <- function(table, S)
{
  k = log2(length(table))
  dat = array(table, rep(2,k))
  mar = marginTable(dat, rev(k+1-S))
  return(c(mar))
}

############################################################################################
# dsep::
#
# Returns true if x is d-separated from y given S in a dag (with redundant ancestral sets
# pre-computed).
#
# * Input:
#
# - x, y: indices of X and Y
# - S: indices of conditioning set S
# - dag: the DAG, as a binary adjacency matrix
# - ancestrals: a list indicating for each variable, which ancestrals they have (excluding themslves)

dsep <- function(x, y, S, dag, ancestrals)
{

  ## Get ancestral graph first
  num_v <- dim(dag)[1]
  c_anc <- rep(0, num_v)
  c_anc[c(x, y, S)] <- 1
  c_anc[ancestrals[[x]]] <- 1; c_anc[ancestrals[[y]]] <- 1;
##  for (s in S) c_anc[ancestrals[[s]]] <- 1
  c_anc[unlist(ancestrals[S])] <- 1
  dag[c_anc == 0, ] <- 0


  ## Marry parents
  for (s in S) {
    parents <- (dag[s, ] == 1)
    dag[parents, parents] <- 1
  }

  ## Remove S and non-ancestors
  rmv = c(which(c_anc == 0), S)
  if (length(rmv)) {
    G <- dag[-rmv, -rmv]
    x = x - sum(rmv < x)
    y = y - sum(rmv < y)
  }
  else G <- dag

  ## Convert to undirected
  G <- (G + t(G) > 0)

  ## G[S, ] <- 0; G[, S] <- 0
  if (sum(G[x, ]) == 0 || sum(G[y,  ]) == 0) return(TRUE)

  ## test for graph connectivity between x and y; turns out to be
  ## quicker to use matrix multiplication than igraph
  G2 = G
  for (i in seq_len(dim(G)[1])) {
    if (G2[x,y] > 0) return(FALSE)
    G2 = G %*% G2
  }
  return(TRUE)
}

############################################################################################
# simulate_witness_model_basic::
#
# Generates a general witness problem, where two variables are our X and Y corresponding to
# the causal effect of interest. The way this is generated is very straightforward and
# results in 'easy' problems where conditioning on all covariates (regardless of causal
# structure) is on average a 'good enough' solution, possibly due to many symmetries in
# in the generation.
#
# * Input
#
# - p: number of variables (besides X and Y)
# - par_max: maximum number of parents
# - p_latent: probability of
# - M: sample size
#
# * Output
#
# - problem: a problem instance

simulate_witness_model_basic <- function(p, par_max, p_latent, M)
{
  ## Generate graph
  g <- matrix(0, p+2, p+2) # variable order: X and Y are p - 1 and p, respectively

  for (i in 2:(p + 2)) {
    num_par_i <- min(i - 1, sample(1:par_max, 1))
    par_i <- sample(1:(i - 1), num_par_i)
    g[i, par_i] <- 1
  }
  g[p + 2, p + 1] <- 1

  ## Set some variables with two or more children to latent
  latent_idx <- c()
  for (i in 1:p) {
    if (sum(g[, i]) > 1 && runif(1) < p_latent)
      latent_idx <- c(latent_idx, i)
  }

  m <- get_bin_DAG_model_data(g, M)
  return(c(m, list(graph = g, X_idx = p + 1, Y_idx = p + 2, latent_idx = latent_idx)))
}

############################################################################################
# simulate_witness_model::
#
# An attempt to make harder problems, since previous methods tend to generate models where
# conditioning on everybody is nearly as good as conditioning on a correct set. The way
# this works is as follows: first generate a graph with a given number of background variables
# which have no latent common parents with X and Y. Then generate a set of 'sink' variables K
# which have one common latent parent with either X or Y. Latent variables are a pool of independent
# variables with no parents, and are parents of either X or Y but not both, if no_sol = FALSE.
# If no_sol = TRUE, then all latent variables are parents of both X and Y.
#
# * Input
#
# - p: number of background variables (besides X and Y)
# - q: number of sink variables
# - par_max: maximum number of parents in the background set
# - M: sample size
# - no_sol: if TRUE (no solution == TRUE), then latent variables are parents of both X and Y
#
# * Output
#
# - problem: a problem instance

simulate_witness_model <- function(p, q, par_max, M, no_sol = FALSE, verbose = FALSE)
{
  ## Preliminaries: here, do rejection sampling on a smaller model to
  ## test whether there is a witness set or not
  if (par_max < 1) stop("Must allow at least one parent")

  if (no_sol == FALSE) {
    while (TRUE) {
      g_dummy <- matrix(0, nrow = p+2, ncol = p+2)

      num_par = pmin(seq_len(p)-1,sample.int(par_max, p, replace=TRUE))
      for (i in 2:p) {
        num_par_i <- min(i-1, sample.int(par_max, 1))
        par_i <- sample.int(i-1, num_par[i])
        g_dummy[i, par_i] <- 1
      }
      for (i in (p + 1):(p + 2)) {
        num_par_i <- min(p, sample(par_max, 1))
        par_i <- sample(p, num_par_i)
        g_dummy[i, par_i] <- 1
      }
      g_dummy[p + 2, p + 1] <- 1

      m <- get_bin_DAG_model_data(g_dummy, 1)
      problem_dummy <- list(graph = g_dummy, X_idx = p + 1, Y_idx = p + 2, latent_idx = c(),
                            ancestrals = m$ancestrals)
      true_witness <- witness_search(problem_dummy, TRUE, TRUE, FALSE, TRUE)
      if (length(true_witness$witness) > 0) break
    }
  }

  ## Generate graph
  q <- max(2, q) # Minimum number of sink variables is 2
  num_h <- q     # The pool of latent variables is also given by q
  num_var <- p + q + 2 + num_h

  X_idx <- p + q + 1
  Y_idx <- p + q + 2

  g <- matrix(rep(0, num_var^2), ncol = num_var) # variable order: X and Y are indexed as p + q + 1 and
                                                 # p + q + 2. Latent variables are indexed
                                                 # as p + q + 3 ... num_var
                                                 # Observable variables
  for (i in 2:(p + q)) {
    num_par_i <- min(i - 1, sample(1:par_max, 1))
    par_i <- sample(1:(i - 1), num_par_i)
    g[i, par_i] <- 1
  }
  for (i in c(X_idx, Y_idx)) {
    num_par_i <- min(p, sample(1:par_max, 1))
    par_i <- sample(1:p, num_par_i)
    g[i, par_i] <- 1
  }
  g[Y_idx, X_idx] <- 1

  ## If g_dummy has been generated, override g with it
  if (no_sol == FALSE) {
    g[1:p, 1:p]   <- g_dummy[1:p, 1:p]
    g[X_idx, 1:p] <- g_dummy[p + 1, 1:p]
    g[Y_idx, 1:p] <- g_dummy[p + 2, 1:p]
  }

  ## for (i in (p + 1):(p + q)) g[i, (p + q + 3):num_var] <- 1 # SINK <- all H
  g[p+seq_len(q), (p + q + 3):num_var] <- 1 # SINK <- all H

  latent_idx <- (p + q + 3):num_var

  if (no_sol) { # Guarantee no solution: make all latent parents be parents of X and Y
    g[X_idx, p + q + 2 + 1:q] <- 1
    g[Y_idx, p + q + 2 + 1:q] <- 1
  }
  else {
    ## Latent variables: flip a coin, then decide whether latent variable will be linked to
    ## X or Y. But the first one is assigned to X and the second to Y
    g[X_idx, p + q + 3] <- 1
    g[Y_idx, p + q + 4] <- 1

    if (q > 2) {
      #whX <- rbinom(q - 2, size = 1, prob = 0.5)
      #g[X_idx, p+q+2+seq(3,q)[whX]] = 1 # For some reason, this doesn't work. R bug?
      #g[Y_idx, p+q+2+seq(3,q)[!whX]] = 1
      whX <- runif(q - 2, 0.5)
      g[X_idx, p + q + 2 + seq(3, q)[whX <  0.5]] <- 1
      g[Y_idx, p + q + 2 + seq(3, q)[whX >= 0.5]] <- 1      
      ## for (h in 3:q){
      ##   if (runif(1) < 0.5)
      ##     g[X_idx, p + q + 2 + h] <- 1  # X <- H_h
      ##   else
      ##     g[Y_idx, p + q + 2 + h] <- 1  # Y <- H_h
      ## }
    }
  }

  ## Generate model
  if (verbose) {
    cat("Graph structure:\n\n")
    l_idx <- rep(0, num_var); l_idx[latent_idx] <- 1
    for (i in 1:num_var) {
      cat(i, ":", which(g[i, ] == 1))
      if (l_idx[i]) cat(" [LATENT]", sep = "")
      if (i > p && i <= p + q)  cat(" [SINK]", sep = "")
      if (i == X_idx) cat(" [X]", sep = "")
      if (i == Y_idx) cat(" [Y]", sep = "")
      cat("\n")
    }
  }
  m <- get_bin_DAG_model_data(g, M)
  return(c(m, list(graph = g, X_idx = X_idx, Y_idx = Y_idx, latent_idx = latent_idx)))
}

############################################################################################
# witness_print_graph::
#
# Just prints graph information in a way that is easier to read.
#
# Input:
# 
# - problem: a problem instance, as defined in 'witness.R'

witness_print_graph <- function(problem)
{
  cat("  X = ", problem$X_idx, ", Y = ", problem$Y_idx, "\n", sep = "")
  is_latent <- rep(0, ncol(problem$graph))
  is_latent[problem$latent_idx] <- 1
  for (i in 1:ncol(problem$graph)) {
    cat(" ", i, ":", which(problem$graph[i,] == 1))
    if (is_latent[i]) cat(" [LATENT]", sep = "")
    if (problem$X_idx == i) cat(" [X]", sep = "")
    if (problem$Y_idx == i) cat(" [Y]", sep = "")
    cat("\n")
  }
}

############################################################################################
# witness_print_model::
#
# Just prints model information in a way that is easier to read.
#
# Input:
# 
# - problem: a problem instance, as defined in 'witness.R'

witness_print_model <- function(problem)
{
  is_latent <- rep(0, ncol(problem$graph))
  is_latent[problem$latent_idx] <- 1

  cat("  *****\n")
  for (i in 1:ncol(problem$graph)) {
    parents <- which(problem$graph[i,] == 1)
    cat("  Node", i, ": parents", parents)
    if (is_latent[i]) cat(" [LATENT]", sep = "")
    if (problem$X_idx == i) cat(" [X]", sep = "")
    if (problem$Y_idx == i) cat(" [Y]", sep = "")

    cat("\n\n")

    params <- round(problem$model[[i]] * 100) / 100
    num_comb <- 2^length(parents)
    if (length(parents) > 0) {
      p_state <- rep(0, length(parents))
      for (k in 1:num_comb) {
        cat("  ", p_state, sep = "")
        cat(": ", params[k], "\n", sep = "")
        for (k in length(parents):1) {
          if (p_state[k] == 1) {
            p_state[k] = 0;
          } else {
            p_state[k] = 1;
            break
          }
        }
      }
    } else {
      cat("  EMPTY: ", params, "\n", sep = "")
    }

    cat("  *****\n")
  }
}

############################################################################################
# witness_generate_WZ_posterior::
#
# Generate distribution over bounds for a particular choice of witness and admissible set.
#
# Input:
# 
# - problem: a problem instance, as defined in 'witness.R'
# - epsilons: array of relaxation parameters
# - witness: index of witness with the data
# - Z: indices of the admissible set variables within teh data
# - M: number of Monte Carlo samples
#
# Output:
#
# - bounds: a matrix where the first column are lower bound samples and the second
#           column are upper bound samples

witness_generate_WZ_posterior <- function(problem, epsilons, witness, Z, M, verbose = FALSE, numerical_method = FALSE)
{
  theta_samples <- witness_posterior_sampling_saturated(witness, problem$X_idx, problem$Y_idx, Z,
                                                        counts = problem$counts, epsilons = epsilons, M = M,
                                                        verbose = verbose, prior_sample = FALSE, numerical_method = numerical_method)
  if (nrow(theta_samples[[1]]$W0) == 0) {
    cat("No bounds can be calculated\n")
    return(NULL)
  }
  num_states <- 2^length(Z)
  
  P_Z_hat <- c(marginTable(problem$counts, rev(Z)))
  P_Z_hat <- P_Z_hat / sum(P_Z_hat) # That is, not completely Bayesian yet
  
  bounds <- matrix(rep(0, 2 * M), nrow = M)  
  if (numerical_method) {
    for (j in seq_len(num_states)) {
      if (verbose) {
        cat(j, "out of", num_states, "\n")
      }
      bounds <- bounds + bayesian_interval_generation(theta_samples[[j]], epsilons, verbose) * P_Z_hat[j]
    }    
  } else {
    for (j in seq_len(num_states)) {
      bounds <- bounds + bayesian_interval_generation_analytical(theta_samples[[j]], epsilons) * P_Z_hat[j]
    }
  }
  return(bounds)
  
}

############################################################################################
# witness_generate_cheap_numerical_bound::
#
# Generates a bound using the posterior mean based on the analytical bounds, then plug
# this single data point into the numerical procedure to generate tighter bounds.
#
# Input:
# 
# - problem: a problem instance, as defined in 'witness.R'
# - epsilons: array of relaxation parameters
# - witness: index of witness with the data
# - Z: indices of the admissible set variables within teh data
# - M: number of Monte Carlo samples
#
# Output:
#
# - bounds_hat: a vector with the lower bound and upper bound point estimator

witness_generate_cheap_numerical_bound <- function(problem, epsilons, witness, Z, M, verbose = FALSE)
{
  theta_samples <- witness_posterior_sampling_saturated(witness, problem$X_idx, problem$Y_idx, Z,
                                                        counts = problem$counts, epsilons = epsilons, M = M,
                                                        verbose = verbose, prior_sample = FALSE, numerical_method = FALSE)
  if (nrow(theta_samples[[1]]$W0) == 0) {
    cat("No tight numerical bounds can be calculated\n")
    return(NULL)
  }
  num_states <- 2^length(Z)
  
  P_Z_hat <- c(marginTable(problem$counts, rev(Z)))
  P_Z_hat <- P_Z_hat / sum(P_Z_hat) # That is, not completely Bayesian yet
  
  bounds_hat <- rep(0, 2)
  for (j in seq_len(num_states)) {
    if (verbose) {
      cat(j, "out of", num_states, "\n")
    }
    theta_samples_hat <- list(W  = mean(theta_samples[[j]]$W),
                              W0 = matrix(colMeans(theta_samples[[j]]$W0), nrow = 1),
                              W1 = matrix(colMeans(theta_samples[[j]]$W1), nrow = 1))
    if (is.na(sum(theta_samples_hat$W0 + theta_samples_hat$W1))) {
      cat("No tight numerical bounds can be calculated\n")
      return(NULL)      
    }
    bounds_hat <- bounds_hat + bayesian_interval_generation(theta_samples_hat, epsilons, verbose) * P_Z_hat[j]
  }  
  if (is.na(bounds_hat[1]) || is.na(bounds_hat[2])) {
    cat("No tight numerical bounds can be calculated\n")
    return(NULL)
  }
  return(bounds_hat)
  
}

###############################################################################################
# witness_simple_bound::
#
# Similar to the previous, but uses a point estimate of the probability distribution. Fit model
# under independence constraint.
#
# Input:
# 
# - problem: a problem instance, as defined in 'witness.R'
# - epsilons: array of relaxation parameters
# - witness: index of witness with the data
# - Z: indices of the admissible set variables within teh data
# - prior: Dirichlet prior
#
# Output:
#
# - bounds_hat: a vector with the lower bound and upper bound point estimator

witness_simple_bound <- function(problem, epsilons, witness, Z, prior = 10)
{
  num_states <- 2^length(Z)
  
  P_Z_hat <- c(marginTable(problem$counts, rev(Z)))
  P_Z_hat <- P_Z_hat / sum(P_Z_hat)
  
  bounds_hat <- rep(0, 2)
  
  W_hat <- marginTable(problem$counts, witness) + prior / 2
  W_hat <- W_hat / sum(W_hat)
  
  counts_YXW <- marginTable(problem$counts, c(problem$Y_idx, problem$X_idx, witness, rev(Z)))
  counts_YX  <- marginTable(counts_YXW, c(1, 2, seq_along(Z) + 3))
  counts_XW  <- marginTable(counts_YXW, c(2, 3, seq_along(Z) + 3))
  
  counts_Y.X0 <- subtable(counts_YX, 2, 1)
  counts_Y.X1 <- subtable(counts_YX, 2, 2)
  counts_X.W0 <- subtable(counts_XW, 2, 1)
  counts_X.W1 <- subtable(counts_XW, 2, 2)
  
  z_value <- rep(1, length(Z))
  bounds_hat <- c(0, 0)
  
  for (i in seq_len(num_states)) {
    
    sY.X0 <- subtable(counts_Y.X0, seq_along(Z) + 1, z_value) + prior / (4 * num_states); sY.X0 <- sY.X0 / sum(sY.X0)
    sY.X1 <- subtable(counts_Y.X1, seq_along(Z) + 1, z_value) + prior / (4 * num_states); sY.X1 <- sY.X1 / sum(sY.X1)
    sX.W0 <- subtable(counts_X.W0, seq_along(Z) + 1, z_value) + prior / (4 * num_states); sX.W0 <- sX.W0 / sum(sX.W0)
    sX.W1 <- subtable(counts_X.W1, seq_along(Z) + 1, z_value) + prior / (4 * num_states); sX.W1 <- sX.W1 / sum(sX.W1)
    
    W0 <- c(sY.X0[1] * sX.W0[1], sY.X1[1] * sX.W0[2], sY.X0[2] * sX.W0[1], sY.X1[2] * sX.W0[2])
    W1 <- c(sY.X0[1] * sX.W1[1], sY.X1[1] * sX.W1[2], sY.X0[2] * sX.W1[1], sY.X1[2] * sX.W1[2])
    
    theta_hat <- list(W  = W_hat[2], W0 = matrix(W0, nrow = 1), W1 = matrix(W1, nrow = 1))
    bounds_hat <- bounds_hat + bayesian_interval_generation(theta_hat, epsilons) * P_Z_hat[i]
    for (j in seq_along(Z)) {
      if (z_value[j] == 1) {z_value[j] <- 2; break}
      z_value[j] <- 1
    }
    
  } 
  
  if (is.na(bounds_hat[1]) || is.na(bounds_hat[2])) {
    return(NULL)
  }
  return(bounds_hat)
  
}
