function [W, stat] = gsp_learn_graph_log_degrees(Z, a, b, params)
%GSP_LEARN_GRAPH_LOG_DEGREES Learn graph from pairwise distances using negative log prior on nodes degrees
%   Usage:  [W, stat] = gsp_learn_graph_log_degrees(Z, a, b)
%           [W, stat] = gsp_learn_graph_log_degrees(Z, a, b, params)
%
%   Inputs:
%         Z         : Matrix with (squared) pairwise distances of nodes
%         a         : Log prior constant  (bigger a -> bigger weights in W)
%         b         : W||_F^2 prior constant  (bigger b -> more dense W)
%         params    : Optional parameters
%
%   Outputs:
%         W         : Weighted adjacency matrix
%         stat      : Optional output statistics (adds small overhead)
%
%
%   'W = gsp_learn_graph_log_degrees(Z, a, b, params)' computes a weighted
%   adjacency matrix W from squared pairwise distances in Z, using the
%   smoothness assumption that text{trace}(X^TLX) is small, where X is
%   the data (columns) changing smoothly from node to node on the graph and
%   L = D-W is the combinatorial graph Laplacian. See the paper of the
%   references for the theory behind the algorithm.
%
%   Alternatively, Z can contain other types of distances and use the
%   smoothness assumption that
%
%      sum(sum(W .* Z))
%
%   is small.
%
%   The minimization problem solved is
%
%      minimize_W sum(sum(W .* Z)) - a * sum(log(sum(W))) + b * ||W||_F^2/2 + c * ||W-W_0||_F^2/2
%
%   subject to W being a valid weighted adjacency matrix (non-negative,
%   symmetric, with zero diagonal).
%
%   The algorithm used is forward-backward-forward (FBF) based primal dual
%   optimization (see references).
%
%   Example:
%
%         G = gsp_sensor(256);
%         f1 = @(x,y) sin((2-x-y).^2);
%         f2 = @(x,y) cos((x+y).^2);
%         f3 = @(x,y) (x-.5).^2 + (y-.5).^3 + x - y;
%         f4 = @(x,y) sin(3*((x-.5).^2+(y-.5).^2));
%         X = [f1(G.coords(:,1), G.coords(:,2)), f2(G.coords(:,1), G.coords(:,2)), f3(G.coords(:,1), G.coords(:,2)), f4(G.coords(:,1), G.coords(:,2))];
%         figure; subplot(2,2,1); gsp_plot_signal(G, X(:,1)); title('1st smooth signal');
%         subplot(2,2,2); gsp_plot_signal(G, X(:,2)); title('2nd smooth signal');
%         subplot(2,2,3); gsp_plot_signal(G, X(:,3)); title('3rd smooth signal');
%         subplot(2,2,4); gsp_plot_signal(G, X(:,4)); title('4th smooth signal');
%         Z = gsp_distanz(X').^2;
%         % we can multiply the pairwise distances with a number to control sparsity
%         [W] = gsp_learn_graph_log_degrees(Z*25, 1, 1);
%         % clean up zeros
%         W(W<1e-5) = 0;
%         G2 = gsp_update_weights(G, W);
%         figure; gsp_plot_graph(G2); title('Graph with edges learned from above 4 signals');
%
%
%   Additional parameters
%   ---------------------
%
%    params.W_init   : Initialization point. default: zeros(size(Z))
%    verbosity       : Default = 1. Above 1 adds a small overhead
%    maxit           : Maximum number of iterations. Default: 1000
%    tol             : Tolerance for stopping criterion. Defaul: 1e-5
%    step_size       : Step size from the interval (0,1). Default: 0.5
%    max_w           : Maximum weight allowed for each edge (or inf)
%    w_0             : Vector for adding prior c/2*||w - w_0||^2
%    c               : multiplier for prior c/2*||w - w_0||^2 if w_0 given
%    fix_zeros       : Fix a set of edges to zero (true/false)
%    edge_mask       : Mask indicating the non zero edges if "fix_zeros"
%
%   If fix_zeros is set, an edge_mask is needed. Only the edges
%   corresponding to the non-zero values in edge_mask will be learnt. This
%   has two applications: (1) for large scale applications it is cheaper to
%   learn a subset of edges. (2) for some applications we don't want some
%   connections to be allowed, for example for locality on images.
%
%   The cost of each iteration is linear to the number of edges to be
%   learned, or the square of the number of nodes (numel(Z)) if fix_zeros
%   is not set.
%
%   The function is using the UNLocBoX functions sum_squareform and
%   squareform_sp.
%   The stopping criterion is whether both relative primal and dual
%   distance between two iterations are below a given tolerance.
%
%   To set the step size use the following rule of thumb: Set it so that
%   relative change of primal and dual converge with similar rates (use
%   verbosity > 1).
%
%   See also: gsp_learn_graph_l2_degrees gsp_distanz gsp_update_weights
%       squareform_sp sum_squareform gsp_compute_graph_learning_theta
%
%   References:
%     V. Kalofolias. How to learn a graph from smooth signals. Technical
%     report, AISTATS 2016: proceedings at Journal of Machine Learning
%     Research (JMLR)., 2016.
%     
%     N. Komodakis and J.-C. Pesquet. Playing with duality: An overview of
%     recent primal? dual approaches for solving large-scale optimization
%     problems. Signal Processing Magazine, IEEE, 32(6):31--54, 2015.
%     
%     V. Kalofolias and N. Perraudin. Large Scale Graph Learning from Smooth
%     Signals. arXiv preprint arXiv:1710.05654, 2017.
%     
%
%
%   Url: https://epfl-lts2.github.io/gspbox-html/doc/learn_graph/gsp_learn_graph_log_degrees.html

% Copyright (C) 2013-2016 Nathanael Perraudin, Johan Paratte, David I Shuman.
% This file is part of GSPbox version 0.7.5
%
% This program is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program.  If not, see <http://www.gnu.org/licenses/>.

% If you use this toolbox please kindly cite
%     N. Perraudin, J. Paratte, D. Shuman, V. Kalofolias, P. Vandergheynst,
%     and D. K. Hammond. GSPBOX: A toolbox for signal processing on graphs.
%     ArXiv e-prints, Aug. 2014.
% http://arxiv.org/abs/1408.5781


% Author: Vassilis Kalofolias
% Testing: gsp_test_learn_graph
% Date: June 2015


%% Default parameters
if nargin < 4
    params = struct;
end

if not(isfield(params, 'verbosity')),   params.verbosity = 1;   end
if not(isfield(params, 'maxit')),       params.maxit = 1000;      end
if not(isfield(params, 'tol')),         params.tol = 1e-5;      end
if not(isfield(params, 'step_size')),   params.step_size = .5;      end     % from (0, 1)
if not(isfield(params, 'fix_zeros')),   params.fix_zeros = false;      end
if not(isfield(params, 'max_w')),       params.max_w = inf;         end


%% Fix parameter size and initialize
if isvector(Z)
    z = Z;  % lazy copying of matlab doesn't allocate new memory for z
else
    z = squareform_sp(Z);
end
% clear Z   % for large scale computation

z = z(:);
l = length(z);                      % number of edges
% n(n-1)/2 = l => n = (1 + sqrt(1+8*l))/ 2
n = round((1 + sqrt(1+8*l))/ 2);    % number of nodes

if isfield(params, 'w_0')
    if not(isfield(params, 'c'))
        error('When params.w_0 is specified, params.c should also be specified');
    else
        c = params.c;
    end
    if isvector(params.w_0)
        w_0 = params.w_0;
    else
        w_0 = squareform_sp(params.w_0);
    end
    w_0 = w_0(:);
else
    w_0 = 0;
end

% if sparsity pattern is fixed we optimize with respect to a smaller number
% of variables, all included in w
if params.fix_zeros
    if not(isvector(params.edge_mask))
        params.edge_mask = squareform_sp(params.edge_mask);
    end
    % use only the non-zero elements to optimize
    ind = find(params.edge_mask(:));
    z = full(z(ind));
    if not(isscalar(w_0))
        w_0 = full(w_0(ind));
    end
else
    z = full(z);
    w_0 = full(w_0);
end


w = zeros(size(z));

%% Needed operators
% S*w = sum(W)
if params.fix_zeros
    [S, St] = sum_squareform(n, params.edge_mask(:));
else
    [S, St] = sum_squareform(n);
end

% S: edges -> nodes
K_op = @(w) S*w;

% S': nodes -> edges
Kt_op = @(z) St*z;

if params.fix_zeros
    norm_K = normest(S);
    % approximation: 
    % sqrt(2*(n-1)) * sqrt(nnz(params.edge_mask) / (n*(n+1)/2)) /sqrt(2)
else
    % the next is an upper bound if params.fix_zeros
    norm_K = sqrt(2*(n-1));
end

%% TODO: Rescaling??
% we want    h.beta == norm_K   (see definition of mu)
% we can multiply all terms by s = norm_K/2*b so new h.beta==2*b*s==norm_K


%% Learn the graph
% min_{W>=0}     tr(X'*L*X) - gc * sum(log(sum(W))) + gp * norm(W-W0,'fro')^2, where L = diag(sum(W))-W
% min_W       I{W>=0} + W(:)'*Dx(:)  - gc * sum(log(sum(W))) + gp * norm(W-W0,'fro')^2
% min_W                f(W)          +       g(L_op(W))      +   h(W)

% put proximal of trace plus positivity together
f.eval = @(w) 2*w'*z;    % half should be counted
%f.eval = @(W) 0;
f.prox = @(w, c) min(params.max_w, max(0, w - 2*c*z));  % all change the same

param_prox_log.verbose = params.verbosity - 3;
g.eval = @(z) -a * sum(log(z));
g.prox = @(z, c) prox_sum_log(z, c*a, param_prox_log);
% proximal of conjugate of g: z-c*g.prox(z/c, 1/c)
g_star_prox = @(z, c) z - c*a * prox_sum_log(z/(c*a), 1/(c*a), param_prox_log);

if w_0 == 0
    % "if" not needed, for c = 0 both are the same but gains speed
    h.eval = @(w) b * norm(w)^2;
    h.grad = @(w) 2 * b * w;
    h.beta = 2 * b;
else
    h.eval = @(w) b * norm(w)^2 + c * norm(w - w_0,'fro')^2;
    h.grad = @(w) 2 * ((b+c) * w - c * w_0);
    h.beta = 2 * (b+c);
end

%% My custom FBF based primal dual (see [1] = [Komodakis, Pesquet])
% parameters mu, epsilon for convergence (see [1])
mu = h.beta + norm_K;     %TODO: is it squared or not??
epsilon = lin_map(0.0, [0, 1/(1+mu)], [0,1]);   % in (0, 1/(1+mu) )

% INITIALIZATION
% primal variable ALREADY INITIALIZED
%w = params.w_init;
% dual variable
v_n = K_op(w);
if nargout > 1 || params.verbosity > 1
    stat.f_eval = nan(params.maxit, 1);
    stat.g_eval = nan(params.maxit, 1);
    stat.h_eval = nan(params.maxit, 1);
    stat.fgh_eval = nan(params.maxit, 1);
    stat.pos_violation = nan(params.maxit, 1);
end
if params.verbosity > 1
    fprintf('Relative change of primal, dual variables, and objective fun\n');
end

tic
gn = lin_map(params.step_size, [epsilon, (1-epsilon)/mu], [0,1]);              % in [epsilon, (1-epsilon)/mu]
for i = 1:params.maxit
    Y_n = w - gn * (h.grad(w) + Kt_op(v_n));
    y_n = v_n + gn * (K_op(w));
    P_n = f.prox(Y_n, gn);
    p_n = g_star_prox(y_n, gn); % = y_n - gn*g_prox(y_n/gn, 1/gn)
    Q_n = P_n - gn * (h.grad(P_n) + Kt_op(p_n));
    q_n = p_n + gn * (K_op(P_n));
    
    if nargout > 1 || params.verbosity > 2
        stat.f_eval(i) = f.eval(w);
        stat.g_eval(i) = g.eval(K_op(w));
        stat.h_eval(i) = h.eval(w);
        stat.fgh_eval(i) = stat.f_eval(i) + stat.g_eval(i) + stat.h_eval(i);
        stat.pos_violation(i) = -sum(min(0,w));
    end
    rel_norm_primal = norm(- Y_n + Q_n, 'fro')/norm(w, 'fro');
    rel_norm_dual = norm(- y_n + q_n)/norm(v_n);
    
    if params.verbosity > 3
        fprintf('iter %4d: %6.4e   %6.4e   %6.3e', i, rel_norm_primal, rel_norm_dual, stat.fgh_eval(i));
    elseif params.verbosity > 2
        fprintf('iter %4d: %6.4e   %6.4e   %6.3e\n', i, rel_norm_primal, rel_norm_dual, stat.fgh_eval(i));
    elseif params.verbosity > 1
        fprintf('iter %4d: %6.4e   %6.4e\n', i, rel_norm_primal, rel_norm_dual);
    end
    
    w = w - Y_n + Q_n;
    v_n = v_n - y_n + q_n;
    
    if rel_norm_primal < params.tol && rel_norm_dual < params.tol
        break
    end
end
stat.time = toc;
if params.verbosity > 0
    fprintf('# iters: %4d. Rel primal: %6.4e Rel dual: %6.4e  OBJ %6.3e\n', i, rel_norm_primal, rel_norm_dual, f.eval(w) + g.eval(K_op(w)) + h.eval(w));
    fprintf('Time needed is %f seconds\n', stat.time);
end

% use the following for testing:
% g.L = K_op;
% g.Lt = Kt_op;
% g.norm_L = norm_K;
% [w, info] = fbf_primal_dual(w, f, g, h, params);
% %[w, info] = fb_based_primal_dual(w, f, g, h, params);


%%

if params.verbosity > 3
    figure; plot(real([stat.f_eval, stat.g_eval, stat.h_eval])); hold all; plot(real(stat.fgh_eval), '.'); legend('f', 'g', 'h', 'f+g+h');
    figure; plot(stat.pos_violation); title('sum of negative (invalid) values per iteration')
    figure; semilogy(max(0,-diff(real(stat.fgh_eval'))),'b.-'); hold on; semilogy(max(0,diff(real(stat.fgh_eval'))),'ro-'); title('|f(i)-f(i-1)|'); legend('going down','going up');
end

if params.fix_zeros
    w = sparse(ind, ones(size(ind)), w, l, 1);
end

if isvector(Z)
    W = w;
else
    W = squareform_sp(w);
end



