function M = stiefelgeneralizedfactory(n, p, B)
% Returns a manifold structure of "scaled" orthonormal matrices.
%
% function M = stiefelgeneralizedfactory(n, p)
% function M = stiefelgeneralizedfactory(n, p, B)
%
% The generalized Stiefel manifold is the set of "scaled" orthonormal 
% nxp matrices X such that X'*B*X is identity. B must be positive definite.
% If B is identity, then this is the standard Stiefel manifold.
%
% The generalized Stiefel manifold is endowed with a scaled metric
% by making it a Riemannian submanifold of the Euclidean space,
% again endowed with the scaled inner product.
%
% Some notions (not all) are from Section 4.5 of the paper
% "The geometry of algorithms with orthogonality constraints",
% A. Edelman, T. A. Arias, S. T. Smith, SIMAX, 1998.
%
% Paper link: http://arxiv.org/abs/physics/9806030.
%
% Note: egrad2rgrad and ehess2rhess involve solving linear systems in B. If
% this is a bottleneck for a specific application, then a way forward is to
% create a modified version of this file which preprocesses B to speed this
% up (typically, by computing a Cholesky factorization of it, then calling
% an appropriate solver; or by exploiting sparsity.)
%
% See also: stiefelfactory  grassmannfactory  grassmanngeneralizedfactory 

% This file is part of Manopt: www.manopt.org.
% Original author: Bamdev Mishra, June 30, 2015.
% Contributors: Hiroyuki Sato and Kensuke Aihara, Dec. 27, 2018.
%
% Change log:
%
%    Sep.  6, 2018 (NB):
%        Removed M.exp() as it was not implemented.
%
%    Dec. 27, 2018 (NB):
%        Added retraction M.retr_qr (based on QR factorization) as an
%        alternative to the previous retraction, which is now accessible as
%        M.retr_polar. The new retraction should be more efficient: it is
%        now the default, as M.retr = M.retr_qr. This new retraction is a
%        contribution of H. Sato and K. Aihara, as per their paper:
%        'Cholesky QR-based retraction on the generalized Stiefel manifold'
%        https://link.springer.com/article/10.1007%2Fs10589-018-0046-7

    
    if ~exist('B', 'var') || isempty(B)
        B = speye(n); % Standard Stiefel manifold.
    end
    
    M.name = @() sprintf('Generalized Stiefel manifold St(%d, %d)', n, p);
    
    M.dim = @() (n*p - .5*p*(p+1));
    
    M.inner = @(X, eta, zeta) trace(eta'*(B*zeta)); % Scaled metric.
    
    M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
    
    M.dist = @(X, Y) error('stiefelgeneralizedfactory.dist not implemented yet.');
    
    M.typicaldist = @() sqrt(p);
    
    % Orthogonal projection of an ambient vector U to the tangent space
    % at X.
    M.proj = @projection;
    function Up = projection(X, U)
        BX = B*X;
        
        % Projection onto the tangent space
        Up = U - X*symm(BX'*U);  
    end
    
    M.tangent = M.proj;
    
    M.egrad2rgrad = @egrad2rgrad;
    function rgrad = egrad2rgrad(X, egrad)
        
        % First, scale egrad according the to the scaled metric in the
        % Euclidean space. Ideally, B should be preprocessed to ease
        % solving linear systems, e.g., via Cholesky factorization.
        egrad_scaled = B\egrad;
        
        % Second, project onto the tangent space.
        % rgrad = egrad_scaled - X*symm((B*X)'*egrad_scaled);
        %
        % Verify that symm(BX'*egrad_scaled) = symm(X'*egrad).
        
        rgrad = egrad_scaled - X*symm(X'*egrad);
    end
    
    
    
    M.ehess2rhess = @ehess2rhess;
    function rhess = ehess2rhess(X, egrad, ehess, H)
        egraddot = ehess;
        Xdot = H;
        
        % Directional derivative of the Riemannian gradient.
        egrad_scaleddot = B\egraddot;
        rgraddot = egrad_scaleddot - Xdot*symm(X'*egrad) ...
                                   - X*symm(Xdot'*egrad) ...
                                   - X*symm(X'*egraddot);
        
        % Project onto the tangent space.
        rhess = M.proj(X, rgraddot);
    end
    
    
    M.retr_polar = @retraction_polar;
    function Y = retraction_polar(X, U, t)
        if nargin < 3
            t = 1.0;
        end
        Y = guf(X + t*U); % Ensures Y'*B*Y is identity.
    end

    M.retr_qr = @retraction_qr;
    function Y = retraction_qr(X, U, t)
        if nargin < 3
            t = 1.0;
        end
        Y = gqf(X + t*U); % Ensures Y'*B*Y is identity.
    end

    % By default, we use the QR retraction
    M.retr = M.retr_qr;

    M.hash = @(X) ['z' hashmd5(X(:))];
    
    M.rand = @random;
    function X = random()
        X = guf(randn(n, p)); % Ensures X'*B*X is identity;
        % TODO: test if this can be replaced by gqf.
    end
    
    M.randvec = @randomvec;
    function U = randomvec(X)
        U = projection(X, randn(n, p));
        U = U / norm(U(:));
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(X) zeros(n, p);
    
    % This transport is compatible with the generalized polar retraction.
    M.transp = @(X1, X2, d) projection(X2, d);
    
    M.vec = @(X, u_mat) u_mat(:);
    M.mat = @(X, u_vec) reshape(u_vec, [n, p]);
    M.vecmatareisometries = @() false;
    
    % Some auxiliary functions
    symm = @(D) (D + D')/2;
    

    function X = gqf(Y)
        % Generalized QR decomposition of an n-by-p matrix Y.
        % X'*B*X is identity. See Sato and Aihara 2018,
        % Cholesky QR-based retraction on the generalized Stiefel manifold,
        % https://link.springer.com/article/10.1007%2Fs10589-018-0046-7
        %
        % Comment by Nicolas Boumal, Dec. 27, 2018, following discussions
        % with Hiroyuki Sato, Kensuke Aihara and Bamdev Mishra:
        %
        % In principle, to orthonormalize the columns of Y against the
        % inner product defined by B, it would be better to implement (for
        % example) a modified Gram-Schmidt algorithm, possibly run twice.
        % Indeed, the latter would be affected by the condition number of
        % sqrtm(B)*Y. In contrast, the method below first squares the data,
        % hence is affected by the condition number of Y'*B*Y, which is the
        % square of that of sqrtm(B)*Y. Fortunately, it is easily shown
        % that, when Y = X + U for a tangent vector U at X, the condition
        % number of sqrtm(B)*Y is upper bounded by the square root of
        % 1 + ||U||_X^2. In Riemannian optimization, it is seldom necessary
        % to retract very large vectors, hence it is expected that the
        % condition numbers encountered in practice will be reasonable. As
        % a result, the implementation below which calls directly upon
        % low-level routines (Cholesky and a small triangular system solve)
        % is expected to run faster in practice than a home-made
        % implementation of modified Gram-Schmidt, which would involve a
        % loop, and still be sufficiently accurate.
        R = chol(symm(Y'*(B*Y)));
        X = Y / R;
    end

    function X = guf(Y)
        % Generalized polar decomposition of an n-by-p matrix Y.
        % X'*B*X is identity.
        
        % Method 1
        [u, ~, v] = svd(Y, 0);
  
        % Instead of the following three steps, an equivalent, but an 
        % expensive way is to do X = u*(sqrtm(u'*(B*u))\(v')).
        [q, ssquare] = eig(u'*(B*u));
        qsinv = q/sparse(diag(sqrt(diag(ssquare))));
        X = u*((qsinv*q')*v'); % X'*B*X is identity.
        
        
        % Another computation using restricted_svd
        % [u, ~, v] = restricted_svd(Y);
        % X = u*v'; % X'*B*X is identity.
        
    end
    
    function [u, s, v] = restricted_svd(Y) %#ok<DEFNU>
        % We compute a thin svd-like decomposition of an n-by-p matrix Y 
        % into matrices u, s, and v such that u is an n-by-p matrix
        % with u'*B*u being identity, s is a p-by-p diagonal matrix 
        % with positive entries, and v is a p-by-p orthogonal matrix.
        % Y = u*s*v'.
        [v, ssquare] = eig(symm(Y'*(B*Y))); % Y*B*Y is positive definite
        ssquarevec = diag(ssquare);
        
        s = sparse(diag(abs(sqrt(ssquarevec))));
        u = Y*(v/s); % u'*B*u is identity.
    end

end
