function [x, cost, info, options] = neldermead(problem, x, options)
% Nelder Mead optimization algorithm for derivative-free minimization.
%
% function [x, cost, info, options] = neldermead(problem)
% function [x, cost, info, options] = neldermead(problem, x0)
% function [x, cost, info, options] = neldermead(problem, x0, options)
% function [x, cost, info, options] = neldermead(problem, [], options)
%
% Apply a Nelder-Mead minimization algorithm to the problem defined in
% the problem structure, starting with the population x0 if it is provided
% (otherwise, a random population on the manifold is generated). A
% population is a cell containing points on the manifold. The number of
% elements in the cell must be dim+1, where dim is the dimension of the
% manifold: problem.M.dim().
%
% To specify options whilst not specifying an initial guess, give x0 as []
% (the empty matrix).
%
% This algorithm is a plain adaptation of the Euclidean Nelder-Mead method
% to the Riemannian setting. It comes with no convergence guarantees and
% there is room for improvement. In particular, we compute centroids as
% Karcher means, which seems overly expensive: cheaper forms of
% average-like quantities might work better.
% This solver is useful nonetheless for problems for which no derivatives
% are available, and it may constitute a starting point for the development
% of other Riemannian derivative-free methods.
%
% None of the options are mandatory. See in code for details.
%
% Requires problem.M.pairmean(x, y) to be defined (computes the average
% between two points, x and y).
%
% If options.statsfun is defined, it will receive a cell of points x (the
% current simplex being considered at that iteration), and, if required,
% one store structure corresponding to the best point, x{1}. The points are
% ordered by increasing cost: f(x{1}) <= f(x{2}) <= ... <= f(x{dim+1}),
% where dim = problem.M.dim().
%
% Based on http://www.optimization-online.org/DB_FILE/2007/08/1742.pdf.
%
% See also: manopt/solvers/pso/pso

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Dec. 30, 2012.
% Contributors: 
% Change log: 
%
%   Apr.  4, 2015 (NB):
%       Working with the new StoreDB class system.
%       Clarified interactions with statsfun and store.
%
%   Nov. 11, 2016 (NB):
%       If options.verbosity is < 2, prints minimal output.
%
%   Sep.  6, 2018 (NB):
%       Using retraction instead of exponential.

    
    % Verify that the problem description is sufficient for the solver.
    if ~canGetCost(problem)
        warning('manopt:getCost', ...
                'No cost provided. The algorithm will likely abort.');  
    end
    
    % Dimension of the manifold
    dim = problem.M.dim();

    % Set local defaults here
    localdefaults.storedepth = 0;                     % no need for caching
    localdefaults.maxiter = max(2000, 4*dim);
    
    localdefaults.reflection = 1;
    localdefaults.expansion = 2;
    localdefaults.contraction = .5;
    % forced to .5 to enable using pairmean functions in manifolds.
    % localdefaults.shrinkage = .5;
    
    % Merge global and local defaults, then merge w/ user options, if any.
    localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
    if ~exist('options', 'var') || isempty(options)
        options = struct();
    end
    options = mergeOptions(localdefaults, options);
    
    % Start timing for initialization.
    timetic = tic();
    
    % If no initial simplex x is given by the user, generate one at random.
    if ~exist('x', 'var') || isempty(x)
        x = cell(dim+1, 1);
        for i = 1 : dim+1
            x{i} = problem.M.rand();
        end
    end
    
    % Create a store database and a key for each point.
    storedb = StoreDB(options.storedepth);
    key = cell(size(x));
    for i = 1 : dim+1;
        key{i} = storedb.getNewKey();
    end
    
    % Compute objective-related quantities for x, and setup a
    % function evaluations counter.
    costs = zeros(dim+1, 1);
    for i = 1 : dim+1
        costs(i) = getCost(problem, x{i}, storedb, key{i});
    end
    costevals = dim+1;
    
    % Sort simplex points by cost.
    [costs, order] = sort(costs);
    x = x(order);
    key = key(order);
    
    % Iteration counter.
    % At any point, iter is the number of fully executed iterations so far.
    iter = 0;
    
    % Save stats in a struct array info, and preallocate.
    % savestats will be called twice for the initial iterate (number 0),
    % which is unfortunate, but not problematic.
    stats = savestats();
    info(1) = stats;
    info(min(10000, options.maxiter+1)).iter = [];
    
    % Start iterating until stopping criterion triggers.
    while true
        
        % Make sure we don't use to much memory for the store database.
        storedb.purge();
        
        stats = savestats();
        info(iter+1) = stats; %#ok<AGROW>
        iter = iter + 1;
        
        % Start timing this iteration.
        timetic = tic();
        
        % Sort simplex points by cost.
        [costs, order] = sort(costs);
        x = x(order);
        key = key(order);

        % Log / display iteration information here.
        if options.verbosity >= 2
            fprintf('Cost evals: %7d\tBest cost: %+.4e\t', ...
                    costevals, costs(1));
        end
        
        % Run standard stopping criterion checks.
        [stop, reason] = stoppingcriterion(problem, x, options, info, iter);
    
        if stop
            if options.verbosity >= 1
                fprintf([reason '\n']);
            end
            break;
        end
        
        % Compute a centroid for the dim best points.
        xbar = centroid(problem.M, x(1:end-1));
        
        % Compute the direction for moving along the axis xbar - worst x.
        vec = problem.M.log(xbar, x{end});
        
        % Reflection step
        xr = problem.M.retr(xbar, vec, -options.reflection);
        keyr = storedb.getNewKey();
        costr = getCost(problem, xr, storedb, keyr);
        costevals = costevals + 1;
        
        % If the reflected point is honorable, drop the worst point,
        % replace it by the reflected point and start new iteration.
        if costr >= costs(1) && costr < costs(end-1)
            if options.verbosity >= 2
                fprintf('Reflection\n');
            end
            costs(end) = costr;
            x{end} = xr;
            key{end} = keyr;
            continue;
        end
        
        % If the reflected point is better than the best point, expand.
        if costr < costs(1)
            xe = problem.M.retr(xbar, vec, -options.expansion);
            keye = storedb.getNewKey();
            coste = getCost(problem, xe, storedb, keye);
            costevals = costevals + 1;
            if coste < costr
                if options.verbosity >= 2
                    fprintf('Expansion\n');
                end
                costs(end) = coste;
                x{end} = xe;
                key{end} = keye;
                continue;
            else
                if options.verbosity >= 2
                    fprintf('Reflection (failed expansion)\n');
                end
                costs(end) = costr;
                x{end} = xr;
                key{end} = keyr;
                continue;
            end
        end
        
        % If the reflected point is worse than the second to worst point,
        % contract.
        if costr >= costs(end-1)
            if costr < costs(end)
                % do an outside contraction
                xoc = problem.M.retr(xbar, vec, -options.contraction);
                keyoc = storedb.getNewKey();
                costoc = getCost(problem, xoc, storedb, keyoc);
                costevals = costevals + 1;
                if costoc <= costr
                    if options.verbosity >= 2
                        fprintf('Outside contraction\n');
                    end
                    costs(end) = costoc;
                    x{end} = xoc;
                    key{end} = keyoc;
                    continue;
                end
            else
                % do an inside contraction
                xic = problem.M.retr(xbar, vec, options.contraction);
                keyic = storedb.getNewKey();
                costic = getCost(problem, xic, storedb, keyic);
                costevals = costevals + 1;
                if costic <= costs(end)
                    if options.verbosity >= 2
                        fprintf('Inside contraction\n');
                    end
                    costs(end) = costic;
                    x{end} = xic;
                    key{end} = keyic;
                    continue;
                end
            end
        end
        
        % If we get here, shrink the simplex around x{1}.
        if options.verbosity >= 2
            fprintf('Shrinkage\n');
        end
        for i = 2 : dim+1
            x{i} = problem.M.pairmean(x{1}, x{i});
            key{i} = storedb.getNewKey();
            costs(i) = getCost(problem, x{i}, storedb, key{i});
        end
        costevals = costevals + dim;
        
    end
    
    
    info = info(1:iter);
    
    % Iteration done: return only the best point found.
    cost = costs(1);
    x = x{1};
    key = key{1};
    
    
    
    % Routine in charge of collecting the current iteration stats.
    function stats = savestats()
        stats.iter = iter;
        stats.cost = costs(1);
        stats.costevals = costevals;
        if iter == 0
            stats.time = toc(timetic);
        else
            stats.time = info(iter).time + toc(timetic);
        end
        % The statsfun can only possibly receive one store structure. We
        % pass the key to the best point, so that the best point's store
        % will be passed. But the whole cell x of points is passed through.
        stats = applyStatsfun(problem, x, storedb, key{1}, options, stats);
    end
    
end
