function M = symmetricfactory(n, k)
% Returns a manifold struct to optimize over k symmetric matrices of size n
%
% function M = symmetricfactory(n)
% function M = symmetricfactory(n, k)
%
% Returns M, a structure describing the Euclidean space of n-by-n symmetric
% matrices equipped with the standard Frobenius distance and associated
% trace inner product, as a manifold for Manopt.
% 
% By default, k = 1. If k > 1, points and vectors are stored in 3D matrices
% X of size nxnxk such that each slice X(:, :, i), for i = 1:k, is
% symmetric.

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Jan. 22, 2014.
% Contributors: 
% Change log: 
%
%   Jan. 25, 2017 (NB):
%       M.tangent = M.proj now, instead of being identity. This is notably
%       necessary so that checkgradient will pick up on gradients that do
%       not lie in the appropriate tangent space.
    
    if ~exist('k', 'var') || isempty(k)
        k = 1;
    end

    M.name = @() sprintf('(Symmetric matrices of size %d)^%d', n, k);
    
    M.dim = @() k*n*(n+1)/2;
    
    M.inner = @(x, d1, d2) d1(:).'*d2(:);
    
    M.norm = @(x, d) norm(d(:), 'fro');
    
    M.dist = @(x, y) norm(x(:)-y(:), 'fro');
    
    M.typicaldist = @() sqrt(k)*n;
    
    M.proj = @(x, d) multisym(d);
    
    M.egrad2rgrad = M.proj;
    
    M.ehess2rhess = @(x, eg, eh, d) M.proj(x, eh);
    
    M.tangent = M.proj;
    
    M.exp = @exp;
    function y = exp(x, d, t)
        if nargin == 3
            y = x + t*d;
        else
            y = x + d;
        end
    end
    
    M.retr = M.exp;
    
    M.log = @(x, y) y-x;
    
    M.hash = @(x) ['z' hashmd5(x(:))];
    
    M.rand = @() multisym(randn(n, n, k));
    
    M.randvec = @randvec;
    function u = randvec(x) %#ok<INUSD>
        u = multisym(randn(n, n, k));
        u = u / norm(u(:), 'fro');
    end
    
    M.lincomb = @matrixlincomb;
    
    M.zerovec = @(x) zeros(n, n, k);
    
    M.transp = @(x1, x2, d) d;
    
    M.pairmean = @(x1, x2) .5*(x1+x2);
    
    
    % Elaborate list of indices of diagonal entries of an nxnxk matrix.
    single_diag_entries = (1:(n+1):n^2)';
    all_diag_entries = bsxfun(@plus, single_diag_entries, n^2*(0:(k-1)));
    all_diag_entries = all_diag_entries(:);
    
    % Likewise, elaborate list of indices of upper-triangular entries.
    single_upper_triangle = find(triu(ones(n), 1));
    all_upper_triangle = bsxfun(@plus, single_upper_triangle, n^2*(0:(k-1)));
    all_upper_triangle = all_upper_triangle(:);
    
    % To vectorize a matrix, we extract all diagonal entries, then all
    % upper-triangular entries, the latter being scaled by sqrt(2) to
    % ensure isometry, that is: given two tangent vectors U and V at a
    % point X, M.inner(X, U, V) is equal to u'*v, where u = M.vec(X, U) and
    % likewise for v. This construction has the advantage of providing a
    % vectorized representation of matrices that has the same length as the
    % intrinsic dimension of the space they live in.
    M.vec = @(x, u_mat) [u_mat(all_diag_entries) ; ...
                         sqrt(2)*u_mat(all_upper_triangle)];
    M.mat = @matricize;
    function u_mat = matricize(X, u_vec) %#ok<INUSL>
        u_mat = zeros(n, n, k);
        u_mat(all_upper_triangle) = u_vec((k*n+1):end) / sqrt(2);
        u_mat = u_mat + multitransp(u_mat);
        u_mat(all_diag_entries) = u_vec(1:(k*n));
    end
    M.vecmatareisometries = @() true;

end

% Former, easier versions for vec / mat. They had the disadvantage of
% giving vector representations of length k*n^2, instead of k*n*(n+1)/2.
% M.vec = @(x, u_mat) u_mat(:);
% M.mat = @(x, u_vec) reshape(u_vec, [m, n]);
% M.vecmatareisometries = @() true;
