function [w1,iter,fs,T,method_label,err,perf,I1,ssn_iter,ssn_time,cg_iter] = pg_main(A, b, s, w0, loss_function, eps, acce, identify, opts)

opts = pg_main_opts(opts);
if acce; opts.stable_min = opts.stable_min/2; end

%Parameters
predperf = (isfield(opts,'Atest')) && (isfield(opts,'btest'));

if predperf; Atest = opts.Atest; btest = opts.btest; end

    switch lower(loss_function)
        case 'logistic'
            loss = @(z,b) logistic_loss(z,b);
            grad = @(z,b) logistic_grad(z,b);
            L_scaling = 0.25;
        case 'leastsquare'
            loss = @(z,b) least_square_loss(z,b);
            grad = @(z,b) least_square_grad(z,b);
            L_scaling = 1;
            opts.eps_l2 = 0;
        otherwise
            fprintf('Using LEAST SQUARE LOSS as default\n');
            loss = @(z,b) least_square_loss(z,b);
            grad = @(z,b) least_square_grad(z,b);
            L_scaling = 1;
            opts.eps_l2 = 0;
    end

base_CG = opts.base_CG;
current_CG = base_CG;
eps_l2 = opts.eps_l2;
verbose = opts.verbose;
angle_eps = opts.angle_eps;
%Initialization
[m,n] = size(A);

unchanged = 0;
iter = 2;
stable = 0;
perf = []; 
T = 0;
ssn_iter = 0;
ssn_time = 0;
cg_iter = 0;

[PS2w0,I0] = PS2(w0,s);
w0 = PS2w0;
[f, Aw0, g] = loss_fun_and_grad(loss, grad, w0, A, b, eps_l2);
fs = f;

if acce; oldg =  g; end

eval_timer_norm = tic;
if ~isfield(opts,'normA') && ~isfield(opts,'timeA')
    if (m > n)
        S = @(x) A' * (A*x);
        len = n;
    else
        S = @(x) A * (A'*x);
        len = m;
    end
    opt.Tolerance = 1e-3;
    L = L_scaling * eigs(S, len, 1, 'lr', opt) + eps_l2; %norm(S)
    time_norm = toc(eval_timer_norm);
else
    L = L_scaling * opts.normA + eps_l2;
    time_norm = opts.timeA;
end
lambda = opts.L_coeff/L;

err = residual(s, w0, g, lambda);
if predperf; perf = pred_perf(Atest, btest, w0, loss_function); end

if (verbose > 0); fprintf('Initial f %15.20e  error_measure %g\n',f,err);  end

%First iteration
start = tic; 

[w1,I1] = newpoint(s, w0,lambda, g);
[f, Aw1, g] = loss_fun_and_grad(loss, grad, w1, A, b, eps_l2);

oldf = f;
fs = [fs; f];


%%%%%%%Performance evaluation%%%%%%%%%%
    eval_timer = tic;
    err = [err; residual(s, w1, g, lambda)];
    if predperf; perf = [perf; pred_perf(Atest, btest, w1, loss_function)]; end
    eval_time_acc = toc(eval_timer);
    T = [T; toc(start) - eval_time_acc + time_norm];
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


%Main Loop
while (err(end)>eps) && (iter < opts.iter_max) && (T(end)<opts.time_max)
        iter = iter + 1;
        if identify || acce 
			stable = isequal(I0,I1);
			unchanged = stable * (unchanged + 1);
			I0 = I1;
        end

        if identify && (unchanged>=opts.stable_min)
        %Restricted to the subset so weight between f and constraint violation doesn't matter
            ssn_start = tic;
            ssn_iter = ssn_iter + 1;
            w0 = w1;
            Aw0 = Aw1;
            
			wnew = w1(I1);
			U = A(:,I1);
            
			[~, wnew, Awnew, increase_flag, cg] = restricted_newton(wnew, U, b, loss, grad, eps_l2, opts.newton_iters, verbose, current_CG, opts.preconditioned);
			cg_iter = cg_iter + cg; 
            if (increase_flag)
				current_CG = current_CG * 2^increase_flag;
            else
				current_CG = base_CG;
            end
		
            w1_temp = zeros(n,1);
            w1_temp(I1) = wnew;

			[f_temp, Aw1_temp, g_temp] = loss_fun_and_grad(loss, grad, w1_temp, A, b, eps_l2, Awnew);
            
            if (oldf - f_temp < 0) %to be replaced by a criterion indicating that SSN fails
                fprintf('\t iter %d: SSN failed to obtain a point with lower objective\n',iter);
				current_CG = base_CG;
                f = oldf;
                unchanged = -1;
            else
                fprintf('\t iter %d: SSN obtained a point with a lower objective value\n',iter);
                %Conduct pg step as a safeguard to avoid getting stuck in
                %wrong subspace                             
                [w1 ,I1] = newpoint(s, w1_temp, lambda, g_temp);
                [f, Aw1, g] = loss_fun_and_grad(loss, grad, w1, A, b, eps_l2);
            end
            
            ssn_time = ssn_time + toc(ssn_start);
        
        else%Usual (accelerated) iterations
			current_CG = base_CG;
            if acce
                newg = g;
                if stable
                    r = (newg - oldg);
                    p = w1 - w0;
                    descent = g'*p;
                    descent1 = -descent / norm(p) / norm(g(I1); 
                    
                    if descent1 >= angle_eps
                        norm_p = p'*p; 
                        Ap = Aw1 - Aw0;
                        
                        switch opts.init_linesearch
                            case 'BB' 
                                init_step = norm_p / (p'*r);
                            case 'L'
                                init_step = initstep_cal(norm_p, descent, Ap, grad, Aw1, b, L);
                            case 'L1'
                                L1 = p'*r / norm_p;
                                init_step = initstep_cal(norm_p, descent, Ap, grad, Aw1, b, L1);
                            case 'Hessian'
            					init_step = initstep_cal(norm_p, descent, Ap, grad, Aw1, b);
                            case 'All'
                                init_step_BB = norm_p / (p'*r);
                                L1 = p'*r / norm_p;
                                init_step_L1 = initstep_cal(norm_p, descent, Ap, grad, Aw1, b, L1);
                                init_step_Hes = initstep_cal(norm_p, descent, Ap, grad, Aw1, b);
                                init_step = max([init_step_L1, init_step_Hes, init_step_BB]);
                            otherwise 
                                init_step = 10;
                        end
                        init_step = min( max(init_step,opts.alpha_min) , opts.alpha_max);
                    end
                end
                oldg = newg;
            else
                oldg = g;
            end

            w0 = w1;
            Aw0 = Aw1;

            if acce && stable
                if descent1 >= angle_eps
                    [w1, g, stepsize] = linesearch(loss, grad, A, b, opts.sigma, w1, p, norm_p, g,  f, Aw1, Ap, init_step, eps_l2);
                else
                    stepsize = 0;
                end
                
                if stepsize ~=1
                    fprintf('\tIter %d: Extrapolation with stepsize %g , unchanged = %d\n', iter, stepsize, unchanged)
                end
            end
      
            [w1, I1] = newpoint(s, w1, lambda, g);
			[f, Aw1, g] = loss_fun_and_grad(loss, grad, w1, A, b, eps_l2);
        end
        
		fs = [fs;f];
        relative_err = (oldf-f);
		oldf = f;
		
        %%%%%%%Performance evaluation%%%%%%%%%%
		eval_timer = tic;
		err = [err;residual(s,w1,g,lambda)];
%         if predperf; perf = [perf; pred_perf(Atest, btest, w1, loss_function)]; end
		eval_time_acc = eval_time_acc + toc(eval_timer);
		T = [T; toc(start) - eval_time_acc + time_norm];
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        
        if (verbose > 0) && identify && (unchanged>=opts.stable_min)
            fprintf('\tIter %d (Newton iteration): Newton steps %d\n',iter, opts.newton_iters);
            fprintf('\t        time %g f %15.20e  stable %d unchanged %d error_measure %g rel_error %g \n', T(end), f, stable, unchanged, err(end), relative_err);
        end
                      
        if (verbose > 0 && (mod(iter,verbose) == 0 || iter == opts.iter_max))
			fprintf('iter %d time %g f %15.20e  stable %d unchanged %d error_measure %g rel_error %g \n',iter, T(end), f, stable, unchanged, err(end), relative_err);
        end
        
        
        if (relative_err < 0)
            fprintf('iter %d: relative_err = %g; Cannot improve obj anymore \n',iter,relative_err);
            break;
        end

end
if predperf; perf = [perf; pred_perf(Atest, btest, w1, loss_function)]; end
%Results
method_label = methodname('PG',acce,identify);

if err(end)<eps
	fprintf('Solution obtained using %s \n',method_label);
else
	fprintf('Rerun %s \n',method_label);
end
fprintf('\t iter %d time %g f %15.20e  stable %d unchanged %d error_measure %g rel_error %g \n',iter, T(end), f, stable, unchanged, err(end), relative_err);
end
