function M = multinomialfactory(n, m)
% Manifold of n-by-m column-stochastic matrices with positive entries.
%
% function M = multinomialfactory(n)
% function M = multinomialfactory(n, m)
%
% The returned structure M is a Manopt manifold structure to optimize over
% the set of n-by-m matrices with (strictly) positive entries and such that
% the entries of each column sum to one. By default, m = 1, which
% corresponds to the relative interior of the simplex (discrete probability
% distributions with nonzero probabilities.)
%
% The metric imposed on the manifold is the Fisher metric such that 
% the set of n-by-m column-stochastic matrices (a.k.a. the multinomial
% manifold) is a Riemannian submanifold of the space of n-by-m matrices.
% Also it should be noted that the retraction operation that we define 
% is first order and as such the checkhessian tool cannot verify 
% the slope correctly at non-critical points.
%             
% The file is based on developments in the research paper
% Y. Sun, J. Gao, X. Hong, B. Mishra, and B. Yin,
% "Heterogeneous tensor decomposition for clustering via manifold
% optimization", arXiv:1504.01777, 2015.
%
% Link to the paper: http://arxiv.org/abs/1504.01777.
%
% The exponential and logarithmic map and the distance appear in:
% F. Astrom, S. Petra, B. Schmitzer, C. Schnorr,
% "Image Labeling by Assignment",
% Journal of Mathematical Imaging and Vision, 58(2), pp. 221?238, 2017.
% doi: 10.1007/s10851-016-0702-4
% arxiv: https://arxiv.org/abs/1603.05285
%
% Please cite the Manopt paper as well as the research paper:
% @Article{sun2015multinomial,
%   author  = {Y. Sun and J. Gao and X. Hong and B. Mishra and B. Yin},
%   title   = {Heterogeneous Tensor Decomposition for Clustering via Manifold Optimization},
%   journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
%   year    = {2016},
%   volume  = {38},
%   number  = {3},
%   pages   = {476--489},
%   doi     = {10.1109/TPAMI.2015.2465901}
% }

% This file is part of Manopt: www.manopt.org.
% Original author: Bamdev Mishra, April 06, 2015.
% Contributors: Ronny Bergmann
% Change log:
%
%    Sep. 6, 2018 (NB):
%        Removed M.exp() as it was not implemented.
%
%    Apr. 12, 2020 (RB):
%        Adds exp, log, dist.

    if ~exist('m', 'var') || isempty(m)
        m = 1;
    end

    M.name = @() sprintf('%dx%d column-stochastic matrices with positive entries', n, m);
    
    M.dim = @() (n-1)*m;
    
    % We impose the Fisher metric.
    M.inner = @iproduct;
    function ip = iproduct(X, eta, zeta)
        ip = sum((eta(:).*zeta(:))./X(:));
    end
    
    M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
    
    M.dist = @(X, Y) norm(2*acos(sum(sqrt(X.*Y), 1)));
    
    M.typicaldist = @() m*pi/2; % This is an approximation.
    
    % Column vector of ones of length n.
    % TODO: eliminate e by using bsxfun
    e = ones(n, 1);
    
    M.exp = @exponential;
    function Y = exponential(X, U, t)
        if nargin == 3
            tU = t*U;
        else
            tU = U;
        end
        Y = zeros(size(X));
        for mm = 1 : m
            x = X(:, mm);
            s = sqrt(x);
            us = tU(:, mm) ./ s ./ 2;
            un = norm(us);
            if un < eps
                Y(:, mm) = X(:, mm);
            else
                Y(:, mm) = (cos(un).*s + sin(un)/un.*us).^2;
            end
        end
    end

    M.log = @logarithm;
    function U = logarithm(X,Y)
        a = sqrt(X.*Y);
        s = sum(a, 1);
        U = 2*acos(s) ./ sqrt(1-s.^2) .* (a - s.*X);
    end
    
    M.egrad2rgrad = @egrad2rgrad;
    function rgrad = egrad2rgrad(X, egrad)
        Xegrad = X.*egrad;
        lambda = -sum(Xegrad, 1); % Row vector of length m.
        rgrad = Xegrad + (e*lambda).*X; % This is in the tangent space.
    end
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(X, egrad, ehess, eta)
        
        % Riemannian gradient computation.
        % lambda is a row vector of length m.
        Xegrad = X.*egrad;
        lambda = - sum(Xegrad, 1);
        rgrad =  Xegrad + (e*lambda).*X;
        
        % Directional derivative of the Riemannian gradient.
        % lambdadot is a row vector of length m.
        Xehess = X.*ehess;
        etaegrad = eta.*egrad;
        lambdadot = -sum(etaegrad, 1) - sum(Xehess, 1); 
        rgraddot = etaegrad + Xehess + (e*lambdadot).*X + (e*lambda).*eta;
        
        % Correction term because of the non-constant metric that we
        % impose. The computation of the correction term follows the use of
        % Koszul formula.
        correction_term = - 0.5*(eta.*rgrad)./X;
        rhess = rgraddot + correction_term;
        
        % Finally, projection onto the tangent space.
        rhess = M.proj(X, rhess);
    end
    
    % Projection of the vector eta in the ambient space onto the tangent
    % space.
    M.proj = @projection;
    function etaproj = projection(X, eta)
        alpha = sum(eta, 1); % Row vector of length m.
        etaproj = eta - (e*alpha).*X;
    end
    
    M.tangent = M.proj;
    M.tangent2ambient = @(X, eta) eta;
    
    M.retr = @retraction;
    function Y = retraction(X, eta, t)
        if nargin < 3
            t = 1.0;
        end
        % A first-order retraction.
        Y = X.*exp(t*(eta./X)); % Based on mapping for positive scalars.
        Y = Y./(e*(sum(Y, 1))); % Projection onto the constraint set.
        % For numerical reasons, so that we avoid entries going to zero:
        Y = max(Y, eps);
    end
    
    
    M.hash = @(X) ['z' hashmd5(X(:))];
    
    M.rand = @random;
    function X = random()
        % A random point in the ambient space.
        X = rand(n, m); %
        X = X./(e*(sum(X, 1)));
    end
    
    M.randvec = @randomvec;
    function eta = randomvec(X)
        % A random vector in the tangent space
        eta = randn(n, m);
        eta = M.proj(X, eta); % Projection onto the tangent space.
        nrm = M.norm(X, eta);
        eta = eta / nrm;
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(X) zeros(n, m);
    
    M.transp = @(X1, X2, d) projection(X2, d);
    
    % vec and mat are not isometries, because of the scaled metric.
    M.vec = @(X, U) U(:);
    M.mat = @(X, u) reshape(u, n, m);
    M.vecmatareisometries = @() false;
end
