#include "coupled_mckean_vlasov.h"
//#include "polynomial_loss.h" 
#include "polynomial_loss_convex.h"
#include <stdio.h>
#include <iostream>

//Comment line 2 to run the convex loss, comment line 3 to run the non-convex loss
//Experiments were run on "./run_poly 100 100 8 30000 0.0002 1 0" (for non-convex)

int main(int argc, char **argv) {

  // parse the arguments
  if (argc!=8) {
    fprintf(stderr, "usage: ./run_poly nx ny dim n_steps lr final convex\n");
    exit(-1);
  }

  int nx = atoi(argv[1]);
  int ny = atoi(argv[2]);
  int dim = atoi(argv[3]);
  int n_steps = atoi(argv[4]);
  double lr = atof(argv[5]);
  bool final = atof(argv[6]); //true for the final NI, false to get evolution of the NI over time
  bool convex = atof(argv[7]);
  //define learning rate for MD/WFR, it could be included in argv, beta could as well
  double lr_md = 0.1;
  //double lr = 0.002; Uncomment to tune it from here, that is the value used in the experiments
  int n_dim = 4;
  std::vector<int> dim_array;
  if (not final)
  {
    std::vector<int> dim_array_in{3,5,9,17};
    dim_array = dim_array_in;
    fprintf(stderr, "Running with %d %d %d %d %d %d\n", nx, ny, 
      dim_array[0], dim_array[1], dim_array[2], dim_array[3]);
  } 
  else 
  {
    std::vector<int> dim_array_in{1,2,4,8,16,32};
    dim_array = dim_array_in;
    n_dim = 6;
    fprintf(stderr, "Running with %d %d %d %d %d %d %d %d\n", nx, ny, 
      dim_array[0], dim_array[1], dim_array[2], dim_array[3], dim_array[4], dim_array[5]);
  }

  char log_file_name[80];
  FILE *log_file_pointer;
  if (final){
    if (not convex) 
    {
      sprintf(log_file_name, "log_nx=%04d_ny=%04d.dat", nx, ny);
      log_file_pointer = fopen(log_file_name, "w");
    }
    else 
    {
      sprintf(log_file_name, "log_nx=%04d_ny=%04d_cv.dat", nx, ny);
      log_file_pointer = fopen(log_file_name, "w");
    }
  }
  else 
  {
    if (not convex)
    {
      sprintf(log_file_name, "time_log_nx=%04d_ny=%04d.dat", nx, ny);
      log_file_pointer = fopen(log_file_name, "w");
    }
    else
    {
      sprintf(log_file_name, "time_log_nx=%04d_ny=%04d_cv.dat", nx, ny);
      log_file_pointer = fopen(log_file_name, "w");
    }
  }


  char xs_file_name[80];
  char ys_file_name[80];
  sprintf(xs_file_name, "xs_nx=%04d_ny=%04d.dat", nx, ny);
  sprintf(ys_file_name, "ys_nx=%04d_ny=%04d.dat", nx, ny);
  std::ofstream xs_stream;
  std::ofstream ys_stream;
  xs_stream.open(xs_file_name);
  ys_stream.open(ys_file_name);

  for (int dimcount = 0; dimcount < n_dim; dimcount++) //run for each dimension in the vector of dimension
  {
  dim = dim_array[dimcount];
  // these are parameters for the losses, they could be put together
  for (int runs = 0; runs < 20; ++runs) //run 20 samples for each dimension
  {
  fprintf(stdout, "%d %d\n", dim, runs);
  mat A_0 = randn<mat>(dim, dim);
  mat A_1 = randn<mat>(dim, dim);
  mat A_2 = randn<mat>(dim, dim);
  mat A_3 = eye<mat>(dim, dim);
  mat a_0 = randn<mat>(dim, 1);
  mat a_1 = randn<mat>(dim, 1);
  //For convex
  if (convex)
  {
    A_0 = A_0.t()*A_0;
    A_2 = A_2.t()*A_2;
  }
  // initialize the particle systems for Langevin
  xsSystem xs_system;
  ysSystem ys_system;
  xs_system.initialize(nx, dim, A_0, A_1, A_2, A_3, a_0, a_1);
  ys_system.initialize(ny, dim, A_0, A_1, A_2, A_3, a_0, a_1);
  double beta=10000;
  
  // initialize the particle systems for Mirror Descent/WFR
  xsSystem xs_system_md;
  ysSystem ys_system_md;
  xs_system_md.initialize(nx, dim, A_0, A_1, A_2, A_3, a_0, a_1);
  ys_system_md.initialize(ny, dim, A_0, A_1, A_2, A_3, a_0, a_1);

  //run Langevin
  double loss;
  double NIerror;
  double loss_md;
  double NIerror_md;
  xsSystem xs_system2;
  ysSystem ys_system2;
  xsSystem ind_xs_system;
  ysSystem ind_ys_system;
  int nx_NI = 20*nx;
  int ny_NI = 20*ny;
  for (int i=0; i<n_steps; i++)
  {
    xs_system.langevin_step(&ys_system, lr, beta);
    ys_system.langevin_step(&xs_system, lr, beta);
    if ((i%500==0) and (not final)) {
      loss = xs_system.compute_loss(&ys_system);
      xs_system2.initialize(nx_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      ys_system2.initialize(ny_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      NIerror = xs_system.computeNI(&ys_system, &xs_system2, &ys_system2, &ind_xs_system, &ind_ys_system, 5000, lr);
      fprintf(stdout, "%s %e %e %e\n", "Langev", i*500*lr, loss, NIerror);
      fprintf(log_file_pointer, "%e %e %e\n", i*500*lr, loss, NIerror);
      xs_system.xs.print(xs_stream);
      ys_system.xs.print(ys_stream);
    }
    if (final and (i >= n_steps - 100) and (i%10 == 0)) {
      loss = xs_system.compute_loss(&ys_system);
      xs_system2.initialize(nx_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      ys_system2.initialize(ny_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      NIerror = xs_system.computeNI(&ys_system, &xs_system2, &ys_system2, &ind_xs_system, &ind_ys_system, 5000, lr);
      fprintf(stdout, "%s %e %e %e\n", "Langev", i*500*lr, loss, NIerror);
      fprintf(log_file_pointer, "%e %e %e\n", i*500*lr, loss, NIerror);
      xs_system.xs.print(xs_stream);
      ys_system.xs.print(ys_stream);
    }
  }

  //run MD
  for (int i=0; i<n_steps; i++)
  {
    xs_system_md.gd_step_weights(&ys_system_md, lr_md);
    ys_system_md.gd_step_weights(&xs_system_md, lr_md);
    if ((i%500==0) and (not final)) {
      xs_system2.initialize(nx_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      ys_system2.initialize(ny_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      NIerror_md = xs_system_md.computeNI_weights(&ys_system_md, &xs_system2, &ys_system2, 
                                                  &ind_xs_system, &ind_ys_system, 5000, lr, true); //true is for time averaging
      loss_md = xs_system_md.compute_loss_weights_full(&ys_system_md, true);
      fprintf(stdout, "%s %e %e %e\n", "MD", i*500*lr, loss_md, NIerror_md);
      fprintf(log_file_pointer, "%e %e %e\n", i*500*lr, loss_md, NIerror_md);
    }
    if (final and (i >= n_steps - 100) and (i%10 == 0)) {
      xs_system2.initialize(nx_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      ys_system2.initialize(ny_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      NIerror_md = xs_system_md.computeNI_weights(&ys_system_md, &xs_system2, &ys_system2, 
                                                  &ind_xs_system, &ind_ys_system, 5000, lr, true);
      loss_md = xs_system_md.compute_loss_weights_full(&ys_system_md, true);
      fprintf(stdout, "%s %e %e %e\n", "MD", i*500*lr, loss_md, NIerror_md);
      fprintf(log_file_pointer, "%e %e %e\n", i*500*lr, loss_md, NIerror_md);
    }
  }

  xs_system_md.initialize(nx, dim, A_0, A_1, A_2, A_3, a_0, a_1);
  ys_system_md.initialize(ny, dim, A_0, A_1, A_2, A_3, a_0, a_1);
  
  //run WFR
  for (int i=0; i<n_steps; i++)
  {
    xs_system_md.gd_step_weights(&ys_system_md, lr_md);
    ys_system_md.gd_step_weights(&xs_system_md, lr_md);
    xs_system_md.transport_step_weights(&ys_system_md, lr, false);
    ys_system_md.transport_step_weights(&xs_system_md, lr, false);
    if ((i%500==0) and (not final)) {
      xs_system2.initialize(nx_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      ys_system2.initialize(ny_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      NIerror_md = xs_system_md.computeNI_weights(&ys_system_md, &xs_system2, &ys_system2, 
                                                  &ind_xs_system, &ind_ys_system, 5000, lr, true); 
      loss_md = xs_system_md.compute_loss_weights_full(&ys_system_md, true); 
      fprintf(stdout, "%s %e %e %e\n", "WFR", i*500*lr, loss_md, NIerror_md);
      fprintf(log_file_pointer, "%e %e %e\n", i*500*lr, loss_md, NIerror_md);
    }
    if (final and (i >= n_steps - 100) and (i%10 == 0)) {
      xs_system2.initialize(nx_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      ys_system2.initialize(ny_NI, dim, A_0, A_1, A_2, A_3, a_0, a_1);
      NIerror_md = xs_system_md.computeNI_weights(&ys_system_md, &xs_system2, &ys_system2, 
                                                  &ind_xs_system, &ind_ys_system, 5000, lr, true); 
      loss_md = xs_system_md.compute_loss_weights_full(&ys_system_md, true); 
      fprintf(stdout, "%s %e %e %e\n", "WFR", i*500*lr, loss_md, NIerror_md);
      fprintf(log_file_pointer, "%e %e %e\n", i*500*lr, loss_md, NIerror_md);
    }
    }
    }
  }

  exit(0);

}
