function checkdiff(problem, x, d, force_gradient)
% Checks the consistency of the cost function and directional derivatives.
%
% function checkdiff(problem)
% function checkdiff(problem, x)
% function checkdiff(problem, x, d)
%
% checkdiff performs a numerical test to check that the directional
% derivatives defined in the problem structure agree up to first order with
% the cost function at some point x, along some direction d. The test is
% based on a truncated Taylor series (see online Manopt documentation).
%
% Both x and d are optional and will be sampled at random if omitted.
%
% See also: checkgradient checkhessian

% If force_gradient = true (hidden parameter), then the function will call
% getGradient and infer the directional derivative, rather than call
% getDirectionalDerivative directly. This is used by checkgradient.

% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Dec. 30, 2012.
% Contributors: 
% Change log: 
%
%   March 26, 2017 (JB):
%       Detects if the approximated linear model is exact
%       and provides the user with the corresponding feedback.
% 
%   April 3, 2015 (NB):
%       Works with the new StoreDB class system.
%
%   Aug. 2, 2018 (NB):
%       Using storedb.remove() to avoid unnecessary cache build-up.
%
%   Sep. 6, 2018 (NB):
%       Now checks whether M.exp() is available; uses retraction otherwise.
%
%   June 18, 2019 (NB):
%       Now issues a warning if the cost function returns complex values.

    if ~exist('force_gradient', 'var')
        force_gradient = false;
    end
        
    % Verify that the problem description is sufficient.
    if ~canGetCost(problem)
        error('It seems no cost was provided.');
    end
    if ~force_gradient && ~canGetDirectionalDerivative(problem)
        error('It seems no directional derivatives were provided.');
    end
    if force_gradient && ~canGetGradient(problem)
        % Would normally issue a warning, but this function should only be
        % called with force_gradient on by checkgradient, which will
        % already have issued a warning.
    end
        
    x_isprovided = exist('x', 'var') && ~isempty(x);
    d_isprovided = exist('d', 'var') && ~isempty(d);
    
    if ~x_isprovided && d_isprovided
        error('If d is provided, x must be too, since d is tangent at x.');
    end
    
    % If x and / or d are not specified, pick them at random.
    if ~x_isprovided
        x = problem.M.rand();
    end
    if ~d_isprovided
        d = problem.M.randvec(x);
    end

    % Compute the value f0 at f and directional derivative at x along d.
    storedb = StoreDB();
    xkey = storedb.getNewKey();
    f0 = getCost(problem, x, storedb, xkey);
    
    if ~force_gradient
        df0 = getDirectionalDerivative(problem, x, d, storedb, xkey);
    else
        grad = getGradient(problem, x, storedb, xkey);
        df0 = problem.M.inner(x, grad, d);
    end
    
    % Pick a stepping function: exponential or retraction?
    if isfield(problem.M, 'exp')
        stepper = problem.M.exp;
    else
        stepper = problem.M.retr;
        % No need to issue a warning: to check the gradient, any retraction
        % (which is first-order by definition) is appropriate.
    end
    
    % Compute the value of f at points on the geodesic (or approximation
    % of it) originating from x, along direction d, for stepsizes in a
    % large range given by h.
    h = logspace(-8, 0, 51);
    value = zeros(size(h));
    for k = 1 : length(h)
        y = stepper(x, d, h(k));
        ykey = storedb.getNewKey();
        value(k) = getCost(problem, y, storedb, ykey);
        storedb.remove(ykey); % no need to keep it in memory
    end
    
    % Compute the linear approximation of the cost function using f0 and
    % df0 at the same points.
    model = polyval([df0 f0], h);
    
    % Compute the approximation error
    err = abs(model - value);
    
    % And plot it.
    loglog(h, err);
    title(sprintf(['Directional derivative check.\nThe slope of the '...
                   'continuous line should match that of the dashed\n'...
                   '(reference) line over at least a few orders of '...
                   'magnitude for h.']));
    xlabel('h');
    ylabel('Approximation error');
    
    line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ...
         'color', 'k', 'LineStyle', '--', ...
         'YLimInclude', 'off', 'XLimInclude', 'off');
    
     
    if ~all( err < 1e-12 )
        % In a numerically reasonable neighborhood, the error should
        % decrease as the square of the stepsize, i.e., in loglog scale,
        % the error should have a slope of 2.
        isModelExact = false;
        window_len = 10;
        [range, poly] = identify_linear_piece(log10(h), log10(err), window_len);
    else
        % The 1st order model is exact: all errors are (numerically) zero
        % Fit line from all points, use log scale only in h.
        isModelExact = true;
        range = 1:numel(h);
        poly = polyfit(log10(h), err, 1);
        % Set mean error in log scale for plot.
        poly(end) = log10(poly(end));
        % Change title to something more descriptive for this special case.
        title(sprintf(...
              ['Directional derivative check.\n'...
               'It seems the linear model is exact:\n'...
               'Model error is numerically zero for all h.']));
    end
    hold all;
    loglog(h(range), 10.^polyval(poly, log10(h(range))), 'LineWidth', 3);
    hold off;
    
    if ~isModelExact
        fprintf('The slope should be 2. It appears to be: %g.\n', poly(1));
        fprintf(['If it is far from 2, then directional derivatives ' ...
                 'might be erroneous.\n']);
    else
        fprintf(['The linear model appears to be exact ' ...
                 '(within numerical precision),\n'...
                 'hence the slope computation is irrelevant.\n']);
    end
    
    if ~(isreal(value) && isreal(f0))
        fprintf(['# The cost function appears to return complex values' ...
              '.\n# Please ensure real outputs.\n']);
    end
    
end
