function M = fixedrankfactory_3factors_preconditioned(m, n, k)
% Manifold of m-by-n matrices of rank k with three factor quotient geometry.
%
% function M = fixedrankfactory_3factors_preconditioned(m, n, k)
%
% This geometry is tuned to least squares problems such as low-rank matrix
% completion with ell-2 loss.
%
% A point X on the manifold is represented as a structure with three
% fields: L, S and R. The matrices L (mxk) and R (nxk) are orthonormal,
% while the matrix S (kxk) is a full rank matrix such that X = L*S*R'.
%
% Tangent vectors are represented as a structure with three fields: L, S
% and R.
%
% Please cite the Manopt paper as well as the research paper:
%     @InProceedings{mishra2014r3mc,
%       Title        = {{R3MC}: A {R}iemannian three-factor algorithm for low-rank matrix completion},
%       Author       = {Mishra, B. and Sepulchre, R.},
%       Booktitle    = {{53rd IEEE Conference on Decision and Control}},
%       Year         = {2014},
%       Organization = {{IEEE CDC}}
%     }
%
%
% See also: fixedrankfactory_3factors fixedrankfactory_2factors_preconditioned

% This file is part of Manopt: www.manopt.org.
% Original author: Bamdev Mishra, Dec. 30, 2012.
% Contributors:
% Change log:
%
%    Apr.  4, 2015 (BM):
%        Cosmetic changes including avoiding storing the inverse of a kxk matrix.
%
%    Apr. 18, 2018 (NB):
%        Removed lyap dependency.
%
%    Sep. 6, 2018 (NB):
%        Removed M.exp() as it was not implemented.

    
    M.name = @() sprintf('LSR'' (tuned for least square problems) quotient manifold of %dx%d matrices of rank %d', m, n, k);
    
    M.dim = @() (m+n-k)*k;
    
    % Some precomputations at the point X that are to be used in the inner product (and
    % pretty much everywhere else).
    function X = prepare(X)
        if ~all(isfield(X,{'StS','SSt'}) == 1)
            X.SSt = X.S*X.S';
            X.StS = X.S'*X.S;
        end
    end
    
    % The choice of metric is motivated by symmetry and tuned to least square
    % objective function.
    M.inner = @iproduct;
    function ip = iproduct(X, eta, zeta)
        X = prepare(X);
        
        ip = trace(X.SSt*(eta.L'*zeta.L)) + trace(X.StS*(eta.R'*zeta.R)) ...
            + trace(eta.S'*zeta.S);
    end
    
    M.norm = @(X, eta) sqrt(M.inner(X, eta, eta));
    
    M.dist = @(x, y) error('fixedrankfactory_3factors_preconditioned.dist not implemented yet.');
    
    M.typicaldist = @() 10*k;
    
    skew = @(X) .5*(X-X');
    symm = @(X) .5*(X+X');
    
    M.egrad2rgrad = @egrad2rgrad;
    function rgrad = egrad2rgrad(X, egrad)
        X = prepare(X);
        
        SSL = X.SSt;
        ASL = 2*symm(SSL*(egrad.S*X.S'));
        
        SSR = X.StS;
        ASR = 2*symm(SSR*(egrad.S'*X.S));
        
        [BL, BR] = tangent_space_lyap(X.S, ASL, ASR); % It computes the solution without calling Matlab's Lyap.
        
        rgrad.L = (egrad.L - X.L*BL)/X.SSt;
        rgrad.R = (egrad.R - X.R*BR)/X.StS;
        rgrad.S = egrad.S;
        
        % Debug
        %         BL1 = lyap(SSL, -ASL); % Alternate way
        %         BR1 = lyap(SSR, -ASR);
        %         norm(skew(X.SSt*(rgrad.L'*X.L) + rgrad.S*X.S'), 'fro')
        %         norm(skew(X.StS*(rgrad.R'*X.R) - X.S'*rgrad.S), 'fro')
        
    end
    
    
    
    M.ehess2rhess = @ehess2rhess;
    function Hess = ehess2rhess(X, egrad, ehess, eta)
        X = prepare(X);
        
        % Riemannian gradient.
        SSL = X.SSt;
        ASL = 2*symm(SSL*(egrad.S*X.S'));
        SSR = X.StS;
        ASR = 2*symm(SSR*(egrad.S'*X.S));
        [BL, BR] = tangent_space_lyap(X.S, ASL, ASR);
        
        rgrad.L = (egrad.L - X.L*BL)/X.SSt;
        rgrad.R = (egrad.R - X.R*BR)/X.StS;
        rgrad.S = egrad.S;
        
        % Directional derivative of the Riemannian gradient.
        ASLdot = 2*symm((2*symm(X.S*eta.S')*(egrad.S*X.S')) + X.SSt*(ehess.S*X.S' + egrad.S*eta.S')) - 4*symm(symm(eta.S*X.S')*BL);
        ASRdot = 2*symm((2*symm(X.S'*eta.S)*(egrad.S'*X.S)) + X.StS*(ehess.S'*X.S + egrad.S'*eta.S)) - 4*symm(symm(eta.S'*X.S)*BR);
        
        %         SSLdot = X.SSt;
        %         SSRdot = X.StS;
        %         BLdot = lyap(SSLdot, -ASLdot);
        %         BRdot = lyap(SSRdot, -ASRdot);
        
        [BLdot, BRdot] = tangent_space_lyap(X.S, ASLdot, ASRdot);
        
        Hess.L = (ehess.L - eta.L*BL - X.L*BLdot - 2*rgrad.L*symm(eta.S*X.S'))/X.SSt;
        Hess.R = (ehess.R - eta.R*BR - X.R*BRdot - 2*rgrad.R*symm(eta.S'*X.S))/X.StS;
        Hess.S = ehess.S;
        
        
        
        % BM: Till this, everything seems correct.
        % We still need a correction factor for the non-constant metric
        % that is imposed.
        % The computation of the correction factor owes itself to the Koszul formula.
        % This corresponds to the Riemannian connection in the Euclidean space with the
        % scaled metric.
        Hess.L = Hess.L + (eta.L*symm(rgrad.S*X.S') + rgrad.L*symm(eta.S*X.S'))/X.SSt;
        Hess.R = Hess.R + (eta.R*symm(rgrad.S'*X.S) + rgrad.R*symm(eta.S'*X.S))/X.StS;
        Hess.S = Hess.S - symm(rgrad.L'*eta.L)*X.S - X.S*symm(rgrad.R'*eta.R);
        
        % The Riemannian connection on the quotient space is the
        % projection of the Riemannian connection in the ambient space onto the tangent space of the total space and
        % then onto the horizontal space. 
        % This is accomplished by the following operation.
        Hess = M.proj(X, Hess);
        
        % Debug
        %         norm(skew(X.SSt*(Hess.L'*X.L) + Hess.S*X.S'))
        %         norm(skew(X.StS*(Hess.R'*X.R) - X.S'*Hess.S))
        
    end
    
    
    
    
    M.proj = @projection;
    function etaproj = projection(X, eta)
        X = prepare(X);
        
        % First, projection onto the tangent space of the total space.
        SSL = X.SSt;
        ASL = 2*symm(X.SSt*(X.L'*eta.L)*X.SSt);
        BL = lyapunov_symmetric(SSL, ASL);
        eta.L = eta.L - X.L*(BL/X.SSt);
        
        SSR = X.StS;
        ASR = 2*symm(X.StS*(X.R'*eta.R)*X.StS);
        BR = lyapunov_symmetric(SSR, ASR);
        eta.R = eta.R - X.R*(BR/X.StS);
        
        % Project onto the horizontal space
        PU = skew((X.L'*eta.L)*X.SSt) + skew(X.S*eta.S');
        PV = skew((X.R'*eta.R)*X.StS)  + skew(X.S'*eta.S);
        [Omega1, Omega2] = coupled_lyap(X.S, PU, PV);
        %         norm(2*skew(Omega1*X.SSt) - PU -(X.S*Omega2*X.S'),'fro' )
        %         norm(2*skew(Omega2*X.StS) - PV -(X.S'*Omega1*X.S),'fro' )
        %
        
        etaproj.L = eta.L - (X.L*Omega1);
        etaproj.S = eta.S - (X.S*Omega2 - Omega1*X.S) ;
        etaproj.R = eta.R - (X.R*Omega2);
        
        
        % Debug
        %         norm(skew(X.SSt*(etaproj.L'*X.L) + etaproj.S*X.S'))
        %         norm(skew(X.StS*(etaproj.R'*X.R) - X.S'*etaproj.S))
        %
        %         norm(skew(X.SSt*(etaproj.L'*X.L) - X.S*etaproj.S'))
        %         norm(skew(X.StS*(etaproj.R'*X.R) + etaproj.S'*X.S))
        
    end
    
    
    M.tangent = M.proj;
    M.tangent2ambient = @(X, eta) eta;
    
    M.retr = @retraction;
    function Y = retraction(X, eta, t)
        if nargin < 3
            t = 1.0;
        end
        
        Y.S = (X.S + t*eta.S);
        Y.L = uf((X.L + t*eta.L));
        Y.R = uf((X.R + t*eta.R));
        
        Y = prepare(Y);
    end
    
    
    M.hash = @(X) ['z' hashmd5([X.L(:) ; X.S(:) ; X.R(:)])];
    
    M.rand = @random;
    % Factors L and R live on Stiefel manifolds, hence we will reuse
    % their random generator.
    stiefelm = stiefelfactory(m, k);
    stiefeln = stiefelfactory(n, k);
    function X = random()
        X.L = stiefelm.rand();
        X.R = stiefeln.rand();
        X.S = diag(1+rand(k, 1));
        
        X = prepare(X);
    end
    
    M.randvec = @randomvec;
    function eta = randomvec(X)
        % A random vector on the horizontal space
        eta.L = randn(m, k);
        eta.R = randn(n, k);
        eta.S = randn(k, k);
        eta = projection(X, eta);
        nrm = M.norm(X, eta);
        eta.L = eta.L / nrm;
        eta.R = eta.R / nrm;
        eta.S = eta.S / nrm;
    end
    
    M.lincomb = @lincomb;
    
    M.zerovec = @(X) struct('L', zeros(m, k), 'S', zeros(k, k), ...
        'R', zeros(n, k));
    
    M.transp = @(x1, x2, d) projection(x2, d);
    
    % vec and mat are not isometries, because of the unusual inner metric.
    M.vec = @(X, U) [U.L(:) ; U.S(:); U.R(:)];
    M.mat = @(X, u) struct('L', reshape(u(1:(m*k)), m, k), ...
        'S', reshape(u((m*k+1): m*k + k*k), k, k), ...
        'R', reshape(u((m*k+ k*k + 1):end), n, k));
    M.vecmatareisometries = @() false;
    
end

% Linear combination of tangent vectors
function d = lincomb(x, a1, d1, a2, d2) %#ok<INUSL>
    
    if nargin == 3
        d.L = a1*d1.L;
        d.R = a1*d1.R;
        d.S = a1*d1.S;
    elseif nargin == 5
        d.L = a1*d1.L + a2*d2.L;
        d.R = a1*d1.R + a2*d2.R;
        d.S = a1*d1.S + a2*d2.S;
    else
        error('Bad use of fixedrankfactory_3factors_preconditioned.lincomb.');
    end
    
end

function A = uf(A)
    [L, unused, R] = svd(A, 0); %#ok
    A = L*R';
end

function[BU, BV] = tangent_space_lyap(R, E, F)
    % We intent to solve a linear system    RR^T  BU + BU RR^T  = E
    %                                       R^T R BV + BV R^T R = F
    % for BU and BV.
    %
    % This can be solved using two calls to the Matlab's lyap.
    % However, we can still have a more efficient implementation
    % that does not require the full functionality of Matlab's lyap.
    
    [U, Sigma, V] = svd(R);
    E_mod = U'*E*U;
    F_mod = V'*F*V;
    b1 = E_mod(:);
    b2 = F_mod(:);
    
    r = size(Sigma, 1);
    sig = diag(Sigma); % all the singular values in a vector
    sig1 = sig*ones(1, r); % columns repeat
    sig1t = sig1'; % rows repeat
    s1 = sig1(:);
    s2 = sig1t(:);
    
    % The block elements
    a =  s1.^2 + s2.^2; % a column vector
    
    % Solve the linear system of equations
    cu = b1./a; %a.\b1;
    cv = b2./a; %a.\b2;
    
    % Matricize
    CU = reshape(cu, r, r);
    CV = reshape(cv, r, r);
    
    % Do the similarity transforms
    BU = U*CU*U';
    BV = V*CV*V';
    
    % %% Debug
    %
    % norm(R*R'*BU + BU*R*R' - E, 'fro');
    % norm((Sigma.^2)*CU + CU*(Sigma.^2) - E_mod, 'fro');
    % norm(a.*cu - b1, 'fro');
    %
    % norm(R'*R*BV + BV*R'*R - F, 'fro');
    %
    % BU1 = lyap(R*R', - E);
    % norm(R*R'*BU1 + BU1*R*R' - E, 'fro');
    %
    % BV1 = lyap(R'*R, - F);
    % norm(R'*R*BV1 + BV1*R'*R - F, 'fro');
    %
    % % as accurate as the lyap
    % norm(BU - BU1, 'fro')
    % norm(BV - BV1, 'fro')
end



function[Omega1, Omega2] = coupled_lyap(R, E, F)
    % We intent to solve the coupled system of Lyapunov equations
    %
    % RR^T Omega1 + Omega1 RR^T  - R Omega2 R^T = E
    % R^T R Omega2 + Omega1 R^T R  - R^T Omega2 R = F,
    %
    % for Omega1 and Omega2, both are skew symmetric matrices.
    %
    % Below is an efficient implementation
    
    [U, Sigma, V] = svd(R);
    E_mod = U'*E*U;
    F_mod = V'*F*V;
    b1 = E_mod(:);
    b2 = F_mod(:);
    
    r = size(Sigma, 1);
    sig = diag(Sigma); % All the singular values in a vector
    sig1 = sig*ones(1, r); % Columns repeat
    sig1t = sig1'; % Rows repeat
    s1 = sig1(:);
    s2 = sig1t(:);
    
    % The block elements
    a =  s1.^2 + s2.^2; % A column vector
    c = s1.*s2;
    
    % Solve directly using the formula
    % A = diag(a);
    % C = diag(c);
    % Y1_sol = (A*(C\A) - C) \ (b2 + A*(C\b1));
    % Y2_sol = A\(b2 + C*Y1_sol);
    
    Y1_sol = (b2 + (a./c).*b1) ./ ((a.^2)./c - c);
    Y2_sol = (b2 + c.*Y1_sol)./a;
    
    % Matricize
    Omega1 = reshape(Y1_sol, r, r);
    Omega2 = reshape(Y2_sol, r, r);
    
    % Do the similarity transforms
    Omega1 = U*Omega1*U';
    Omega2 = V*Omega2*V';
    
    % %% Debug: whether we have the right solution.
    % norm(R*R'*Omega1 + Omega1*R*R'  - R*Omega2*R' - E, 'fro')
    % norm(R'*R*Omega2 + Omega2*R'*R  - R'*Omega1*R - F, 'fro')
end