from mpi4py import MPI
import numpy as np
import math
import random
from numpy import linalg as LA
import scipy.io as sio
import sys
import sklearn.datasets
import time


def cost(x,y,th,alpha,num_data):
    pro = sigmoid(np.dot(x,th))
    result = sum(-y*np.log(pro) - (1-y)*np.log(1-pro))   
    result = result +(alpha/2)*(LA.norm(th[0:len(th)-1],2)**2)
    return result/num_data

def sigmoid(a):
    return 1.0/(1+np.exp(-a))

def all_one_stack(data):
    b=np.ones( (data.shape[0],1)    )
    return np.hstack( (data,  b ) )


def QSGD1(g):
  l= len(g)
  l1_value=LA.norm(g,2)
  l2_value=l1_value
  prob= np.abs(g)/l2_value
  temp= np.zeros(l)
  temp=g
  #temp[np.where(g==0)]=1
  quantized_vector=np.multiply(np.sign(temp),(np.random.uniform(size=l)<prob)*1 )
  return l2_value, quantized_vector

def error_test(X,y,w):
    y_prob= sigmoid(np.dot(X,w))
    y_pred=np.zeros(len(y))
    y_pred[np.where(y_prob>0.5)]=1
    error= sum(y.flatten()==y_pred)/len(y)
    return error*100

def local_update(x,y,w,alpha,num_data,dim):
    xTrans = x.transpose()                                      
    sig = sigmoid(np.dot(x,w)) 
   # print("sig",(sig-y).shape)
    grad = np.dot(xTrans, ( sig - y ))   + (alpha*w)                       
    grad = grad /num_data 
    S= np.dot(np.diag(sig) , np.diag(1 - sig) )
    SX=np.dot(S,x)
    Hess= np.dot(xTrans,SX)+ (alpha*np.eye(dim+1))
    Hess=Hess/num_data
    
    update= -np.dot(np.linalg.pinv(Hess),grad)
    B,qvec=QSGD1(update)
    return B*qvec


def byz_res(M,n_serv,n_byz):
        temp_l2=np.zeros(n_serv)
        for serv in range(n_serv):
                temp_l2[serv]=LA.norm(M[:,serv],2)
                ordered_set=np.argsort(temp_l2)
                thres=int(n_serv-n_byz-2)
                choose_set=ordered_set[0:thres]
        return np.sum(M[:,choose_set],axis=1)



# Load Data using sklearn 


X1,y1=sklearn.datasets.load_svmlight_file("w5a")
#Sample data for covtype 
X2=np.array(X1.todense())
n1,d=X2.shape
y1=np.reshape(y1,(n1,1))
y2=np.copy(y1)
y2[np.where(y2==-1)]=0
number_data= len(y2)

# Create separate test and training samples
n_train= int(n1)
n_test=  n_train
X_train= X2 
X_test = X2 
y_train= y2 
y_test= y2 

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
status = MPI.Status()



number_server=size-1
data_per_server= n_train//number_server +1
dim= d
MaxIter=30
learning_rate=.005
iter_num=int(sys.argv[1])
rept=1
alpha=1

frac_byz=float(sys.argv[2])
num_byzantine= int(number_server*frac_byz)

byz= np.arange(num_byzantine)


Result=np.zeros(MaxIter)
Error=np.zeros(MaxIter)
num_data=data_per_server
TIME=np.zeros(MaxIter)
delay=10
comm.Barrier()




#comm.Barrier()


if rank==0:
#       pass
	w=np.ones(dim+1)/10
#       recvdata = np.empty(4,dtype=np.float64)
else:
	w=np.empty(dim+1,dtype=np.float64)
	local_dict={}


t1=time.time()
for i in range (MaxIter):
	w=comm.bcast(w,root=0)
	st=int ((rank-1)*data_per_server)
	fin=int (rank*data_per_server)
	if fin> n_train:
		fin=n_train
	global_update=local_update(all_one_stack( X_train[st:fin,:]),y_train[st:fin].flatten(),w,alpha,data_per_server,dim)
	recvdata=comm.gather(global_update,root=0)
	temp=np.array(recvdata)
	#global_red=np.sum(temp,axis=0)
	if rank==0:
		# Adversary attack
		M=temp[1:size,:].T
		for serv in  byz:
			
			M[:,serv]= -(0.9*M[:,serv])  	

		global_red=byz_res(M,number_server,num_byzantine)
		global_red=global_red/1

		w= w + (learning_rate*global_red) 
		TIME[i]=time.time()-t1
		Result[i]=cost(all_one_stack(X_train),y_train.flatten(),w,alpha,number_data)
		Error[i]=error_test(all_one_stack( X_test),y_test,w)

		#TIME[i]=time.time()-t1
		t1=time.time()

a={}
a['result']=Result
a['error']=Error
a['TIME']=TIME
a['num_serv']=number_server
fname='Oneround_w5a_neg_comp_'+'iter_'+str(iter_num)+'byz'+str(num_byzantine)
sio.savemat(fname,a)







