import numpy as np
import random as rd
from os import listdir
import time
import networkx as nx
import cProfile
import hashlib as hl
import matplotlib.pyplot as plt
import pickle
from scipy.optimize import linear_sum_assignment
import make_graph

rd.seed(12345)



global k
global epsilon










def EstimateSet(estimator,f,X, S,size,samples):
    fS = estimator(f,S)
    if len(X)==0:
        return 0
    Rand = np.random.choice(list(X),int(size*samples*1.5))
    print(Rand,int(len(Rand)/size))
    Rtmp = [set(Rand[size*i:size*(i+1)]) for i in range(int(len(Rand)/size))]
    Rset = [i for i in Rtmp if len(i)==size]                   
    return np.mean([estimator(f,S.union(R)) for R in Rset[:samples]]) - fS
    
def EstimateMarginal(estimator,f,X, S,a,size,samples):
    Rand = np.random.choice(list(X),int(size*samples*1.5))
    Rtmp = [set(Rand[size*i:size*(i+1)]) for i in range(int(len(Rand)/size))]
    Rset = [i for i in Rtmp if len(i)==size]
    return np.mean([estimator(f,S.union(list(R)+[a]))-estimator(f,(S.union(R)).difference([a])) for R in Rset])

def Filter(estimator,f,X, S,size,m,r):
    global k
    global OPT
    fS = estimator(f,S)
    while EstimateSet(estimator,f,X, S,size,m) < (1-epsilon)*(OPT-fS)/r:
        vsxa = [EstimateMarginal(estimator,f,X, S,a,size,m) for a in X]
        lX = list(X)
        A = [lX[i] for i in range(len(lX)) if vsxa[i]<(1+epsilon/2)*(1-epsilon)*(OPT-fS)/k]
        X = X.difference(A)
    return X

def Filter2(estimator,f,X, S,size,m,r):
    fS = estimator(f,S)
    while EstimateSet(estimator,f,X, S,size,m) < (1-epsilon)*(OPT-fS)/r:
        Rand = np.random.choice(list(X),int(size*m*1.5))
        Rtmp = [set(Rand[size*i:size*(i+1)]) for i in range(len(Rand)/size)]
        Rset = [i for i in Rtmp if len(i)==size]
        lX = list(X)
        vsxa = [np.mean([estimator(f,S.union(list(R)+[a]))-estimator(f,(S.union(R)).difference([a])) for R in Rset]) for a in lX]
        A = [lX[i] for i in range(len(lX)) if vsxa[i]<(1+epsilon/2)*(1-epsilon)*(OPT-fS)/k]
        X = X.difference(A)
    return X
    
#   implementation of AmortizedFiltering 
def AmortizedFiltering(estimator,f, r,epsilon,OPT,m):
    global k
    global N
    
    size = int(k/r)
    S = set()
    fS=0
    sN = set(N)
    for epoch in range(int(20/epsilon)):
        X = sN.copy()
        T = set()
        fS = estimator(f,S)
        while estimator(f,S.union(T)) - fS < (epsilon/20)*(OPT-fS) and len(S.union(T))<k:
            X = Filter(estimator,f,X, S,size,m,r)
            T = T.union(rd.sample(X,size))
        S = S.union(T) #assert at most k!!!!!
    print(('jadaptive',estimator(f,S),S))
    return S
    
   
    
    
# GSAS    
def ALG_nonpar(estimator,f,N,k,epsilon,Delta,OPT,maxfa):
    t=min(OPT/(epsilon*k),maxfa)
    S = set()
    c = 0
    for j in range(Delta):
        X = set(N)
        while len(X)>0:
            c=c+1
            if(len(S)==k):
                return S,c
            sample_size = min([k-len(S),len(X)])
            A = rd.sample(X,sample_size)
            base = [estimator(f,S.union(A[:i])) for i in range(sample_size)]
            
            small_size = min(2,sample_size)
            XX = [[a for a in X if estimator(f,S.union(A[:i]+[a]))-base[i]>=t] for i in range(small_size)]
            r = [i for i in range(small_size) if len(XX[i])<(1-epsilon)*len(X)]
            if len(r)>0:
                ii = min(r)
                S = S.union(A[:ii])
                X = set(XX[ii])
                continue
            
            XX = XX + [[a for a in X if estimator(f,S.union(A[:i]+[a]))-base[i]>=t] for i in range(small_size,sample_size)]
            r = [i for i in range(sample_size) if len(XX[i])<(1-epsilon)*len(X)]
            if len(r)==0:
                S = S.union(A)
                break
            ii = min(r)
            S = S.union(A[:ii])
            X = set(XX[ii])
        t = t*(1-epsilon)  
    return S,c    
    

#choose k elements with maximal values as singeltons    
def ALG_add(estimator,f,N,k):
    NN = [i for i in N]
    rd.shuffle(NN)
    res = [estimator(f,set([i])) for i in NN]
    arr = np.array([i for i in NN])
    ind = np.array(res).argsort()[-k:]
    arr = arr[ind]
    return set(arr)

    
# sample #tries sets and return the one with the maximal value    
def ALG_bestsample(estimator,f,N,k,tries):
    S = set()
    mVAL = 0
    
    for i in range(tries):
        St = set(rd.sample(N,k))
        tmp = estimator(f,St)
        S = S if mVAL> tmp else St
        mVal = mVAL if mVAL> tmp else tmp
    return S    
    
    
# greedy algorithm    
def ALG_greedy(estimator,f,N,k,rounds):
    S = set()
    for r in range(min(rounds,k)):
        res = np.array([estimator(f,S.union([i])) for i in N])
        arr = np.array([i for i in N])
        S = S.union([arr[res.argmax()]])
    return S 

    
    
    
# evaluate f on s (assume f is of norm structure)    
def estimateNorm(f,S):
    norm = f['norm']
    if norm==-1:
        return max(f['vals'][s] for s in S)
    return np.power(max(f['vals'][s]**norm for s in S),1./norm)
    
# evaluate f on s (assume f is of 4types structure)        
def estimatelong(f,S):
    n1 = len([i for i in S if i<f['n1']])
    n2 = len([i for i in S if i>=f['n1'] and i<f['n2']])
    n3 = len([i for i in S if i>=f['n2'] and i<f['n3']])
    n4 = len([i for i in S if i>=f['n3'] and i<f['n4']])
    return 12*min(n1,f['Vn1']) + 9*min(n2,f['Vn2']) + 6*min(n3,f['Vn3']) + 4*min(n4,f['Vn4'])     

# evaluate f on s (assume f is of GBN structure)            
def estimate(f,S):
    nG = len([i for i in S if i<f['nG']])
    nB = len([i for i in S if i>=f['nG'] and i<f['nG']+f['nB']])
    nM = len(S)-nG-nB
    numM = min(nM,f['gV'])
    best = min(nG,f['gV']-numM)
    second_best = min(nB + nG - best,f['bV'])
    the_rest = nG + nB - second_best - best
    return numM+ 2*best + 2*second_best + the_rest

# evaluate f on s (assume f is of GB structure)            
def estimatesimple(f,S):
    nG = len([i for i in S if i<f['nG']])
    nB = len([i for i in S if i>=f['nG'] and i<f['nG']+f['nB']])
    return   min(nG,f['gV']) + min(nB,f['bV'])

# return a random function of norm structure           
def randomNorm(N,norm):
    f={}
    f['norm'] = norm
    f['vals'] = [rd.random() for i in N]
    return f

# return a function of 4types structure               
def randomlong(N,k):
    f={}
    n = len(N)
    n1 = k
    n2 = int((n-n1)/3) #int(np.log(n)**1)*k #int((n-n1)/3)
    n3 = int((n-n1)/3) # int(np.log(n)**2)*k #int((n-n1)/3)
    n4 = n - n3 - n2 - n1
    f['n1']=n1
    f['n2']=n1+n2
    f['n3']=n1+n2+n3
    f['n4']=n1+n2+n3+n4
    f['Vn1'] = k/3
    f['Vn2'] = k/3
    f['Vn3'] = k/3
    f['Vn4'] = k/3    
    return f

# return a function of GBN/GB structure                   
def randomGBM(N):
    f={}
    n = len(N)
    nG = 2*int(np.log(n)**2)
    nB = 2*int(np.log(n)**4)
    f['gV'] = int(np.log(n)**2)
    f['bV'] = int(np.log(n)**1)
    f['nG'] = nG
    f['nB'] = nB
    f['n'] = n
    print('f=',n,nG,nB,f['gV'],f['bV'])
    return f
        
# evaluates G on s (assume G is a graph structure)                
def OXS_grapg(G,S):
    global R
    R['t']=R['t']+1
    ss = hl.sha256(str(S)).hexdigest()
    if(R.has_key(ss)):
        R['s']=R['s']+1
        return R[ss]
    players, elements = nx.bipartite.sets(G)
    A = G.subgraph(S.union(players))
    H = nx.algorithms.max_weight_matching(A)
    max_wt = sum([G[u][v]["weight"]/2.0 for u,v in H.items()])
    R[ss] = max_wt
    return max_wt     




# evaluates G on s (assume G is a graph structure)  (faster implemantation)                   
def linear_sum_assignment_S(f,S):
    '''
    global R
    R['t']=R['t']+1
    ss = hl.sha256(str(S)).hexdigest()
    if(R.has_key(ss)):
        R['s']=R['s']+1
        return R[ss]
    '''
    if len(S)==0: return 0
    G = f['cost']
    row_ind, col_ind = linear_sum_assignment(G[list(S)])  
    
    max_wt =  -G[list(S), col_ind].sum()
    #R[ss] = max_wt
    return max_wt     

# translate a grph data structre to our format    
def graphToMat(G):
    f = {}
    players, elements = nx.bipartite.sets(G)
    A = nx.algorithms.bipartite.matrix.biadjacency_matrix(G,players,elements)
    AA = np.array([[-A[i,j] for i in range(A.shape[0])]for j in range(A.shape[1])])
    f['elements'] = np.array(elements)    
    f['cost'] = AA        

    return f

def random_grapg(n,m,r):
    G = nx.bipartite.random_graph(n, m, r)
    players, elements = nx.bipartite.sets(G)
    
    # this is to fool random
    for u,v in G.edges(): 
        if u < 150 or v < 350:
            G[u][v]["weight"] = np.random.uniform(0,0.1)
        else:
            G[u][v]["weight"] = np.random.uniform(0.9,1)
    # generate edges
    edge1 = np.random.choice(list(range(n)), 10)
    edges_to_add = []
    for e in edge1:
        for j in range(50):
            edges_to_add.append((e,n + j, 1))
    G.add_weighted_edges_from(edges_to_add)
    players, elements = nx.bipartite.sets(G)
    return G

def simulation(f_type,f_data,outputname):
        #dict for olready calced values
    global k
    global epsilon
    global iterations
    R = {} 
    R['s'] = 0
    R['t'] = 0
    

    
    
    global N
    global OPT
    global f
    global estimator
    
    
    f = [0]*iterations
    OPTS =[set()]*iterations
    S1 = [set()]*iterations
    S2 = [set()]*iterations
    S3 = [set()]*iterations
    S4 = [set()]*iterations
    S5 = [set()]*iterations
    Rounds = [0]*iterations


    if f_type =='simple':
        n = f_data['n']
        N= range(n)
        f = [randomGBM(N) for i in range(iterations)]
        maxfa = [1 for i in range(iterations)]
        estimator = estimate
        
    elif f_type =='long':
        n = f_data['n']
        N = range(n)
        f = [randomlong(N,(i+1)*10) for i in range(iterations)]
        maxfa = [12 for i in range(iterations)]
        estimator = estimatelong
        
    elif f_type =='graph':
        G = f_data['G']
        f = [graphToMat(G) for i in range(iterations)]
        players, elements = nx.bipartite.sets(G)
        N = range(len(elements))
        maxfa = [-np.min(f[i]['cost']) for i in range(iterations)]
        estimator = linear_sum_assignment_S           
    else:
        return
   

    
    
    
    
    #load graph to maximize
    #G1 = pickle.load(open('G4_prot2.pkl','rb'))
    #F = graphToMat(G1)
    
    for i in range(iterations):
        k=(i+1)*10
        r = k
        #f[i] = F 
        #players, elements = nx.bipartite.sets(G1)
        #N = range(len(elements))
        #maxfa = -np.min(f[i]['cost'])
        #estimator = linear_sum_assignment_S # estimate
        
        #f[i] = randomGBM(N)
        #maxfa = 1 #2
        #estimator =   estimatesimple
        
        Delta = int(1/(epsilon**2))
        OPTS[i] = ALG_greedy(estimator,f[i],N,k,k)
        OPT = estimator(f[i],OPTS[i])
        S1[i],Rounds[i] = ALG_nonpar(estimator,f[i],N,k,epsilon,Delta,OPT,maxfa[i])
        S2[i] = ALG_add(estimator,f[i],N,k)
        S3[i] = ALG_greedy(estimator,f[i],N,k,Rounds[i])
        S4[i] = ALG_bestsample(estimator,f[i],N,k,len(N))
        #cProfile.run('S5[i] = AmortizedFiltering(estimator,f[i], r,epsilon,OPT,5)')
        S5[i] = AmortizedFiltering(estimator,f[i], r,epsilon,OPT,5) #r=20 log1+/3(n)/
        
        print([R['s'],R['t']])
        Val_alg = [estimator(f[j],S1[j]) for j in range(i+1)]
        Val_add = [estimator(f[j],S2[j]) for j in range(i+1)]
        Val_greedy = [estimator(f[j],S3[j]) for j in range(i+1)]
        Val_bestsample = [estimator(f[j],S4[j]) for j in range(i+1)]
        Val_Asampling = [estimator(f[j],S5[j]) for j in range(i+1)]
        Val_OPT = [estimator(f[j],OPTS[j]) for j in range(i+1)]
        # greedy/optimal           
        # best sample out of n samples
        O, = plt.plot([(j+1)*10 for j in range(i+1)],Val_OPT, label='OPT')
        A, = plt.plot([(j+1)*10 for j in range(i+1)],Val_alg, label='GSAS')
        B, = plt.plot([(j+1)*10 for j in range(i+1)],Val_greedy, label='greedy')
        C, = plt.plot([(j+1)*10 for j in range(i+1)],Val_add, label='TOPk')
        D, = plt.plot([(j+1)*10 for j in range(i+1)],Val_bestsample, label='bestSample')
        E, = plt.plot([(j+1)*10 for j in range(i+1)],Val_Asampling, label='Asampling')
        plt.legend(bbox_to_anchor=(0.4, 1),handles=[O,A,B,C,D,E])
        plt.show()
    
    a = {'f':f,'OPTS':OPTS,'S1':S1,'S2':S2,'S3':S3,'S4':S4,'S5':S5}
    pickle.dump(a,open(outputname,"wb+"))

    
def simul():  
    global R
    R = {} 
    R['s'] = 0
    R['t'] = 0

    f_data = {}
    global k
    global epsilon
    global iterations
    iterations = 10
    epsilon = 0.1
    
    #G1
    f_data['n'] = 5000
    simulation('long',f_data,'./output/G1.pkl')
    #G2
    simulation('simple',f_data,'./output/G2.pkl')
    #G3    
    f_data['G'] = random_grapg(275,200,0.25)
    simulation('graph',f_data,'./output/G3.pkl')
    #G4   
    f_data['G'] = random_grapg(275,200,0.75)
    simulation('graph',f_data,'./output/G4.pkl')    
    #G5    
    f_data['G'] = pickle.load(open('./data/G_ad.gpickle','rb'))
    simulation('graph',f_data,'./output/G5.pkl')  
    #G6  
    f_data['G'] = pickle.load(open('./data/G_spon.gpickle','rb'))
    simulation('graph',f_data,'./output/G6.pkl')     
    #G7
    f_data['G'] = pickle.load(open('./data/G_giveaway.gpickle','rb'))
    simulation('graph',f_data,'./output/G7.pkl')  
    #G8    
    f_data['G'] = pickle.load(open('./data/G_win.gpickle','rb'))
    simulation('graph',f_data,'./output/G8.pkl')  
simul()