/* coupled_mckean_vlasov.h


   (c) Grant Rotskoff, Carles Domingo-Enrich 2020.
*/

#ifndef PARTICLES_H
#define PARTICLES_H

#include <armadillo>
#include <math.h>
using namespace arma;

class CoupledMcKeanVlasov {
  public:
    int n; /**< the number of units */
    int d; /**< dimensionality */
    mat xs; /**< vector for the weights */
    mat xs_copy;
    mat grad_xs;
    mat ws; /**< array for the coeffs */
    mat ws_avg;
    mat ws_matrix;
    mat ws_avg_matrix;
    mat upd_ws;
    mat A_0;
    mat A_1;
    mat A_2;
    mat A_3;
    mat a_0;
    mat a_1;

    /* Constructors and Destructors */
    // default constructor, initializes the class
    CoupledMcKeanVlasov();
    // default destructor, clears the memory occupied by the class
    ~CoupledMcKeanVlasov();

    // initialize the class with specific values
    void initialize(int n0, int d0, mat A_00, mat A_10, mat A_20, mat A_30, mat a_00, mat a_10);

    // compute the interaction between pairs of particles x,y
    void restrict_row(int n_row);

    // compute gradients
    virtual void compute_grad(CoupledMcKeanVlasov *ys_system)
    {
      grad_xs = zeros<mat>(n,d);
    }

    virtual void compute_grad_weights(CoupledMcKeanVlasov *ys_system)
    {
      grad_xs = zeros<mat>(n,d);
    }

    virtual void weights_update(CoupledMcKeanVlasov *ys_system)
    {
      upd_ws = zeros<mat>(n,1);
    }

    // compute the loss
    virtual double compute_loss(CoupledMcKeanVlasov *ys_system)
    {
      return 0.;
    }

    virtual double compute_loss_weights(CoupledMcKeanVlasov *ys_system)
    {
      return 0.;
    }

    virtual double compute_loss_weights_full(CoupledMcKeanVlasov *ys_system, bool avg)
    {
      return 0.;
    }

    virtual double computeNI(CoupledMcKeanVlasov *ys_system, CoupledMcKeanVlasov *xs_system2, 
                             CoupledMcKeanVlasov *ys_system2, CoupledMcKeanVlasov *ind_xs_system, 
                             CoupledMcKeanVlasov *ind_ys_system, int iter, double lr)
    {
      return 0.;
    }


    // optimization dynamics, requires a pointer to a teacher network
    void gd_step(CoupledMcKeanVlasov *ys_system, double lr);

    void gd_step_weights(CoupledMcKeanVlasov *ys_system, double lr);

    // langevin dynamics
    void langevin_step(CoupledMcKeanVlasov *ys_system, double lr, double beta);

    void transport_step_weights(CoupledMcKeanVlasov *ys_system, double lr, bool avg);

    /* helper function*/
    void renormalize();
    void renormalize_weights();
};


#endif
