%% Q-Whittle Learning with linear function approximation
clear all

% learning parameters

epsilon = 0.5;  % exploration probability (1-epsilon = exploit / epsilon = explore)
% states
state = [1,2,3,4];
% actions
action = [0,1];
% initial a feature matrix
Phi=zeros(length(state)*length(action));
X=symdec(length(state)*length(action),1);
[Phi, v]=eig(X); %F is the feature matrix, which is unitary

m=300; % m is the depth of Neural network function
b=ones(m,1);
random_num=b(randperm(m,m/2));
b(random_num)=-1;

theta_0=0.1*randn(m*length(state)*length(action),1);
theta=theta_0;
K = 10000;     % maximum number of the iterations
 

W=zeros(4,K);

%% the main loop of the algorithm
MC=100;




for s=1:1:length(state) 
    

Wmc=zeros(MC, K);

for mc=1:MC
    %theta=randn(m*length(state)*length(action),1);
    lambda=0; % Initialize whittle index as 0
    f=zeros(8,1);
for k = 1:K
    disp(['iteration: ' num2str(k)]);

    if k==1
    state_idx =s;
    end
    
   

   f((state_idx-1)*2+1)=0;
   f((state_idx-1)*2+2)=0;
    
    Wr=reshape(theta,length(state)*length(action),m);
    Wr_0=reshape(theta_0,length(state)*length(action),m);
    for r=1:m

    f((state_idx-1)*2+1)=f((state_idx-1)*2+1)+1/sqrt(m)*b(r)*Phi((state_idx-1)*2+1,:)*Wr(:,r)*(Phi((state_idx-1)*2+1,:)*Wr_0(:,r)>0);
    f((state_idx-1)*2+2)=f((state_idx-1)*2+2)+1/sqrt(m)*b(r)*Phi((state_idx-1)*2+2,:)*Wr(:,r)*(Phi((state_idx-1)*2+2,:)*Wr_0(:,r)>0);

%     f((state_idx-1)*2+1)=f((state_idx-1)*2+1)+1/sqrt(m)*b(r)*Phi((state_idx-1)*2+1,:)*Wr(:,r);
%     f((state_idx-1)*2+2)=f((state_idx-1)*2+2)+1/sqrt(m)*b(r)*Phi((state_idx-1)*2+2,:)*Wr(:,r);
    end

    r0=rand; % get 1 uniform random number
    x=sum(r0>=cumsum([0, 1-epsilon, epsilon])); % check it to be in which probability area

    if x == 1   % exploit
        [~,umax]=max(max(f((state_idx-1)*2+1), f((state_idx-1)*2+2)));
        current_action = action(umax);
    else        % explore
        current_action=datasample(action,1); % choose 1 action randomly (uniform random distribution)
    end
    
    
    
    action_idx = find(action==current_action); % id of the chosen action
    % observe the next state and next reward ** there is no reward matrix
   
  
    
    [next_state,next_reward] = model(state(state_idx),action(action_idx), lambda);
    next_state_idx = find(state==next_state);  % id of the next state
    % print the results in each iteration
    disp(['current state : ' num2str(state(state_idx))  ' taken action : ' num2str(action(action_idx)) ' next state : ' num2str(state(next_state_idx))]);
    disp([' next reward : ' num2str(next_reward)]);
    % update the Q matrix using the Q-learning rule
    
    gradient=[];
    for r=1:m
    gradient=[gradient, 1/sqrt(m)*b(r)*(Phi((state_idx-1)*2+action_idx,:)*Wr_0(:,r)>0)*Phi((state_idx-1)*2+action_idx,:)];
   %gradient=[gradient, 1/sqrt(m)*b(r)*Phi((state_idx-1)*2+action_idx,:)];
    end
    
    theta = theta + 1/(k^(1/2))* gradient'*(next_reward +  max(f((next_state_idx-1)*2+1), f((next_state_idx-1)*2+2)) -1/8*sum(f)- f((state_idx-1)*2+action_idx));
    
   
    lambda=lambda+0.1/(k^(1/2))*(f((s-1)*2+2)-f((s-1)*2+1));
  

    state_idx = next_state_idx;
   
  
    disp(lambda)
    Wmc(mc,k)=lambda;
   
end
  




end

for k=1:K
W(s,k)=mean(Wmc(:,k));

end

end
 
 
figure;
plot(10:10:K, W(1,10:10:K), 10:10:K, W(2,10:10:K),10:10:K, W(3,10:10:K), 10:10:K, W(4,10:10:K));
legend("s1", "s2", "s3", "s4");


%%

%% This function is used as an observer to give the next state and the next reward using the current state and action
function [next_state,r] = model(s,a, lambda)
 t=rand;
if s==1 && a==0
    next_state=s+3*(t>0.5);
    r=-1+(1-a)*lambda;
elseif s==1 && a==1
    next_state=s+1*(t>0.5);
    r=-1+(1-a)*lambda;
elseif s==2 && a==0
    next_state=s-1*(t>0.5);
    r=0+(1-a)*lambda;
elseif s==2 && a==1
    next_state=s+1*(t>0.5);
    r=0+(1-a)*lambda;   
elseif s==3 && a==0
    next_state=s-1*(t>0.5);
    r=0+(1-a)*lambda;
elseif s==3 && a==1
    next_state=s+1*(t>0.5);
    r=0+(1-a)*lambda;   

elseif s==4 && a==0
    next_state=s-1*(t>0.5);
    r=1+(1-a)*lambda;
elseif s==4 && a==1
    next_state=s-3*(t>0.5);
    r=1+(1-a)*lambda;   
end

end