rm(list=objects())
setwd("~/Desktop/Experiments/Simulations/Results")

#################### -  Set parameters - #########
M <- 100 # number of experiments
K_list <- 2:15 # number of communities
N <- 300 # number of nodes
for (assortative in c(TRUE, FALSE)){
  for (K in K_list){
    
    if(assortative){
      Q <- matrix(0.1, ncol = K, nrow = K)
      diag(Q) <- 0.4
    }else{
      Q <- matrix(0.4, ncol = K, nrow = K)
      diag(Q) <- 0.1
    }
    
    alpha <- rep(1/K, K)
    
    set.seed(0)
    start_time <- Sys.time()
    
    # record the errors 
    Error_missSBM <- rep(NA, M)
    Error_Var<- rep(0, M)
    Error_true_Z <- rep(NA, M)
    Error_softImpute <- rep(NA, M)
    
    for (m in 1:M){
      
      ##################### -  Draw random variables - #####################
      # set Z the label function
      Z <- base::sample(1:K, size = N, replace = T, prob = alpha)
      
      # set A the full adjacency matrix
      Theta <- sapply(1:N, function(i) sapply(1:N, function(j) Q[Z[i], Z[j]])) # SBM matrix
      A_undir <- rbinom(n=N*(N-1)/2, size=1, prob=Theta[upper.tri(Theta)]) # draw edges
      A <- matrix(0,N,N)
      A[upper.tri(A)] <- A_undir
      A <- (A+t(A)) # adjacency matrix
      
      # Sample edges
      Omega_undir <- rbinom(n=N*(N-1)/2, size=1, prob=0.5) # draw edges
      Omega <- matrix(0,N,N)
      Omega[upper.tri(Omega)] <- Omega_undir
      Omega <- (Omega+t(Omega)) # adjacency matrix
      
      # set A_obs the observed adjacency matrix
      A_obs <- A
      diag(A_obs) <- NA
      A_obs[Omega == 0] <- NA
      
      ##################### -  Estimate network using softImpute - #####################
      SVD <- softImpute(A_obs, rank.max = K, lambda = 0, maxit = 500)
      estimate_Theta_softImpute <- SVD$u %*% diag(SVD$d, nrow = K, ncol = K) %*% t(SVD$v)
      estimate_Theta_softImpute <- pmin(pmax(estimate_Theta_softImpute, 0),1)
      Error_softImpute[m] <- sum((estimate_Theta_softImpute - Theta)**2, na.rm = F)/N**2
      
      ##################### -  Estimate network using missSBM - #####################
      estimator_missSBM <- missSBM::estimateMissSBM(
        adjacencyMatrix = A_obs, 
        vBlocks = c(K),
        sampling = "dyad",
        control = list(trace = 0))$bestModel$fittedSBM
      estimate_Theta_missSBM <- estimator_missSBM$expectation
      
      Error_missSBM[m] <- sum((estimate_Theta_missSBM - Theta)**2, na.rm = F)/N**2
      
      ##################### -  Estimate network using the variational estimate of z - #####################
      z_est <- estimator_missSBM$memberships
      estimate_Q_Var <- sapply(1:K, function(a) sapply(1:K, function(b) mean(A_obs[z_est == a, z_est == b], na.rm =T)))
      estimate_Q_Var[is.na(estimate_Q_Var)] <- 0
      estimate_Theta_Var <- sapply(1:N, function(i) sapply(1:N, function(j) estimate_Q_Var[z_est[i], z_est[j]]))
      diag(estimate_Theta_Var) <- 0
      
      Error_Var[m] <- sum((estimate_Theta_Var - Theta)**2, na.rm = F)/N**2
      
      ##################### -  Estimate network using the true z - #####################
      estimate_Q_true_Z <- sapply(1:K, function(a) sapply(1:K, function(b) mean(A_obs[Z == a, Z == b], na.rm =T)))
      estimate_Q_true_Z[is.na(estimate_Q_true_Z)] <- 0
      estimate_Theta_true_Z <- sapply(1:N, function(i) sapply(1:N, function(j) estimate_Q_true_Z[Z[i], Z[j]]))
      diag(estimate_Theta_true_Z) <- 0
      
      Error_true_Z[m] <- sum((estimate_Theta_true_Z - Theta)**2, na.rm = F)/N**2
    }
    
    results <- list(N = N, K = K, fixed_K = FALSE, assortative = assortative,
                    Error_softImpute = Error_softImpute, Error_missSBM = Error_missSBM,
                    Error_Var = Error_Var, Error_true_Z = Error_true_Z)
    path <- "Varying K"
    
    if(assortative){
      path <- paste0(path, "/Assortative SBM")
    }else{
      path <- paste0(path, "/Disassortative SBM")
    }
    path <- paste0(path, "/K_",K ,".RDS")
    saveRDS(results, file = path)
    
    print(path)
    end_time <- Sys.time()
    print(end_time - start_time)
  }
}