table = readtable('wine/winequality-red.csv');
X = table{:, 1:11};
y = table{:, 12};
z = y*1;
z(z > 10.0) = 10.0;
z(z < 0.0) = 0.0;
z(z <6) = 6;

gamma_range = linspace(1e-3, 0.2, 15);
nu_range = [1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100 ,1000];
tol = 1e-1;
q_start = 1000;
bias_upper = 20;

[bisect_results, bruckner_results, ridge_results] = gather_results(X, y, z, q_start, bias_upper, tol, gamma_range, nu_range);
save('bisect_wine_modest_nip_new.mat', 'bisect_results');
save('bruck_wine_modest_nip_new.mat', 'bruckner_results');
save('ridge_wine_modest_nip_new.mat', 'ridge_results');

% gamma = 0.5;
% [bruck_mr, bruck_std, bisect_mr, bisect_std] = get_timings(X, y, z, gamma, q_start, tol, [100, 200, 300, 400, 500, 600, 700, 800], bias_upper);
% save('r_bisect_mean_wine_modest.mat', 'bisect_mr');
% save('r_bisect_std_wine_modest.mat', 'bisect_std');
% save('r_bruck_mean_wine_modest.mat', 'bruck_mr');
% save('r_bruck_std_wine_modest.mat', 'bruck_std');

%z = y*2;
z = 1*y;
z(z > 10.0) = 10.0;
z(z < 0) = 0.0;
z(z < 8) = 8;

[bisect_results, bruckner_results, ridge_results] = gather_results(X, y, z, q_start, bias_upper, tol, gamma_range, nu_range);
save('bisect_wine_severe_nip_new.mat', 'bisect_results');
save('bruck_wine_severe_nip_new.mat', 'bruckner_results');
save('ridge_wine_severe_nip_new.mat', 'ridge_results');

% gamma = 0.5;
% [bruck_mr, bruck_std, bisect_mr, bisect_std] = get_timings(X, y, z, gamma, q_start, tol, [100, 200, 300, 400, 500, 600, 700, 800], bias_upper);
% save('bisect_mean_wine_severe.mat', 'bisect_mr');
% save('bisect_std_wine_severe.mat', 'bisect_std');
% save('bruck_mean_wine_severe.mat', 'bruck_mr');
% save('bruck_std_wine_severe.mat', 'bruck_std');

table = readtable('insurance.csv');
T = createOneHotEncoding(table, 'sex');
T = createOneHotEncoding(T, 'smoker');
T = createOneHotEncoding(T, 'region');
y = T.charges;
z = y - 100;
z(z < 0.0) = 0.0;
z = z / 100;
y = y / 100;
T = removevars(T,{'charges'});
X = T{:,:};

gamma_range = linspace(0.01, 0.75, 20);
nu_range = linspace(1e-5, 10000, 100);
tol = 1e-8;
bias_upper = 20;
q_start = 13380000 / 500.0;


[bisect_results, bruckner_results, ridge_results] = gather_results(X, y, z, q_start, bias_upper, tol, gamma_range, nu_range);
save('bisect_in_modest_nip_new.mat', 'bisect_results');
save('bruck_in_modest_nip_new.mat', 'bruckner_results');
save('ridge_in_modest_nip_new.mat', 'ridge_results');

gamma = 0.5;
[bruck_mr, bruck_std, bisect_mr, bisect_std] = get_timings(X, y, z, gamma, q_start, tol, [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200], bias_upper);
save('bisect_mean_in_modest_nip.mat', 'bisect_mr');
save('bisect_std_in_modest_nip.mat', 'bisect_std');
save('bruck_mean_in_modest_nip.mat', 'bruck_mr');
save('bruck_std_in_modest_nip.mat', 'bruck_std');

table = readtable('insurance.csv');
T = createOneHotEncoding(table, 'sex');
T = createOneHotEncoding(T, 'smoker');
T = createOneHotEncoding(T, 'region');
y = T.charges;
z = y - 300;
z(z < 0.0) = 0.0;
z = z / 100;
y = y / 100;
T = removevars(T,{'charges'});
X = T{:,:};
% 
[bisect_results, bruckner_results, ridge_results] = gather_results(X, y, z, q_start, bias_upper, tol, gamma_range, nu_range);
save('bisect_in_severe_nip.mat', 'bisect_results');
save('bruck_in_severe_nip.mat', 'bruckner_results');
save('ridge_in_severe_nip.mat', 'ridge_results');

gamma = 0.5;
[bruck_mr, bruck_std, bisect_mr, bisect_std] = get_timings(X, y, z, gamma, q_start, tol, [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200], bias_upper);
save('bisect_mean_in_severe_nip.mat', 'bisect_mr');
save('bisect_std_in_severe_nip.mat', 'bisect_std');
save('bruck_mean_in_severe_nip.mat', 'bruck_mr');
save('bruck_std_in_severe_nip.mat', 'bruck_std');

% map_values = heatmap(0.3, 5.0);
% map_ridge = heatmap_ridge(0.3, 5.0);
% save('rebut_map_values_bruck.mat', 'map_values');
% save('rebut_map_values_ridge.mat', 'map_ridge'); 

function A = build_A(X, y, z, q, gamma)
    m = size(X, 1);
    n = size(X, 2);
    A_11 = X.'*X;
    A_12 = X.'*ones(m, 1);
    A_13 = (1.0 / gamma)*X.'*(z - y);
    A_14 = -X.'*y;
    
    A_1 = [A_11, A_12, A_13, A_14];
    
    A_21 = ones(m, 1).'*X;
    A_22 = ones(m, 1).'*ones(m, 1);
    A_23 = (1.0 / gamma) * ones(m, 1).'*(z - y);
    A_24 = -ones(m, 1).'*y;
    
    A_2 = [A_21, A_22, A_23, A_24];
    
    A_31 = (1.0 / gamma) * (z - y).'*X;
    A_32 = (1.0 / gamma) * (z-y).'*ones(m, 1);
    A_33 = (1.0 / gamma) * (1.0 / gamma) * (z-y).'*(z-y) - (1.0 / gamma) * (1.0 / gamma) * q;
    A_34 = (1.0 / gamma) * y.'*(y - z) - (1.0 / gamma) * q;
    
    A_3 = [A_31, A_32, A_33, A_34];
    
    A_41 = -y.'*X;
    A_42 = -y.'*ones(m, 1);
    A_43 = (1.0 / gamma) * y.'*(y - z) - (1.0 / gamma)* q;
    A_44 = y.'*y - q;
    
    A_4 = [A_41, A_42, A_43, A_44];
    
    A = [A_1; A_2; A_3; A_4];
end

function B = build_B(X, y, z, q, gamma)
    m = size(X, 1);
    n = size(X, 2);
    
    B_11 = eye(n);
    B_12 = zeros(n, 3);
    
    B_1 = [B_11, B_12];
    
    B_2 = zeros(3, n+3);
    B_2(2, n+3) = -0.5;
    B_2(3, n+2) = -0.5;
    
    B = [B_1; B_2];
end

function C = build_C(X, y, z, q, gamma)
    n = size(X, 2);
    C = zeros(n+3, n+3);
    C(n+3, n+3) = -1.0;
end

function Aq = build_quad_A(X, y , z , q, gamma)
    m = size(X, 1);
    n = size(X, 2);

    Aq_11 = (1.0 / gamma) * (1.0 / gamma) * (z-y).'*(z-y) - (1.0 / gamma) * (1.0 / gamma) * q;
    Aq_12 = (1.0 / gamma) * (z - y).'*X;
    Aq_13 = (1.0 / gamma) * (z-y).'*ones(m, 1);
    
    Aq_1 = [Aq_11, Aq_12, Aq_13];
    
    Aq_21 = (1.0 / gamma)*X.'*(z - y);
    Aq_22 = X.'*X;
    Aq_23 = X.'*ones(m, 1);
    
    Aq_2 = [Aq_21, Aq_22, Aq_23];
    
    Aq_31 = (1.0 / gamma) * ones(m, 1).'*(z - y);
    Aq_32 = ones(m, 1).'*X;
    Aq_33 = ones(m, 1).'*ones(m, 1);
    
    Aq_3 = [Aq_31, Aq_32, Aq_33];
    
    Aq = [Aq_1; Aq_2; Aq_3];
end

function aq = build_lin_A(X, y , z, q, gamma)
    m = size(X, 1);
    n = size(X, 2);
    
    aq_1 = (1.0 / gamma) * y.'*(y - z) - (1.0 / gamma) * q;
    aq_2 = -X.'*y;
    aq_3 = -ones(m, 1).'*y;
    
    aq = [aq_1; aq_2; aq_3];
end

function ac = build_const_A(X, y, z, q, gamma)
    ac = y.'*y - q;
end

function Bq = build_quad_B(X, y, z, q, gamma)
    m = size(X, 1);
    n = size(X, 2);
    
    Bq = eye(n+2);
    Bq(1, 1) = 0;
    % Need to do something special for the bias
    %Bq(n+2, n+2) = 0;
end

function bb = build_lin_B(X, y, z, q, gamma)
    m = size(X, 1);
    n = size(X, 2);
    
    bb = zeros(n+2, 1);
    bb(1) = -1;
end

function ub = bias_bound_upper(X, y, z, q, gamma)
    m = size(X, 1);
    n = size(X, 2);
    
    ub = zeros(n+2, 1);
    ub(n+2) = 1.0;
end

function lb = bias_bound_lower(X, y, z, q, gamma)
    m = size(X, 1);
    n = size(X, 2);
    
    lb = zeros(n+2, 1);
    lb(n+2) = -1.0;
end

function Mb = bias_matrix_bound(X, y, z, q, gamma)
    m = size(X, 1);
    n = size(X, 2);
    
    Mb = zeros(n+2, n+2);
    Mb(n+2, n+2) = 1.0;
end

function [val_opt, w_opt, W_opt] = evaluate_Fq(X, y, z, q, gamma, bias_upper)
    Aq = build_quad_A(X, y, z , q, gamma);
    ab = build_lin_A(X, y, z, q, gamma);
    ac = build_const_A(X, y, z, q, gamma);

    Bq = build_quad_B(X, y, z, q, gamma);
    bb = build_lin_B(X, y, z, q, gamma);

    ub = bias_bound_upper(X, y, z, q, gamma);
    lb = bias_bound_lower(X, y, z, q, gamma);
    Mb = bias_matrix_bound(X, y, z, q, gamma);
    
    m = size(X, 1);
    n = size(X, 2);
    
    cvx_begin quiet sdp
        variable W(n+2, n+2) symmetric
        variable w(n+2)
        minimise ( trace(Aq*W) + 2*ab.'*w + ac )
        trace(Bq*W) + 2*bb.'*w == 0;
        [W w, ; w.' 1 ] >= 0; 
    cvx_end
    
    %w_opt = w(2:n+2);
    w_opt = w;
    W_opt = W;
    
    val_opt = cvx_optval;
end

function [w_opt, W_opt, B] = bisect(X,y, z, gamma, bias_upper, q_start, tol)
    a = 0.0;
    b = q_start;
    [opt, w_opt, W_opt] = evaluate_Fq(X, y, z, b, gamma, bias_upper);
    Bq = build_quad_B(X, y, z, b, gamma);
    bb = build_lin_B(X, y, z, b, gamma);
    B = [Bq, bb; bb.', 0];
    
    while (b -a) > tol
        q_test = (a+b) / 2;
        [Fq, w_q, W] = evaluate_Fq(X, y, z, q_test, gamma, bias_upper);
        %[V, D] = eig(W_opt);
        %q = V(2:end, end);
        %size(q)
        %lam = D(end, end);
        %w_opt = -sqrt(lam)*q;
        
        if Fq >= 0
            a = q_test;
        else
            b = q_test;
            w_opt = w_q;
            W_opt = W;
            Bq = build_quad_B(X, y, z, q_test, gamma);
            bb = build_lin_B(X, y, z, q_test, gamma);
            B = [Bq, bb; bb.', 0];
            %[V, D] = eig(W);
            %q = V(2:end, end);
            %lam = D(end, end);
            %w_opt = -sqrt(lam)*q;
        end
    end
end

function w_opt = ridge_regression(X, y, nu)
    m = size(X, 1);
    n = size(X, 2);
    
    cvx_begin quiet
        variable w(n)
        variable b
        minimise ( sum_square_abs(X*w + b - y) + nu*w.'*w ) 
    cvx_end
    
    w_opt = [w; b];
end

function loss = compute_loss(X, y, z, gamma, w)
    m = size(X, 1);
    n = size(X, 2);
    wv = w(1: n);
    b = w(n+1);
    X_fake = (z*wv.' - b*ones(m, 1)*wv.' + gamma*X)*inv((wv*wv.' + gamma*eye(n)));
    loss = (norm(X_fake*wv + b - y)^2)/m;
end

function loss = compute_loss_no_bias(X, y, z, gamma, w)
    m = size(X, 1);
    n = size(X, 2);
    X_fake = (z*w.' + gamma*[X ones(m, 1)])*inv((w*w.' + gamma*eye(n+1)));
    loss = (norm(X_fake*w - y)^2)/m;
end

function obj = bruckner_objective(X, y, z, gamma)
    m = size(X, 1);
    n = size(X, 2);
    
    function [loss, g] = bruckner(v)
        w = reshape(v(1:n), n, 1);
        tau = v(n+1: n+m);
        tau = reshape(tau, m, 1);
        inner = X*w + w.'*w.*tau - y;
        loss = norm(inner, 2)^2;
        
        if nargout  > 1
            inner = X*w + (w.'*w)*tau - y;
            gw = 2*X.'*inner + 2*inner.'*tau*w + 2*tau.'*inner*w;
            gt = 2*(w.'*w)*inner;
            g = [gw; gt];
        end
    end
    
    obj = @bruckner;
end

function nonlincon = bruckner_constraints(X, y, z, gamma)
    m = size(X, 1);
    n = size(X, 2);
    
    function [c, ceq, gc, gceq] = bruckner(v)
        w = reshape(v(1:n), n , 1);
        tau = v(n+1: n+m);
        tau = reshape(tau, m, 1);
        inner_term = X*w + w.'*w.*tau;
        ceq = 2*(m/gamma)*(inner_term - z) + tau;
        c = [];
        if nargout > 2
            gc = [];
            gceq_w = 2*(m /gamma)*X + 4*(m/gamma)*tau*w.';
            gceq_tau = eye(m) + 2*(m/gamma)*w.'*w*eye(m);
            gceq = [gceq_w, gceq_tau].';
        end
    end

    nonlincon = @bruckner;
end

function w = bruckner_method(X, y, z, gamma)
    m = size(X, 1);
    n = size(X, 2);
    obj = bruckner_objective([X, ones(m, 1)], y, z, gamma);
    nonlincon = bruckner_constraints([X, ones(m, 1)], y, z, gamma);
    v0 = zeros(n+1+m, 1);
    hessian_finder = get_hessian_finder([X, ones(m, 1)], y, z, gamma);
    %options = optimoptions('fmincon','Algorithm','interior-point', 'MaxFunctionEvaluations', 1e+10, 'Display', 'iter', 'HessianFcn', hessian_finder, 'SpecifyConstraintGradient', true,'SpecifyObjectiveGradient', true);
    %options = optimoptions('fmincon','Algorithm','interior-point', 'MaxFunctionEvaluations', 1e+10, 'Display', 'iter',  'SpecifyConstraintGradient', true,'SpecifyObjectiveGradient', true);
    % options = optimoptions('fmincon','Algorithm','interior-point', 'MaxFunctionEvaluations', 1e+10, 'Display', 'iter','SpecifyConstraintGradient', true,'SpecifyObjectiveGradient', true, 'CheckGradients', true, 'FiniteDifferenceType', 'central');
    options = optimoptions('fmincon','Algorithm','interior-point', 'MaxFunctionEvaluations', 1e+10, 'Display', 'off', 'SpecifyConstraintGradient', false,'SpecifyObjectiveGradient', false, 'OptimalityTolerance', 1e-2);
    x_opt = fmincon(obj, v0, [], [], [], [], [], [], nonlincon, options);
    w = x_opt(1:n+1);
end

function [bisect_results, bruckner_results, ridge_results] = gather_results(X, y, z, q_start, bias_upper, tol, gamma_range, nu_range)
    m = size(X, 1);
    num_gamma = size(gamma_range, 2);
    num_nu = size(nu_range, 2);
    CVO = cvpartition(m, 'Kfold', 10);
    
    % flatten out
    simSpace = [10, num_gamma, num_nu];
    numSims = prod(simSpace);
    ridge_results = zeros(numSims, 1);

    for i=1:10
        trIdx = CVO.training(i);
        teIdx = CVO.test(i);
        X_tr = X(trIdx, :);
        y_tr = y(trIdx);
        z_tr = z(trIdx);
        X_te = X(teIdx, :);
        y_te = y(teIdx);
        z_te = z(teIdx);
        n = size(X_tr, 2);
        
        fprintf('Doing ridge regression\n');
        ridge_params = ridge(y_tr, X_tr, nu_range, 0);
        % Need to swap first and last rows of ridge params as constant is
        % put at top of matrix
        row_1 = ridge_params(1, :);
        row_n1 = ridge_params(n+1, :);
        ridge_params(1, :) = row_n1;
        ridge_params(n+1,:) = row_1;
        
        
        for j=1:num_gamma
            gamma = gamma_range(j);
            for k=1:num_nu
                wr = ridge_params(:, k);
                loss = compute_loss_no_bias(X_te, y_te, z_te, gamma, wr);
                lidx = sub2ind([10, num_gamma, num_nu], i, j , k);
                ridge_results(lidx) = loss;
            end
        end
        
    end

    % flatten out
    simSpace = [10, num_gamma];
    numSims = prod(simSpace);
    bisect_results = zeros(numSims, 1);
    bruckner_results = zeros(numSims, 1);
    
    for idx=1:numSims
        fprintf('bisect progress is: %i\n', idx);
        [i, j] = ind2sub(simSpace, idx);
        trIdx = CVO.training(i);
        teIdx = CVO.test(i);
        X_tr = X(trIdx, :);
        y_tr = y(trIdx);
        z_tr = z(trIdx);
        X_te = X(teIdx, :);
        y_te = y(teIdx);
        z_te = z(teIdx);
        gamma = gamma_range(j);
        [w_s, W_s, B] = bisect(X_tr , y_tr, z_tr, gamma, bias_upper, q_start, tol);
        W = [W_s, w_s; w_s.', 1];
%         disp(rank(W_s))
%         [V, D] = eig([W_s, w_s; w_s.', 1]);
%         disp(V)
        w_s = w_s(2:end);
        w_opt = rank_decompose(B, W);
        %disp('Before rank decomposition');
        %disp(w_s)
        %disp('after rank decomposition');
        %disp(w_opt)
        w_opt = w_opt(2:end-1);
        %bisect_results(idx) = compute_loss_no_bias(X_te, y_te, z_te, gamma, w_s);
        bisect_results(idx) = compute_loss_no_bias(X_te, y_te, z_te, gamma, w_opt);
    end
    
    parfor_progress(numSims);
    parfor idx=1:numSims
        [i, j] = ind2sub(simSpace, idx);
        trIdx = CVO.training(i);
        teIdx = CVO.test(i);
        X_tr = X(trIdx, :);
        y_tr = y(trIdx);
        z_tr = z(trIdx);
        X_te = X(teIdx, :);
        y_te = y(teIdx);
        z_te = z(teIdx);
        gamma = gamma_range(j);
        w_b = bruckner_method(X_tr, y_tr, z_tr, gamma);
        bruckner_results(idx) = compute_loss_no_bias(X_te, y_te, z_te, gamma, w_b);
        parfor_progress;
    end
    parfor_progress(0);   
end

function map_values = heatmap(gamma, q_start)
    X_tr = randn(20, 2);
    y_tr = X_tr*[0.5; 0.5] + 1.0;
    X_te = randn(50, 2);
    y_te = X_te*[0.5; 0.5] + 1.0;
    
    w1_list = linspace(-5, 5, 32);
    w2_list = linspace(-5, 5, 32);
    
    simSpace = [32, 32];
    numSims = prod(simSpace);
    bruckner_values = zeros(numSims, 1);
    bisect_values = zeros(numSims, 1);
    
    
    for idx=1:numSims
        [i, j] = ind2sub(simSpace, idx);
        z_tr = X_tr*[w1_list(i); w2_list(j)] + 1.0;
        z_te = X_te*[w1_list(i); w2_list(j)] + 1.0;
        w_b = bruckner_method(X_tr, y_tr, z_tr, gamma);
        bruckner_loss = compute_loss(X_te, y_te, z_te, gamma, w_b);
        bruckner_values(idx)= bruckner_loss;
        fprintf('In heatmap bruck %d \n', idx);
    end
    
    for idx=1:numSims
        [i, j] = ind2sub(simSpace, idx);
        z_tr = X_tr*[w1_list(i); w2_list(j)] + 1.0;
        z_te = X_te*[w1_list(i); w2_list(j)] + 1.0;
        w_s = bisect(X_tr,y_tr, z_tr, gamma, 10.0, q_start, 0.001);
        bisect_loss = compute_loss(X_te, y_te, z_te, gamma, w_s);
        bisect_values(idx)= bisect_loss;
        fprintf('In heatmap bisect %d \n', idx);
    end
    
    map_values = bruckner_values - bisect_values;
end

function map_values = heatmap_ridge(gamma, q_start)
    X_tr = randn(20, 2);
    y_tr = X_tr*[0.5; 0.5] + 1.0;
    X_te = randn(50, 2);
    y_te = X_te*[0.5; 0.5] + 1.0;
    
    w1_list = linspace(-5, 5, 32);
    w2_list = linspace(-5, 5, 32);
    
    simSpace = [32, 32];
    numSims = prod(simSpace);
    bisect_values = zeros(numSims, 1);
    
    
    for idx=1:numSims
        [i, j] = ind2sub(simSpace, idx);
        z_tr = X_tr*[w1_list(i); w2_list(j)] + 1.0;
        z_te = X_te*[w1_list(i); w2_list(j)] + 1.0;
        w_s = bisect(X_tr,y_tr, z_tr, gamma, 10.0, q_start, 0.001);
        bisect_values(idx) = compute_loss(X_te, y_te, z_te, gamma, w_s);
        fprintf('In heatmap bisect %d \n', idx);
    end
    
    simSpace = [32, 32];
    numSims = prod(simSpace);
    ridge_values = zeros(numSims, 8);
    nu_range = [1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100 ,1000];
    
    for idx=1:numSims
        [i, j] = ind2sub(simSpace, idx);
        z_tr = X_tr*[w1_list(i); w2_list(j)] + 1.0;
        z_te = X_te*[w1_list(i); w2_list(j)] + 1.0;
        for j=1:8
            nu = nu_range(j);
            w_r = ridge_regression(X_tr, y_tr, nu);
            ridge_loss = compute_loss(X_te, y_te, z_te, gamma, w_r);
            ridge_values(idx, j)= ridge_loss;
            fprintf('In heatmap ridge %d \n', idx);
        end
    end
    
    ridge_values = min(ridge_values, 2);
    
    map_values = ridge_values - bisect_values;
end

function hessian_obj = get_obj_hessian(X, y, z, gamma)
    m = size(X, 1);
    n = size(X, 2);
    function hessian = hessian_obj_inner(v)
        w = v(1:n);
        tau = v(n+1:m+n);
        dwdw = 2*X.'*X + 4*X.'*tau*w.' + 4*w*tau.'*X + 8*tau.'*tau*w*w.' + 4*(X*w + w.'*w*tau - y).'*tau*eye(n);
        dtdt = 2*(w.'*w)*(w.'*w)*eye(m);
        dwdt = 2*X'*(w.'*w) + 4*(w.'*w)*w*tau.' + 4*w*(X*w + (w.'*w)*tau - y).';
        
        hessian = [ dwdw, dwdt; dwdt.', dtdt];
    end

    hessian_obj = @hessian_obj_inner;
end

function hessian_ci = get_hessian_cons(X, y, z, gamma)
    m = size(X, 1);
    n = size(X, 2);
    function hessian_c = hessian_ci_inner(v, i)
        w = v(1:n);
        tau = v(n+1:m+n);
        t = tau(i);
        dwdw = 4*(m / gamma)*t*eye(n);
        dwdt = zeros(n, m);
        dwdt(:, i) = 2*(m / gamma)*w;
        dtdt = zeros(m, m);
        
        hessian_c = [dwdw, dwdt; dwdt.', dtdt];
    end
    hessian_ci = @hessian_ci_inner;
end

function hessian_finder = get_hessian_finder(X, y, z, gamma)
    hessian_obj = get_obj_hessian(X, y, z, gamma);
    hessian_ci = get_hessian_cons(X, y, z, gamma);
    m = size(X, 1);
    n = size(X, 2);
    
    function hessian = find_hessian(v, lambda)
        hessian = hessian_obj(v);
        le = lambda.eqnonlin;
        for i=1:m
            hessian = hessian + le(i)*hessian_ci(v, i);
        end
    end

    hessian_finder = @find_hessian;
    
end

function [bruck_mr, bruck_sd, bisect_mr, bisect_sd] = get_timings(X, y, z, gamma, q_start, tol, size_range, bias_upper)
    num_sizes = size(size_range, 2);
    bruck_mr = zeros(num_sizes, 1);
    bruck_sd = zeros(num_sizes, 1);
    bisect_mr = zeros(num_sizes, 1);
    bisect_sd = zeros(num_sizes, 1);
    
    for i=1:num_sizes
        fprintf('Starting a dataset size\n');
        
        m = size_range(i);
        bruck_m = zeros(1, 10);
        bisect_m = zeros(1, 10);
        
        for j=1:10
            perm = randperm(size(X, 1));
            X_shuffle = X(perm, :);
            y_shuffle = y(perm, :);
            z_shuffle = z(perm, :);
            
            X_tr = X_shuffle(1:m , :);
            y_tr = y_shuffle(1:m);
            z_tr = z_shuffle(1:m);
            tic;
            bisect(X_tr,y_tr, z_tr, gamma, bias_upper, q_start, tol);
            bisect_m(1,j) = toc;
            tic;
            bruckner_method(X_tr, y_tr, z_tr, gamma);
            bruck_m(1,j) = toc;
        end
        
        bruck_mr(i, 1) = mean(bruck_m);
        bruck_sd(i, 1) = std(bruck_m);
        bisect_mr(i, 1) = mean(bisect_m);
        bisect_sd(i, 1) = std(bisect_m);
    end
end

function w_opt = rank_decompose(M, W)
    n = size(W, 1);
    q_forms = ones(n, 1);
    X_new = W;
    [V, D] = eig(X_new);
    u = zeros(n, n);
    
    for i =1:n
        u(:, i) = V(:, i)*sqrt(D(i, i));
        q_forms(i) = u(:, i).'*M*u(:, i);
    end
    
    while(~(q_forms <= 0))
        pos_idx = -1;
        neg_idx = -1;
        
        for i = 1:n
            if(q_forms(i) > 0 &&  pos_idx == -1)
                pos_idx = i;
            elseif(q_forms(i) < 0 && neg_idx == -1)
                neg_idx = i;
            end      
        end
        
        a = q_forms(pos_idx);
        b = 2*u(:, pos_idx).'*M*u(:, neg_idx);
        c = q_forms(neg_idx);
        t = (-b + sqrt(b*b - 4*a*c))/(2*a);
        u_pos = (t / sqrt(t*t + 1))*u(pos_idx) + (1 / sqrt(t*t + 1))*u(neg_idx);
        u_neg = - (1 / sqrt(t*t + 1))*u(pos_idx) + (t / sqrt(t*t + 1))*u(neg_idx);
        u(pos_idx) = u_pos;
        u(neg_idx) = u_neg;
        q_forms(pos_idx) = u(pos_idx).'*M*u(pos_idx);
        q_forms(neg_idx) = u(neg_idx).'*M*u(neg_idx);
    end  
    w_opt = u(:, end);
    %disp('new matrix');
    %disp(u)
end















