/* coupled_mckean_vlasov.cc

   Implements a coupled nonlinear McKean-Vlasov particle system.

   (c) Grant Rotskoff, Carles Domingo-Enrich, 2020.
 */

#include "coupled_mckean_vlasov.h"

using namespace arma;
const double pi = 4.*atan(1.);

CoupledMcKeanVlasov::CoupledMcKeanVlasov()
{
  n = 0;
  d = 0;
}

CoupledMcKeanVlasov::~CoupledMcKeanVlasov(){}

void CoupledMcKeanVlasov::initialize(int n0, int d0, mat A_00, mat A_10, mat A_20, mat A_30, mat a_00, mat a_10)
{
  n = n0;
  d = d0;
  ws = ones<mat>(n,1);
  ws_matrix = ones<mat>(n,d);
  upd_ws = zeros(n,1);
  xs = randn<mat>(n,d);
  grad_xs = zeros<mat>(n,d);
  xs_copy = randn<mat>(n,d);
  A_0 = A_00;
  A_1 = A_10;
  A_2 = A_20;
  A_3 = A_30;
  a_0 = a_00;
  a_1 = a_10;
  renormalize();
  renormalize_weights();
  ws_avg = ws;
  ws_avg_matrix = ws_matrix;
}

void CoupledMcKeanVlasov::restrict_row(int n_row) 
{
  n = 1;
  xs = xs.row(n_row);
  ws = ones<mat>(1,1);
  ws_avg = ws;
  ws_matrix = ones<mat>(1,d);
  ws_avg_matrix = ones<mat>(1,d);
}

void CoupledMcKeanVlasov::renormalize()
{
  xs = normalise(xs, 2, 1); // normalise each row with p=2
}

void CoupledMcKeanVlasov::renormalize_weights()
{
  ws = normalise(ws, 1, 0); // normalise column with p=2
}

void CoupledMcKeanVlasov::gd_step(CoupledMcKeanVlasov *ys_system, double lr)
{
  compute_grad(ys_system);
  xs -= lr * grad_xs;
  renormalize();
}

void CoupledMcKeanVlasov::gd_step_weights(CoupledMcKeanVlasov *ys_system, double lr)
{
  weights_update(ys_system);
  ws %= exp(-lr*upd_ws);
  renormalize_weights();
  ws_matrix.each_col() = ws;
  ws_avg += ws;
}

void CoupledMcKeanVlasov::langevin_step(CoupledMcKeanVlasov *ys_system, double lr, double beta)
{
  compute_grad(ys_system);
  xs += -lr * grad_xs + sqrt(2*lr/beta)*randn<mat>(n,d);
  renormalize();
}

void CoupledMcKeanVlasov::transport_step_weights(CoupledMcKeanVlasov *ys_system, double lr, bool avg)
{
  mat x_ws_copy;
  mat y_ws_copy;
  if (avg)
  {
    x_ws_copy = ws;
    y_ws_copy = ys_system->ws;
    ws = normalise(ws_avg, 1, 0);
    ws_matrix.each_col() = normalise(ws_avg, 1, 0);
    ys_system->ws = normalise(ys_system->ws_avg, 1, 0);
    ys_system->ws_matrix.each_col() = normalise(ys_system->ws_avg, 1, 0);
  }
  compute_grad_weights(ys_system);
  xs += -lr * grad_xs;
  renormalize();
  if (avg)
  {
    ws = x_ws_copy;
    ys_system->ws = y_ws_copy;
  }
}
