import numpy as np
import math
import string


def index(array, M, m):
    for i in range(M):
        if array[i] == m:
            return i
    return M


def encode(num, bit):
    num_str = bin(num)[2:]
    length = len(num_str)
    ans = []
    for i in range(bit-length):
        ans.append(0)
    for string in num_str:
        ans.append(int(string))
    return ans


def decode(C):
    length = np.shape(C)[0]
    c = 0
    for i in range(length):
        if C[i] == 1:
            c = c+(2)**(length-i-1)
    return c


def collision_indicator(j, arm_preference, pull, M):
    C = 0
    arm = pull[j]
    for i in range(M):
        if pull[i] == arm:
            # if (arm_preference[arm,:]).index(i)<(arm_preference[arm,:].tolist()).index(j):
            if index(arm_preference[arm], M, i) < index(arm_preference[arm], M, j):
                C = 1
    return C


def index_assignment(arm_preference, M, K, information, t_total, t_total_collision):
    inf = information
    pull = np.zeros(M, int)
    arm = np.zeros(M, int)
    t_0 = 0
    for i in range(K):
        t_0 = t_total_collision[0, i]+t_0
    for k in range(K):
        for t in range(M):
            for m in range(M):
                pull[m] = arm[m]
            for i in range(M):
                t_total[i, pull[i]] = t_total[i, pull[i]]+1 * \
                    (1-collision_indicator(i, arm_preference, pull, M))
                t_total_collision[i, pull[i]] = t_total_collision[i, pull[i]]+1
                if pull[i] == k:
                    collision = collision_indicator(i, arm_preference, pull, M)
                    if collision == 0:
                        arm[i] = (arm[i]+1) % (K)
                        inf[i, k, t] = i
    return inf, t_total, t_total_collision


def information_exchange(k1, k2, k, arm_preference, K, M, information, t_total, t_total_collision):
    pull = np.zeros(M, int)
    log_m = math.ceil(math.log(M, 2))
    for k3 in range(K):
        for k4 in range(M):
            C = np.zeros((log_m), int)
            for k5 in range(log_m):
                for j in range(M):
                    if information[j, k, k1] == j:
                        if encode(information[j, k3, k4]+1, log_m)[k5] == 1:
                            pull[j] = k
                        else:
                            pull[j] = (k + 1) % (K)
                    elif information[j, k, k2] == j:
                        pull[j] = k
                    else:
                        pull[j] = (k + 1) % (K)
                for i in range(M):
                    t_total[i, pull[i]] = t_total[i, pull[i]]+1 * \
                        (1-collision_indicator(i, arm_preference, pull, M))
                    t_total_collision[i, pull[i]
                                      ] = t_total_collision[i, pull[i]] + 1
                    if k1 != k2:
                        if information[i, k, k2] == i:
                            C[k5] = collision_indicator(
                                i, arm_preference, pull, M)
                            if k5 == log_m-1:
                                if decode(C) != 0:
                                    information[i, k3, k4] = decode(C)-1
    return information, t_total, t_total_collision


def information_access(arm_preference, M, K, information, t_total, t_total_collision):
    for m in range(M):
        for k in range(K):
            for k1 in range(M):
                for k2 in range(k1+1, M):
                    information, t_total, t_total_collision = information_exchange(
                        k1, k2, k, arm_preference, K, M, information, t_total, t_total_collision)
        return information, t_total, t_total_collision


def whether_leader(information, agent_0, arm, agent, M):
    leader = []
    K = len(arm)
    K = 3
    flag = True
    for k in range(K - 1):
        if information[0][k][0] == information[0][k + 1][0]:
            flag = True
        else:
            flag = False
            break
    for k in range(K):
        if flag and k != information[0][k][0]:
            leader.append(k)
        elif not flag:
            leader.append(k)
    return leader


def findtime(t_total_collision, K, agent):
    t = 0
    if type(agent) == int:
        k = agent
        for j in range(K):
            t = t + t_total_collision[k][j]
    else:
        k = agent[0]
        for j in range(K):
            t = t+t_total_collision[k][j]
    return t


def success_information(leader, follower, arm, success, arm_preference, information, M, t_total, t_total_collision, agent, reward, K, value):
    K_layer = len(arm)
    M_layer = len(leader)
    pull = np.zeros(M, int)-1
    for m in range(M_layer):
        for k in range(K_layer):
            for m1 in range(M):
                for m2 in range(m1 + 1, M):
                    for j in leader:
                        if information[j, arm[k], m1] == j:
                            if success[j] == 0:
                                pull[j] = arm[k]
                            else:
                                m = (k + 1) % (K_layer)
                                pull[j] = arm[m]
                        elif information[j, arm[k], m2] == j:
                            pull[j] = arm[k]
                        else:
                            m = (k + 1) % (K_layer)
                            pull[j] = arm[m]
                    for i in follower:
                        if information[i, arm[k], m2] == i:
                            pull[i] = arm[k]
                        else:
                            m = (k + 1) % (K_layer)
                            pull[i] = arm[m]
                    for j in follower:
                        t_total[j, pull[j]] = t_total[j, pull[j]] + 1 * (
                            1 - collision_indicator(j, arm_preference, pull, M, agent))
                        t_total_collision[j, pull[j]
                                          ] = t_total_collision[j, pull[j]] + 1
                        if information[j, arm[k], m2] == j:
                            if information[j, arm[k], m1] in leader:
                                # print(success[information[j, arm[k], m1]],arm[k],pull)
                                success[j] = 1-collision_indicator(
                                    j, arm_preference, pull, M, agent)
                    for i in leader:
                        t_total[i, pull[i]] = t_total[i, pull[i]] + 1 * (
                            1 - collision_indicator(i, arm_preference, pull, M))
                        t_total_collision[i, pull[i]
                                          ] = t_total_collision[i, pull[i]] + 1
                        if information[i, arm[k], m2] == i:
                            if success[i] == 1:
                                success[i] = 1-collision_indicator(
                                    i, arm_preference, pull, M)
                    if findtime(t_total_collision, K, agent) % 100000 == 0:
                        earn = 0
                        for i in range(M):
                            for k in range(K):
                                earn = t_total[i, k] * value[i][k] + earn
                                nod = int(
                                    (findtime(t_total_collision, K, agent)) / 100000)
                                reward[nod] = earn
    return success, t_total, t_total_collision, reward


def GS_arm(information, u, leader, arm_preference, arm, follower, M, agent, K, t_total, t_total_collision, reward, value):
    K_layer = int(np.shape(arm)[0])
    M_layer = int(np.shape(agent)[0])
    M_layer_leader = int(np.shape(leader)[0])
    M_layer_follower = int(np.shape(follower)[0])
    estimation = np.zeros((M, K), int)
    lenth1 = M_layer_leader**2
    pull = np.zeros(M, int)-1
    optimal = np.zeros(M, int)
    lenth2 = M_layer_follower
    arm_left = list()
    for a in arm:
        arm_left.append(a)
    # print(arm_left)
    for i in leader:
        estimation[i, :] = u[i, :].argsort()[::-1]
        optimal[i] = 0
    for t in range(lenth1):
        for i in leader:
            while estimation[i, optimal[i]] not in arm:
                optimal[i] = optimal[i]+1
            pull[i] = estimation[i, optimal[i]]
        for j in follower:
            pull[j] = arm[0]
        for j in follower:
            t_total[j, pull[j]] = t_total[j, pull[j]] + 1 * \
                (1 - collision_indicator(j, arm_preference, pull, M))
            t_total_collision[j, pull[j]] = t_total_collision[j, pull[j]] + 1
        for i in leader:
            c = collision_indicator(
                i, arm_preference, pull, M)
            t_total[i, pull[i]] = t_total[i, pull[i]] + 1 * \
                (1 - collision_indicator(i, arm_preference, pull, M))
            t_total_collision[i, pull[i]] = t_total_collision[i, pull[i]] + 1
            if c == 1:
                optimal[i] = optimal[i]+1
        if findtime(t_total_collision,  K, agent) % 100000 == 0:
            earn = 0
            for i in range(M):
                for k in range(K):
                    earn = t_total[i, k] * value[i][k] + earn
                    nod = int((findtime(t_total_collision, K, agent)) / 10000)
                    reward[nod] = earn
    for k in range(K_layer):
        for t in range(lenth2):
            for i in leader:
                pull[i] = estimation[i, optimal[i]]
            for j in follower:
                if j == follower[t]:
                    pull[j] = arm[k]
                else:
                    m = (k+1) % K_layer
                    pull[j] = arm[m]
            for i in leader:
                t_total[i, pull[i]] = t_total[i, pull[i]] + 1 * (
                    1 - collision_indicator(i, arm_preference, pull, M))
                t_total_collision[i, pull[i]
                                  ] = t_total_collision[i, pull[i]] + 1
            for j in follower:
                t_total[j, pull[j]] = t_total[j, pull[j]] + 1 * (
                    1 - collision_indicator(j, arm_preference, pull, M))
                t_total_collision[j, pull[j]
                                  ] = t_total_collision[j, pull[j]] + 1
                if j == follower[t]:
                    c = collision_indicator(j, arm_preference, pull, M)
                    if c == 1:
                        if arm[k] in arm_left:
                            arm_left.remove(arm[k])
                if findtime(t_total_collision, K, agent) % 100000 == 0:
                    earn = 0
                    for i in range(M):
                        for k in range(K):
                            earn = t_total[i, k] * value[i][k] + earn
                            nod = int(
                                (findtime(t_total_collision, K, agent)) / 100000)
                            reward[nod] = earn
    return pull, arm_left, t_total, t_total_collision, reward


def exploration(arm, agent, arm_preference, u, T, time, value, M, t_total, t_total_collision, reward_0):
    pai = 0
    K = int(np.shape(arm)[0])
    M_1 = int(np.shape(agent)[0])
    pull = np.zeros(M, int)
    t_1 = (math.log(T, 2))
    t_0 = 0
    K_0 = (np.shape(t_total_collision)[1])
    for i in range(K_0):
        t_0 = t_0+t_total_collision[0, i]
    for t in range(K*math.ceil(t_1)):
        for i in range(M_1):
            pull[agent[i]] = arm[(i+t) % K]
        t_0 = t_0+1
        for i in agent:
            reward = np.random.normal(loc=value[i][pull[i]], scale=1.0, size=None)*(
                1-collision_indicator(i, arm_preference, pull, M))
            u[i, pull[i]] = (u[i, pull[i]]*time[i, pull[i]] +
                             reward)/(time[i, pull[i]]+1)
            time[i, pull[i]] = time[i, pull[i]]+1
            t_total[i, pull[i]] = t_total[i, pull[i]] + 1 * \
                (1 - collision_indicator(i, arm_preference, pull, M))
            t_total_collision[i, pull[i]] = t_total_collision[i, pull[i]]+1
        if (t_0) % 100000 == 0:
            earn = 0
            for i in range(M):
                for k in range(K):
                    earn = t_total[i, k] * value[i][k] + earn
            nod = int((t_0) / 100000)
            reward_0[nod] = earn
    return u, time, t_total, t_total_collision, reward_0


def whether_success(j, u, time, arm, T):
    t = math.log(T)
    for k1 in arm:
        for k2 in arm:
            if k1 != k2:
                if u[j, k1] > u[j, k2]:
                    if u[j, k1] - ((2*t) / time[j, k1])**(1/2) < u[j, k2] + ((2*t) / time[j, k2])**(1/2):
                        return 0
                else:
                    if u[j, k1] + ((2*t) / time[j, k1])**(1/2) > u[j, k2] - ((2*t) / time[j, k2])**(1/2):
                        return 0

    return 1
