#include <RcppArmadillo.h>
#include <Rcpp.h>
#include <omp.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::plugins(openmp)]]

using namespace Rcpp;
using namespace arma;

static omp_lock_t lock;

// [[Rcpp::export]]
Rcpp::List integrated_fun(arma::mat data_use, arma::vec S_depth, arma::vec mlemu1, arma::mat mlesigmahat1,
                 arma::vec t_root_vec, arma::vec omega_root_vec,
                 int core_num){
  //def
  arma::mat a_mat("0 ; 1");
  arma::mat b_mat("1 ; 0");
  int dim_use = data_use.n_cols;
  int sample_size_use = data_use.n_rows;
  arma::vec t_root_vec1 = t_root_vec * sqrt(2);
  arma::vec oemga_troot_vec = omega_root_vec % exp(pow(t_root_vec,2)) * sqrt(2) ;
  int length1 = t_root_vec1.n_elem;
  int length2 = oemga_troot_vec.n_elem;
  //
  arma::mat gradiant_int(dim_use,dim_use,fill::zeros);
  arma::mat Hessian_int(dim_use,dim_use,fill::zeros);
  //
  arma::mat sigma0_mat(dim_use,dim_use,fill::zeros);
  arma::mat if_infinite_sigma0_mat(dim_use,dim_use,fill::zeros);
  omp_init_lock(&lock);
  omp_set_lock(&lock);
  
  omp_set_num_threads(core_num);
#pragma omp parallel for
  for (int i=0;i<(dim_use-1);++i) {
    arma::vec data_1 = data_use.col(i);
    double mu1 = mlemu1(i);
    double sigma1 = mlesigmahat1(i,i);
    for (int j=(i+1);j<dim_use;++j) {
      arma::vec data_2 = data_use.col(j);
      double mu2 = mlemu1(j);
      double sigma2 = mlesigmahat1(j,j);
      //
      arma::vec mu_vec(2);
      mu_vec(0) = mu1;
      mu_vec(1) = mu2;
      //
      arma::vec vec1 = (data_1 % data_2)/pow(S_depth,2);
      //
      double vec1_sum = sum(vec1);
      if(vec1_sum == 0){
        if_infinite_sigma0_mat(i,j) = 1;
      }else{
        sigma0_mat(i,j) = log(sum(vec1)/(double)sample_size_use) - (mu1 + sigma1*0.5) - (mu2 + sigma2*0.5);
        sigma0_mat(j,i) = sigma0_mat(i,j);
      }
    }
  }
  omp_destroy_lock(&lock);
  //
  arma::vec min_vec(dim_use,fill::zeros);
  omp_init_lock(&lock);
  omp_set_lock(&lock);
  
  omp_set_num_threads(core_num);
#pragma omp parallel for
  for (int i=0;i<(dim_use);++i){
    arma::vec vec_min_tem = sigma0_mat.col(i);
    double min_value = 1e+5;
    for(int j=0;j<(dim_use);++j){
      if(i!=j){
        if(if_infinite_sigma0_mat(i,j)!=1){
          if(abs(min_value)>abs(sigma0_mat(i,j))){
            min_value = sigma0_mat(i,j);
          }
        }
      }
    }
    //
    min_vec(i) = min_value;
  }
  omp_destroy_lock(&lock);
  //
  omp_init_lock(&lock);
  omp_set_lock(&lock);
  
  omp_set_num_threads(core_num);
#pragma omp parallel for
  for (int i=0;i<(dim_use-1);++i){
    for (int j=(i+1);j<dim_use;++j) {
      if(if_infinite_sigma0_mat(i,j) == 1){
        if(abs(min_vec(i))>abs(min_vec(j))){
          sigma0_mat(i,j) = min_vec(j); 
        }else{
          sigma0_mat(i,j) = min_vec(i); 
        }
      }
    }
  }
  omp_destroy_lock(&lock);
  //
  omp_init_lock(&lock);
  omp_set_lock(&lock);
  //
  sigma0_mat = mlesigmahat1;
  //
  
  omp_set_num_threads(core_num);
#pragma omp parallel for
  for (int i=0;i<(dim_use-1);++i) {
    arma::vec data_1 = data_use.col(i);
    double mu1 = mlemu1(i);
    double sigma1 = mlesigmahat1(i,i);
    for (int j=(i+1);j<dim_use;++j) {
      arma::vec data_2 = data_use.col(j);
      double mu2 = mlemu1(j);
      double sigma2 = mlesigmahat1(j,j);
      //
      arma::vec mu_vec(2);
      mu_vec(0) = mu1;
      mu_vec(1) = mu2;
      //
      // arma::vec vec1 = (data_1 % data_2)/pow(S_depth,2);
      // double sigmaint0 = log(sum(vec1)/(double)sample_size_use) - (mu1 + sigma1*0.5) - (mu2 + sigma2*0.5);
      double sigmaint0 = sigma0_mat(i,j);
      arma::mat sigma_matrix0(2,2);
      sigma_matrix0(0,0) = sigma1;
      sigma_matrix0(1,1) = sigma2;
      sigma_matrix0(1,0) = sigmaint0;
      sigma_matrix0(0,1) = sigmaint0;
      //
      double det_sigma_matrix0 = det(sigma_matrix0);
      if(det_sigma_matrix0 < 0){
        sigmaint0 = sign(sigmaint0) * (sqrt(sigma1 * sigma2) * 0.99);
        sigma_matrix0(1,0) = sigmaint0;
        sigma_matrix0(0,1) = sigmaint0;
      }
      //
      arma::mat sigma_inv0 = inv_sympd(sigma_matrix0);
      //
      arma::vec z1_root = t_root_vec1 * sigma1 + mu1;
      arma::vec z2_root = t_root_vec1 * sigma2 + mu2;
      
      arma::vec m1_root = oemga_troot_vec * sigma1;
      arma::vec m2_root = oemga_troot_vec * sigma2;
      //
      arma::cube gz_up_sigma_int(sample_size_use,length1,length2);
      arma::cube gz_down_int(sample_size_use,length1,length2);
      arma::cube gz_up_sigma_2_int(sample_size_use,length1,length2);
      //
      arma::cube kkk1_array(sample_size_use,length1,length2);
      arma::vec col_index(2);
      col_index(0) = i;
      col_index(1) = j;
      for(int u = 0;u<length1;++u){
        arma::mat z_root_vec(2, length2);
        for(int v = 0;v<length2;++v){
          z_root_vec(0,v) = z1_root(u);
          z_root_vec(1,v) = z2_root(v);
        }
        //
        arma::mat data_use_sub(sample_size_use,2);
        data_use_sub.col(0) = data_use.col(i);
        data_use_sub.col(1) = data_use.col(j);
        //
        arma::mat temp_mat1(1,length2,fill::ones);
        arma::mat res_mat = data_use_sub * z_root_vec - (S_depth * exp(z1_root(u))) * temp_mat1 - S_depth * (exp(z2_root)).t();
        //
        kkk1_array.subcube(0,u,0, (sample_size_use - 1),u,(length2 - 1)) = res_mat;
      }
      //
      arma::vec kkk1(sample_size_use);
      for(int sample_index = 0;sample_index < sample_size_use; ++sample_index){
        arma::mat mat_temp1 = kkk1_array.subcube(sample_index,0,0, sample_index,(length1 - 1),(length2 - 1));
        kkk1(sample_index) = mat_temp1.max();
      }
      //
      for(int u = 0;u<length1;++u){
        arma::mat z_root_vec(2, length2);
        for(int v = 0;v<length2;++v){
          z_root_vec(0,v) = z1_root(u);
          z_root_vec(1,v) = z2_root(v);
        }
        //
        arma::mat z_root_vec_center = z_root_vec;
        z_root_vec_center.row(0) = z_root_vec_center.row(0) - mu_vec(0);
        z_root_vec_center.row(1) = z_root_vec_center.row(1) - mu_vec(1);
        arma::mat kernel_mat = exp(((z_root_vec_center.t()) * sigma_inv0 * z_root_vec_center /2) * (-1));
        arma::vec bbb_1 = diagvec(kernel_mat,0);
        //
        arma::vec bbb_2(length2);
        arma::vec bbb_3(length2);
        for(int v = 0;v<length2;++v){
          arma::mat mat_test = sigma_inv0 * (z_root_vec_center.col(v)) * (z_root_vec_center.col(v)).t() * sigma_inv0;
          bbb_2(v) = mat_test(0,1);
          //
          arma::mat mat_bbb3_1 = ((z_root_vec_center.col(v)).t() * sigma_inv0 * b_mat);
          arma::mat mat_bbb3_2 = (a_mat.t() * sigma_inv0 * (z_root_vec_center.col(v)));
          
          arma::mat bbb3_base = sigma_inv0 * (z_root_vec_center.col(v)) * a_mat.t() * sigma_inv0 * mat_bbb3_1(0,0)
            + sigma_inv0 * b_mat * (z_root_vec_center.col(v)).t() * sigma_inv0 * mat_bbb3_2(0,0);
          bbb_3(v) = bbb3_base(0,1) + bbb3_base(1,0);
        }
        //
        arma::mat kkk2_0 = kkk1_array.subcube(0,u,0, (sample_size_use - 1),u,(length2 - 1));
        for(int sample_index = 0; sample_index<sample_size_use;++sample_index){
          kkk2_0.row(sample_index) = kkk2_0.row(sample_index) - kkk1(sample_index);
        }
        arma::mat kkk2 = (exp(kkk2_0)).t();
        //
        arma::mat gz_up_sigma_int_res = kkk2;
        arma::mat gz_down_int_res = kkk2;
        arma::mat gz_up_sigma_2_int_res = kkk2;
        for(int v = 0;v<length2;++v){
          gz_up_sigma_int_res.row(v) = gz_up_sigma_int_res.row(v) * (bbb_1(v) * bbb_2(v));
          gz_down_int_res.row(v) = gz_down_int_res.row(v) * (bbb_1(v));
          gz_up_sigma_2_int_res.row(v) = gz_up_sigma_2_int_res.row(v) * (bbb_1(v) * ((-1) * bbb_3(v) + bbb_2(v) * bbb_2(v)));
        }
        //
        gz_up_sigma_int.subcube(0,u,0, (sample_size_use - 1),u,(length2 - 1)) = gz_up_sigma_int_res.t();
        gz_down_int.subcube(0,u,0, (sample_size_use - 1),u,(length2 - 1)) = gz_down_int_res.t();
        gz_up_sigma_2_int.subcube(0,u,0, (sample_size_use - 1),u,(length2 - 1)) = gz_up_sigma_2_int_res.t();
        
      }
      //
      arma::vec gradiant_up_sigma_int(sample_size_use);
      arma::vec gradiant_down_int(sample_size_use);
      arma::vec Hessian_up_sigma_int(sample_size_use);
      //
      for(int sample_index = 0;sample_index < sample_size_use; ++sample_index){
        arma::mat mat_incube1 = gz_up_sigma_int.subcube(sample_index,0,0, sample_index,(length1 - 1),(length2 - 1));
        arma::mat mat_incube2 = gz_down_int.subcube(sample_index,0,0, sample_index,(length1 - 1),(length2 - 1));
        arma::mat mat_incube3 = gz_up_sigma_2_int.subcube(sample_index,0,0, sample_index,(length1 - 1),(length2 - 1));
        //
        arma::mat mat_temp1 = (m1_root).t() * mat_incube1 * m2_root;
        double value1 = (double)(mat_temp1(0,0));
        gradiant_up_sigma_int(sample_index) = value1;
        //
        arma::mat mat_temp2 = (m1_root).t() * mat_incube2 * m2_root;
        double value2 = (double)(mat_temp2(0,0));
        gradiant_down_int(sample_index) = value2;
        //
        arma::mat mat_temp3 = (m1_root).t() * mat_incube3 * m2_root;
        double value3 = (double)(mat_temp3(0,0));
        Hessian_up_sigma_int(sample_index) = value3;
      }
      //
      gradiant_int(i,j) = sum(gradiant_up_sigma_int /gradiant_down_int)/(double(sample_size_use)) - sigma_inv0(0,1);
      //
      Hessian_int(i,j) = sum(Hessian_up_sigma_int /gradiant_down_int)/(double(sample_size_use)) -
        sum(pow(gradiant_up_sigma_int /gradiant_down_int,2))/(double(sample_size_use)) +
        (sigma_inv0(0,0) * sigma_inv0(1,1) + sigma_inv0(0,1) * sigma_inv0(0,1));
      
    }
  }
  omp_destroy_lock(&lock);
  //
  return Rcpp::List::create(Rcpp::Named("gradiant_int") = gradiant_int,
                            Rcpp::Named("Hessian_int") = Hessian_int);
  
}
