from mpi4py import MPI
import numpy as np
import math
import random
from numpy import linalg as LA
import scipy.io as sio
import sys
import sklearn.datasets
import time





def cost(x,y,th,alpha,num_data):
	pro = sigmoid(np.dot(x,th))
	result = sum(-y*np.log(pro) - (1-y)*np.log(1-pro))   
	result = result +(alpha/2)*(LA.norm(th[0:len(th)-1],2)**2)
	return result/num_data

def sigmoid(a):
	return 1.0/(1+np.exp(-a))

def all_one_stack(data):
    b=np.ones( (data.shape[0],1)    )
    return np.hstack( (data,  b ) )

def local_update(x,y,alpha,diction,num_data,dim):
	w=diction['w']
	grad=diction['grad']
	xTrans = x.transpose()                                      
	sig = sigmoid(np.dot(x,w)) 
 
	S= np.dot(np.diag(sig) , np.diag(1 - sig) )
	SX=np.dot(S,x)
	Hess= np.dot(xTrans,SX)+ (alpha*np.eye(dim+1))
	Hess=Hess/num_data
    
	update= -np.dot(np.linalg.pinv(Hess),grad)  
   
	return update

def error_test(X,y,w):
    y_prob= sigmoid(np.dot(X,w))
    y_pred=np.zeros(len(y))
    y_pred[np.where(y_prob>0.5)]=1
    error= sum(y.flatten()==y_pred)/len(y)
    return error*100




def byz_res(M,n_serv,n_byz):
        temp_l2=np.zeros(n_serv)
        for serv in range(n_serv):
                temp_l2[serv]=LA.norm(M[:,serv],2)
                ordered_set=np.argsort(temp_l2)
                thres=int(n_serv-n_byz-2)
                choose_set=ordered_set[0:thres]
        return np.sum(M[:,choose_set],axis=1)

def local_grad(X,y,w,alpha,num_data):
	xTrans = X.transpose() 
	sig = sigmoid(np.dot(X,w))
#	print ("y:",y)
	grad = np.dot(xTrans, ( sig - y.flatten() ))   + (alpha*w) 
	grad = grad /num_data
	return grad



# Load Data using sklearn 


X1,y1=sklearn.datasets.load_svmlight_file('w5a')


#Sample data for covtype 
X2=np.array(X1.todense())
n1,d=X2.shape
y1=np.reshape(y1,(n1,1))
y2=np.copy(y1)
y2[np.where(y2==-1)]=0
number_data= len(y2)

# Create separate test and training samples
n_train= int(n1)
n_test= n_train
X_train= X2 
X_test = X2 
y_train= y2 
y_test= y2 



comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
status = MPI.Status()



number_server=size-1
data_per_server= n_train//number_server +1
dim= d
MaxIter=30
learning_rate=.01
iter_num=int(sys.argv[1])
rept=1
alpha=1
Result=np.zeros(MaxIter)
Error=np.zeros(MaxIter)
num_data=data_per_server
TIME=np.zeros(MaxIter)


frac_byz=float(sys.argv[2])
num_byzantine= int(number_server*frac_byz)

byz= np.arange(num_byzantine)


comm.Barrier()

delay=10

if rank==0:
#	pass
        w= np.ones(dim+1)/10
#       recvdata = np.empty(4,dtype=np.float64)
else:
        w=np.empty(dim+1,dtype=np.float64)
        local_dict={}
t1=time.time()
#w=np.ones(dim)/10
for i in range (MaxIter):
	w=comm.bcast(w,root=0)
	st=int ((rank-1)*data_per_server)
	fin=int (rank*data_per_server)
	if fin> n_train:
		fin=n_train
	global_grad=local_grad(all_one_stack( X_train[st:fin,:]),y_train[st:fin].flatten(),w,alpha,data_per_server)
	recvdata=comm.gather(global_grad,root=0)
	temp=np.array(recvdata)
	grad_red=np.sum(temp,axis=0)
	if rank==0:

		grad_red=grad_red/number_server

	root_dict={}
	comm.Barrier()	
	root_dict['grad']=grad_red
	root_dict['w']=w
	local_dict= comm.bcast(root_dict,root=0)
	st=int ((rank-1)*data_per_server)
	fin=int (rank*data_per_server)
	if fin> n_train:
		fin=n_train
	newton_update=local_update(all_one_stack( X_train[st:fin,:]),y_train[st:fin].flatten(),alpha,local_dict,data_per_server,dim)
	recvnewton= comm.gather(newton_update,root=0)
	temp_newton=np.array(recvnewton)


	if rank==0:

		M=temp_newton[1:size,:].T
		for serv in  byz:
  
			M[:,serv]= -(0.9*M[:,serv])
#               time.sleep(delay)
		global_red=byz_res(M,number_server,num_byzantine)
		TIME[i]=time.time()-t1
		Result[i]=cost(all_one_stack( X_train),y_train.flatten(),w,alpha,n_train)
		Error[i]=error_test(all_one_stack( X_test),y_test,w)
		TIME[i]=time.time()-t1			
		t1=time.time()
		w=w+ (learning_rate*global_red)
a={}
a['result']=Result
a['error']=Error
a['TIME']=TIME
a['num_serv']=number_server
fname='GIANT_w5a_neg_thres'+'_iter_'+str(iter_num)+'byz'+str(num_byzantine) 
sio.savemat(fname,a)

