import numpy as np
import random as rd
from os import listdir
import time
import networkx as nx
import cProfile
import hashlib as hl
import matplotlib.pyplot as plt
import pickle

rd.seed(12345)


#http://tda.gatech.edu/papers/deveci13-europar.pdf

#http://software.clapper.org/munkres/
# scipy.optimize.linear_sum_assignment(cost_matrix)[source]¶

def EstimateSet(estimator,f,X, S,size,samples):
    fS = estimator(f,S)
    if len(X)==0:
        return 0
    '''
    Rset = []
    while len(Rset)<samples:
        Rand = np.random.choice(list(X),size*samples*1.5)
        Rtmp = [set(Rand[size*i:size*(i+1)]) for i in range(len(Rand)/size)]
        Rset = Rset + [i for i in Rtmp if len(i)==size]
    '''
    Rand = np.random.choice(list(X),int(size*samples*1.5))
    print([i for i in range(int(len(Rand)/size))])
    Rtmp = [set(Rand[size*i:size*(i+1)]) for i in range(int(len(Rand)/size))]
    Rset = [i for i in Rtmp if len(i)==size]                   
    return np.mean([estimator(f,S.union(R)) for R in Rset[:samples]]) - fS
    
def EstimateMarginal(estimator,f,X, S,a,size,samples):
    #Rset = [rd.sample(X,size) for m in range(samples)]
    '''
    Rset = []
    while len(Rset)<samples:
        Rand = np.random.choice(list(X),size*samples*2)
        Rtmp = [set(Rand[size*i:size*(i+1)]) for i in range(len(Rand)/size)]
        Rset = Rset + [i for i in Rtmp if len(i)==size]
    '''
    Rand = np.random.choice(list(X),int(size*samples*1.5))
    Rtmp = [set(Rand[size*i:size*(i+1)]) for i in range(len(Rand)/size)]
    Rset = [i for i in Rtmp if len(i)==size]
    #print(['EstimateMarginal',len(Rset),len(Rtmp),len(Rand),size,samples])
    return np.mean([estimator(f,S.union(list(R)+[a]))-estimator(f,(S.union(R)).difference([a])) for R in Rset])

def Filter(estimator,f,X, S,size,m,r):
    fS = estimator(f,S)
    while EstimateSet(estimator,f,X, S,size,m) < (1-epsilon)*(OPT-fS)/r:
        vsxa = [EstimateMarginal(estimator,f,X, S,a,size,m) for a in X]
        lX = list(X)
        A = [lX[i] for i in range(len(lX)) if vsxa[i]<(1+epsilon/2)*(1-epsilon)*(OPT-fS)/k]
        X = X.difference(A)
    return X

def Filter2(estimator,f,X, S,size,m,r):
    fS = estimator(f,S)
    while EstimateSet(estimator,f,X, S,size,m) < (1-epsilon)*(OPT-fS)/r:
        print('filtering')
        Rand = np.random.choice(list(X),int(size*m*1.5))
        Rtmp = [set(Rand[size*i:size*(i+1)]) for i in range(len(Rand)/size)]
        Rset = [i for i in Rtmp if len(i)==size]
        lX = list(X)
        print(['filtering2',(1+epsilon/2)*(1-epsilon)*(OPT-fS)/k])
        vsxa = [np.mean([estimator(f,S.union(list(R)+[a]))-estimator(f,(S.union(R)).difference([a])) for R in Rset]) for a in lX]
        print(['filtering3'])
        A = [lX[i] for i in range(len(lX)) if vsxa[i]<(1+epsilon/2)*(1-epsilon)*(OPT-fS)/k]
        print(['filtering4',vsxa,A])
        X = X.difference(A)
        print(len(X),len(A))
    return X
'''
#algorithm 5
def AmortizedFiltering(estimator,f, r,epsilon,OPT,m):
    size = int(k/r)
    S = set()
    fS=0
    sN = set(N)
    for epoch in range(int(20/epsilon)):
        X = sN.copy()
        T = set()
        vsX = EstimateSet(estimator,f,X, S,size,m) #T????????????????????? see original paper!!!!!!!!!!!!!!!!!!
        print(vsX,(epsilon/20)*(OPT-estimator(f,S)))
        while vsX< (epsilon/20)*(OPT-estimator(f,S)) and len(S.union(T))<k:
            vsxa = [EstimateMarginal(estimator,f,X, S,a,size,m) for a in X] #.difference(S)????
            A = [X[i] for i in X if vsxa[i]<(1+epsilon/2)*(1-epsilon)*(OPT-fS)/k]
            X = X.difference(A)
            T = T.union(rd.sample(X,size))
            vsX = EstimateSet(estimator,f,X, S,size,m)
        S = S.union(T) #assert at most k!!!!!
        fS = estimator(f,S)
        print(fS,S)
    return S
'''
def AmortizedFiltering(estimator,f, r,epsilon,OPT,m):
    size = int(k/r)
    S = set()
    fS=0
    sN = set(N)
    for epoch in range(int(20/epsilon)):
        X = sN.copy()
        T = set()
        fS = estimator(f,S)
        while estimator(f,S.union(T)) - fS < (epsilon/20)*(OPT-fS) and len(S.union(T))<k:
            X = Filter(estimator,f,X, S,size,m,r)
            T = T.union(rd.sample(X,size))
        S = S.union(T) #assert at most k!!!!!
    print(('jadaptive',estimator(f,S),S))
    return S
    
'''    
def ALG(estimator,f,N,k,epsilon,Delta,OPT,maxfa):
    t=min(OPT/(epsilon*k),maxfa)
    sN = set(N)
    #rd.shuffle(sN)
    S = set()
    for j in range(Delta):
        X = sN.copy()
        c=0
        while len(X)>0:
            c=c+1
            sample_size = min([k-len(S),len(X)])
            A = rd.sample(X,sample_size)
            XX = [[a for a in X if diff(estimator,f,S.union(A[:i]),a)>=t] for i in range(sample_size)]
            r = [i for i in range(sample_size) if len(XX[i])<(1-epsilon)*len(X)]
            if len(r)==0:
                S = S.union(A)
                break
            ii = min(r)
            S = S.union(A[:ii])
            X = set(XX[ii])
            #print('bb',X,len(X),S)
            #print(R['s'],R['t'])
        print('j',j,t,estimator(f,S),S,c)
        t = t*(1-epsilon)
    return S    
'''
    
def ALG_nonpar(estimator,f,N,k,epsilon,Delta,OPT,maxfa):
    t=min(OPT/(epsilon*k),maxfa)
    S = set()
    c = 0
    for j in range(Delta):
        X = set(N)
        while len(X)>0:
            c=c+1
            if(len(S)==k):
                print('je',j,t,estimator(f,S),S,c)
                return S,c
            sample_size = min([k-len(S),len(X)])
            A = rd.sample(X,sample_size)
            base = [estimator(f,S.union(A[:i])) for i in range(sample_size)]
            
            small_size = min(2,sample_size)
            XX = [[a for a in X if estimator(f,S.union(A[:i]+[a]))-base[i]>=t] for i in range(small_size)]
            r = [i for i in range(small_size) if len(XX[i])<(1-epsilon)*len(X)]
            if len(r)>0:
                ii = min(r)
                S = S.union(A[:ii])
                X = set(XX[ii])
                continue
            
            XX = XX + [[a for a in X if estimator(f,S.union(A[:i]+[a]))-base[i]>=t] for i in range(small_size,sample_size)]
            r = [i for i in range(sample_size) if len(XX[i])<(1-epsilon)*len(X)]
            if len(r)==0:
                S = S.union(A)
                break
            ii = min(r)
            S = S.union(A[:ii])
            X = set(XX[ii])
            #print('bb',X,len(X),S)
            #print(R['s'],R['t'])
        #print('j',j,t,estimator(f,S),c)
        t = t*(1-epsilon)    
    return S,c    
    

    
def ALG_add(estimator,f,N,k):
    NN = [i for i in N]
    rd.shuffle(NN)
    res = [estimator(f,set([i])) for i in NN]
    arr = np.array([i for i in NN])
    ind = np.array(res).argsort()[-k:]
    arr = arr[ind]
    print('jjmaxmarginaltoempty',estimator(f,set(arr)),arr)
    return set(arr)

def ALG_bestsample(estimator,f,N,k,tries):
    S = set()
    mVAL = 0
    
    for i in range(tries):
        St = set(rd.sample(N,k))
        tmp = estimator(f,St)
        S = S if mVAL> tmp else St
        mVal = mVAL if mVAL> tmp else tmp
    '''
    SS = np.array([set(rd.sample(N,k)) for i in range(tries)])
    VV = np.array([estimator(f,S) for i in range(tries)])
    S = set(SS[VV.argmax()])
    '''
    print('jjbestsample',estimator(f,S),S)
    return S    
    
def ALG_greedy(estimator,f,N,k,rounds):
    S = set()
    for r in range(min(rounds,k)):
        res = np.array([estimator(f,S.union([i])) for i in N])
        arr = np.array([i for i in N])
        S = S.union([arr[res.argmax()]])
        #print('r',r,estimator(f,S),S)
    print('rrgreedy',estimator(f,S),S)
    return S 

def estimateNorm(f,S):
    norm = f['norm']
    if norm==-1:
        return max(f['vals'][s] for s in S)
    return np.power(max(f['vals'][s]**norm for s in S),1./norm)
    
def estimatelong(f,S):
    n1 = len([i for i in S if i<f['n1']])
    n2 = len([i for i in S if i>=f['n1'] and i<f['n2']])
    n3 = len([i for i in S if i>=f['n2'] and i<f['n3']])
    n4 = len([i for i in S if i>=f['n3'] and i<f['n4']])
    return 12*min(n1,f['Vn1']) + 9*min(n2,f['Vn2']) + 6*min(n3,f['Vn3']) + 4*min(n4,f['Vn4'])     

def estimate(f,S):
    nG = len([i for i in S if i<f['nG']])
    nB = len([i for i in S if i>=f['nG'] and i<f['nG']+f['nB']])
    nM = len(S)-nG-nB
    numM = min(nM,f['gV'])
    best = min(nG,f['gV']-numM)
    second_best = min(nB + nG - best,f['bV'])
    the_rest = nG + nB - second_best - best
    #print(nG,nB,nM,best,second_best,the_rest)
    return numM+ 2*best + 2*second_best + the_rest


def estimatesimple(f,S):
    #nG = len(S.intersection(f['G']))
    nG = len([i for i in S if i<f['nG']])
    #nB = len(S.intersection(f['B']))
    nB = len([i for i in S if i>=f['nG'] and i<f['nG']+f['nB']])
    return   min(nG,f['gV']) + min(nB,f['bV'])

def estimatNOoverlap(f,S):
    g=f['groups']
    #h=f['howmany']
    gs = f['n']/g
    
    r#eturn sum([max([0]+[int(s/10) for s in S if gs*i<=s and s<gs*(i+1)]) for i in range(g)])
    return sum([max([0]+[(i+1)*(len([s for s in S if gs*i<=s and s<gs*(i+1)])>0)]) for i in range(g)])

    
def generateNOoverlap(n,groups):
    #gSize = n/groups
    return {'n':n,'groups':groups,'howmany':groups}
    
    
def randomNorm(N,norm):
    f={}
    f['norm'] = norm
    f['vals'] = [rd.random() for i in N]
    return f

    
    
def randomlong(N,k):
    f={}
    n = len(N)
    n1 = k
    n2 = int((n-n1)/3) #int(np.log(n)**1)*k #int((n-n1)/3)
    n3 = int((n-n1)/3) # int(np.log(n)**2)*k #int((n-n1)/3)
    n4 = n - n3 - n2 - n1
    f['n1']=n1
    f['n2']=n1+n2
    f['n3']=n1+n2+n3
    f['n4']=n1+n2+n3+n4
    f['Vn1'] = k/3
    f['Vn2'] = k/3
    f['Vn3'] = k/3
    f['Vn4'] = k/3    
    return f

def randomGBM(N):
    f={}
    n = len(N)
    nG = 2*int(np.log(n)**2)
    nB = 2*int(np.log(n)**4)
    '''
    tmp = [i for i in N]
    rd.shuffle(tmp)
    f['G'] = tmp[:nG]
    f['B'] = tmp[nG:nG+nB]
    f['M'] = tmp[nG+nB:]
    '''
    f['gV'] = int(np.log(n)**2)
    f['bV'] = int(np.log(n)**1)
    f['nG'] = nG
    f['nB'] = nB
    f['n'] = n
    print('f=',n,nG,nB,f['gV'],f['bV'])
    return f
        
    
def OXS_grapg(G,S):
    '''
    #q = [i for i in range(int(M+N/100)) if len(S.intersection(range(100*i,100*(i+1))))>0]
    q = set([int(i/100) for i in S])
    return 100*len(q) + max(S.union([0]))
    #return max(S.union([0]))+len(S)
    '''
    
    global R
    R['t']=R['t']+1
    ss = hl.sha256(str(S)).hexdigest()
    if(R.has_key(ss)):
        R['s']=R['s']+1
        return R[ss]
    
    
    #S = np.unique(S)
    players, elements = nx.bipartite.sets(G)
    A = G.subgraph(S.union(players))
    #print(S,players,nx.adjacency_matrix(A))
    H = nx.algorithms.max_weight_matching(A)
    max_wt = sum([G[u][v]["weight"]/2.0 for u,v in H.items()])
    R[ss] = max_wt
    return max_wt     
    
def getmax(f,N,estimator):
    return max([estimator(f,set([a])) for a in N])   

l = 10
n = int(2**l)
N = range(n)

epsilon = 0.1


epsilon = 0.1
iterations = 1
f = [0]*(l+1)
OPTS =[0]*(l+1)
OPTa =[0]*(l+1)
S1 = [0]*(l+1)
S2 = [0]*(l+1)
S3 = [0]*(l+1)
S4 = [0]*(l+1)
S5 = [0]*(l+1)
Rounds = [0]*(l+1)

for r in range(l+1):
    
    k= 2**int(l/2) #(i+1)*10
    rr = k
    
    f[r] = generateNOoverlap(n,int(2**r))
    estimator =   estimatNOoverlap
    
    maxfa = getmax(f[r],N,estimator)
    Delta = int(1/(epsilon**2))
    OPTS[r] = ALG_greedy(estimator,f[r],N,k,k)
    OPT = estimator(f[r],OPTS[r])
    S1[r],Rounds[r] = ALG_nonpar(estimator,f[r],N,k,epsilon,Delta,OPT,maxfa)
    S2[r] = ALG_add(estimator,f[r],N,k)
    S3[r] = ALG_greedy(estimator,f[r],N,k,Rounds[r])
    S4[r] = ALG_bestsample(estimator,f[r],N,k,len(N))
    S5[r] = AmortizedFiltering(estimator,f[r], rr,epsilon,OPT,5)
    OPTa[r] = OPT
    # Val_OPT = [estimator(f[j],OPTS[j]) for j in range(r+1)]
    # Val_alg = [estimator(f[j],S1[j]) for j in range(r+1)]
    # Val_add = [estimator(f[j],S2[j]) for j in range(r+1)]
    # Val_greedy = [estimator(f[j],S3[j]) for j in range(r+1)]
    # Val_bestsample = [estimator(f[j],S4[j]) for j in range(r+1)]
    # Val_Asampling = [estimator(f[j],S5[j]) for j in range(r+1)]
    # # greedy/optimal           
    # # best sample out of n samples
    # O, = plt.plot([2**j for j in range(r+1)],Val_OPT, label='OPT '+str(r))
    # A, = plt.plot([2**j for j in range(r+1)],Val_alg, label='AS')
    # B, = plt.plot([2**j for j in range(r+1)],Val_greedy, label='greedy')
    # C, = plt.plot([2**j for j in range(r+1)],Val_add, label='TOPk')
    # D, = plt.plot([2**j for j in range(r+1)],Val_bestsample, label='bestSample')
    # E, = plt.plot([2**j for j in range(r+1)],Val_Asampling, label='Asampling')
    # plt.legend(bbox_to_anchor=(0.4, 1),handles=[O,A,B,C,D,E])
    # plt.show()    
Val_OPT = OPTa
OPT = OPTa
Val_alg = [(estimator(f[j],S1[j])+0.0)/OPT[j] for j in range(r+1)]
Val_add = [(estimator(f[j],S2[j])+0.0)/OPT[j] for j in range(r+1)]
Val_greedy = [(estimator(f[j],S3[j])+0.0)/OPT[j] for j in range(r+1)]
Val_bestsample = [(estimator(f[j],S4[j])+0.0)/OPT[j] for j in range(r+1)]
Val_Asampling = [(estimator(f[j],S5[j])+0.0)/OPT[j] for j in range(r+1)]

a = {'f':f,'OPTS':OPTS,'S1':S1,'S2':S2,'S3':S3,'S4':S4,'S5':S5, 'Val_alg': Val_alg, 
'Val_add': Val_add, 'Val_greedy': Val_greedy, 'Val_bestsample': Val_bestsample, 'Val_OPT': OPT, 'Val_Asampling': Val_Asampling}
outfile = open('nooverlap.pkl','wb')
pickle.dump(a,outfile)

# greedy/optimal           
# best sample out of n samples
O, = plt.plot([2**j for j in range(r+1)],[1]*(r+1), label='OPT ')
A, = plt.plot([2**j for j in range(r+1)],Val_alg, label='AS')
B, = plt.plot([2**j for j in range(r+1)],Val_greedy, label='greedy')
C, = plt.plot([2**j for j in range(r+1)],Val_add, label='TOPk')
D, = plt.plot([2**j for j in range(r+1)],Val_bestsample, label='bestSample')
E, = plt.plot([2**j for j in range(r+1)],Val_Asampling, label='Asampling')
plt.legend(bbox_to_anchor=(1.5, 1),handles=[O,A,B,C,D,E])
plt.show()

a = {'f':f,'OPTS':OPTS,'S1':S1,'S2':S2,'S3':S3,'S4':S4,'S5':S5, 'Val_alg': Val_alg, 
'Val_add': Val_add, 'Val_greedy': Val_greedy, 'Val_bestsample': Val_bestsample, 'Val_OPT': Val_OPT, 'Val_Asampling': Val_Asampling}
outfile = open('UD_ADD.pkl','wb')
pickle.dump(a,outfile)

