import numpy as np
from dataclasses import dataclass
import math
import matplotlib.pyplot as plt
import sys

# input: An underlying BMS channel W, a bound mu = 2*nu on the output
#        alphabet size, a code length N = 2^n and index i with binary representation
#        i = <b1,b2,...,bm>_2
# output: BMS channel that is degraded wrt bit channel W_i

# New Data dataclass
@dataclass
class Data_Element:
    a = -1.0
    b = -1.0
    a_prime = -1.0
    b_prime = -1.0
    deltaI = -1.0
    left = None
    right = None
    h = -1 # index of the data element in the heap array

@dataclass
class MinHeap:
    # Heap initialization
    def __init__(self):
        self.heap = []
    # Insert an element at the end of the heap without sorting
    def insert(self,item):
        self.heap.append(item)
        index = len(self.heap) - 1
        
        # Update left and right to be left and right probabilities
        if index != 0:
            item.left = self.heap[index - 1]
            self.heap[index-1].right = item
        # do not sort yet, will do this after creating full heap
    # Remove the element from the heap with minimum deltaI value and adjust the rest of the heap
    def extract_min(self, fix):
        min_elem = self.heap[0]     # minimum element always has index 0 in the heap
        heapsize = len(self.heap)
        
        # Fix left and right fields if this is not the last element of the heap
        # Don't do this when creating matrix Q at the end of degrading procedure
        if heapsize > 1 and fix:
            if min_elem.left is None: # left-most element in list
                rightelem = min_elem.right
                ind = rightelem.h
                self.heap[ind].left = None
                # previous .right now has .left of None
            elif min_elem.right is None: # right-most element in list
                leftelem = min_elem.left
                ind = leftelem.h
                self.heap[ind].right = None
                # previous .left now has .right of None
            else: # somewhere else in the list
                leftelem = min_elem.left
                Lind = leftelem.h
                self.heap[Lind].right = min_elem.right
                # .left now has .right of min.right
                rightelem = min_elem.right
                Rind = rightelem.h
                self.heap[Rind].left = min_elem.left
                # .right now has .left of min.left
        self.heap.pop(0) # remove min element
        if heapsize > 1 and fix:
            self.min_sort() # sort heap after removing min
        return min_elem
    # Sort the heap using helper function heapify()
    def min_sort(self):
        size = len(self.heap)
        if size == 1: # no sorting needed
            self.heap[0].h = 0

        for i in range(size//2 - 1, -1, -1):
            self.heapify(size, i)
    # Recursively swap an element with a child element if it has a smaller deltaI value
    def heapify(self, size, i):
        minI = i # current index
        L = 2*i + 1 # left child
        R = 2*i + 2 # right child

        if minI < size: self.heap[minI].h = minI
        if L < size: self.heap[L].h = L
        if R < size: self.heap[R].h = R

        if L < size and self.heap[L].deltaI < self.heap[minI].deltaI:
            minI = L
        if R < size and self.heap[R].deltaI < self.heap[minI].deltaI:
            minI = R
        if minI != i:
            # perform the swap
            self.heap[i], self.heap[minI] = self.heap[minI], self.heap[i]

            # Update .h values after the swap
            self.heap[i].h = i
            self.heap[minI].h = minI

            self.heapify(size, minI) # recursively continue to sort the heap
    # Returns the element with the smallest deltaI value in the heap, which will be at the top of a sorted heap
    def get_min(self):
        if len(self.heap) == 0:
            return None
        return self.heap[0]
    # Assigns the new values of a,b,a',b' in the heap
    def update_vals(self,index,new_deltaI, new_a, new_b, new_a_prime, new_b_prime):
        self.heap[index].deltaI = new_deltaI
        if new_a is not None:
            self.heap[index].a = new_a
        if new_b is not None:
            self.heap[index].b = new_b
        if new_a_prime is not None:
            self.heap[index].a_prime = new_a_prime
        if new_b_prime is not None:
            self.heap[index].b_prime = new_b_prime
    # Returns the number of elements in the heap
    def get_size(self):
        return len(self.heap)

# Algorithm A
def degrading_procedure(W, mu, b):
    m = len(b)
    Q = degrading_merge(W,mu)
    for j in range(0,m):
        if b[j] == 0:
            script_W = square(Q)
        else: # b[j] == 1
            script_W = circle(Q)
        Q = degrading_merge(script_W,mu)
    return Q

# Arikan channel transformation 1
def square(W):
    # Original channel W is X-by-Y
    num_rows = W.shape[0]
    num_cols = W.shape[1]
    # New channel new_W is X-by-Y^2
    new_W = np.zeros((num_rows,num_cols**2))
    for u1 in range(num_rows):
        for y1 in range(num_cols): # i = 00,01,10,11 or higher dimension
            for y2 in range(num_cols):
                val = 0
                for u2 in range(num_rows):
                    # Compute W(y1|u1^u2)*W(y2|u2)
                    val += W[u1^u2, y1] * W[u2,y2]

                # Convert each y1,y2 to str and concatenate to form the output index
                stry1 = str(bin(y1)[2:].zfill(int(math.log2(W.shape[1]))))
                stry2 = str(bin(y2)[2:].zfill(int(math.log2(W.shape[1]))))

                ind = bin(int(stry1 + stry2, 2))[2:]
                ind = int(ind,2)

                new_W[u1, ind] = val / 2
    return new_W

# Arikan channel transformation 2
def circle(W):
    # Original channel W is X-by-Y
    num_rows = W.shape[0]
    num_cols = W.shape[1]
    # New channel new_W is X-by-XY^2
    new_W = np.zeros((num_rows, (num_cols**2)*num_rows))
    for u1 in range(num_rows):
        for y1 in range(num_cols):
            for y2 in range(num_cols):
                for u2 in range(num_rows):
                    # Compute W(y1|u1^u2)*W(y2|u2)
                    val = W[u1^u2, y1] * W[u2,y2]
                    
                    # Convert each u1,y1,y2 to str and concatenate to form the output index
                    stry1 = str(bin(y1)[2:].zfill(int(math.log2(W.shape[1]))))
                    stry2 = str(bin(y2)[2:].zfill(int(math.log2(W.shape[1]))))
                    stru1 = str(bin(u1)[2:].zfill(int(math.log2(W.shape[0]))))
                    ind = bin(int(stry1 + stry2 + stru1, 2))[2:]
                    ind = int(ind,2)

                    new_W[u2, ind] = val / 2
    return new_W

# Algorithm C
def degrading_merge(W,mu):  
    # W: X -> Y
    Y = W.shape[1]
    L = Y // 2
    
    v = mu // 2

    # Degraded W is W itself if output size is already <= mu
    if L <= v:
        return W

    # Want: 1 <= LR(y1) <= LR(y2) <= ... <= LR(y_L)
    
    # Small value to avoid division by 0
    epsilon = 10**(-20)

    # Perform division with epsilon added only when the denominator is zero -- avoid divide by 0
    LR = np.log(W[0, :W.shape[1] // 2] + (epsilon * (W[0, :W.shape[1] // 2] == 0))) - np.log(W[1, :W.shape[1] // 2] + (epsilon * (W[1, :W.shape[1] // 2] == 0))) # avoids overflow
    
    # Want to pick one representative from corresponding columns in first & second halves but want all LR >= 1
    ge1 = LR >= 1 # all indices are TRUE where LR >= 1
    W[:, :W.shape[1]//2][:,~ge1] = W[::-1, :W.shape[1]//2][:, ~ge1] # reverse columns where LR < 1
    new_W = W[:,:W.shape[1]//2]

    # Division with epsilon added when denominator is 0 to avoid division by 0
    LR1 = np.log(new_W[0,:] + epsilon * (new_W[0,:] == 0)) - np.log(new_W[1,:] + epsilon * (new_W[1,:] == 0)) # avoids overflow
    sort_inds = np.argsort(LR1)

    new_W = new_W[:,sort_inds]

    for i in range(1,L):
        d = Data_Element() # initialize new data element
        d.a = new_W[0,i-1]
        d.b = new_W[1,i-1]
        d.a_prime = new_W[0,i]
        d.b_prime = new_W[1,i]
        d.deltaI = calcDeltaI(d.a,d.b,d.a_prime,d.b_prime)
        insertRightmost(d)

    # Now that heap is built, need to arrange heap according to deltaI value
    heap.min_sort()

    l = L

    # Here, the heap has L-1 elements each
    # Number of elements in heap will now be decreased
    while l > v:
        d = getMin() # element with smallest I (capacity) value
        a_plus = d.a + d.a_prime
        b_plus = d.b + d.b_prime
        dLeft = d.left
        dRight = d.right
        removeMin()

        l -= 1
        
        if dLeft is not None:
            lind = dLeft.h # index of dLeft in heap
            new_a_prime = a_plus
            new_b_prime = b_plus
            new_deltaI = calcDeltaI(dLeft.a, dLeft.b, a_plus, b_plus) # new capacity calculation
            valueUpdated(lind,new_deltaI, None, None, new_a_prime, new_b_prime) # update values in heap
        if dRight is not None:
            rind = dRight.h # index of dRight in heap
            new_a = a_plus
            new_b = b_plus
            new_deltaI = calcDeltaI(a_plus, b_plus, dRight.a_prime, dRight.b_prime) # new capacity calculation
            valueUpdated(rind,new_deltaI, new_a, new_b, None, None) # update values in heap

    # Initialize Q - output space 2-by-mu
    heapsize = heap.get_size()
    Q = np.zeros((W.shape[0],mu))
    
    # COnstruct returning probability matrix
    for i in range(heapsize):
        min_elem = removeMin(fix = 0)
        Q[0,i] = min_elem.a
        Q[1,i] = min_elem.b
        Q[0,(Q.shape[1]-1)-i] = min_elem.b
        Q[1,(Q.shape[1]-1)-i] = min_elem.a
        if min_elem.right is None: # This was the last element added to the heap, need the final probability stored in 'prime' variables 
            sp_ind = mu//2 - 1
            Q[0,sp_ind] = min_elem.a_prime
            Q[1,sp_ind] = min_elem.b_prime
            Q[0,sp_ind+1] = min_elem.b_prime
            Q[1,sp_ind+1] = min_elem.a_prime

    return Q

# Initialize deltaI field for a new data element
def calcDeltaI(a,b,a_prime,b_prime):
    a_plus = a + a_prime
    b_plus = b + b_prime
    return C(a,b) + C(a_prime,b_prime) - C(a_plus,b_plus)

# Helper function for calcDeltaI()
def C(a,b):
    small_val = 1e-15
    if a + b < small_val: # checking for small value close to 0
        term1 = 0
    else:
        term1 = -(a+b)*math.log2((a+b)/2)
    if a < small_val:
        term2 = 0
    else:
        term2 = a*math.log2(a)
    if b < small_val:
        term3 = 0
    else:
        term3 = b*math.log2(b)
    return term1 + term2 + term3

# Inserts element as rightmost element of list and updates heap accordingly
def insertRightmost(d):
    heap.insert(d)

# Returns the data element with smallest delta I capacity
def getMin():
    return heap.get_min()

# Removes the element returned by getMin from heap
def removeMin(fix = 1):
    min_elem = heap.get_min()
    if min_elem is not None:
        min_elem = heap.extract_min(fix)
        return min_elem
    return None

# Updates the heap due to a change in deltaI resulting from a merge, no change to list
def valueUpdated(index, new_deltaI, new_a, new_b, new_a_prime, new_b_prime):
    # Update the values in the heap
    heap.update_vals(index, new_deltaI, new_a, new_b, new_a_prime, new_b_prime)
    heap.min_sort() # Re-sort the heap according to the new values
   
# ---------------------------------------------------------------------- #

def execute(N, e, mu):
    m = int(math.log2(N)) # N = 2^m
    epsilon = e   # BSC flipping probability

    # initialize BSC channel with crossover probability epsilon
    W_init = np.array([[1-epsilon, epsilon],[epsilon, 1-epsilon]])

    # initialize heap to be altered globally
    global heap
    heap = MinHeap()    # sorted according to deltaI field

    capacity = np.zeros(N)
    error_prob = np.zeros(N)

    for i in range(N):
        # Binary representation in list b
        strbi = str(bin(i)[2:]).zfill(m)
        b = list(map(int,list(strbi)))

        channel = degrading_procedure(W_init,mu,b)

        # Compute capacity       H(Y) - 0.5*H(Y|X=1) - 0.5*H(Y|X=0)
        HY = 0 # H(Y)
        for y in range(channel.shape[1]):
            PY = .5*channel[0,y] + .5*channel[1,y]
            if PY != 0: # using 0log0 = 0
                HY += PY * math.log2(1/PY)
    
        HYX0 = 0 # H(Y|X=0)
        for y in range(channel.shape[1]):
            PYX0 = channel[0,y]
            if PYX0 != 0:
                HYX0 += PYX0 * math.log2(1/PYX0)

        HYX1 = 0 # H(Y|X=1)
        for y in range(channel.shape[1]):
            PYX1 = channel[1,y]
            if PYX1 != 0:
                HYX1 += PYX1 * math.log2(1/PYX1)

        capacity[i] = (HY - 0.5*HYX0 - 0.5*HYX1)

        # computing error probability: sum_i min(pi,qi)/2
        error_prob[i] = np.sum(np.minimum(channel[0, :], channel[1, :]) / 2)

    return capacity, error_prob

import pandas as pd
def plotting(N, epsilon, mu):
    capacity, error_prob = execute(N, epsilon, mu)

    i_vector = np.arange(N)

    # # df = pd.DataFrame({'Index': i_vector, 'Capacity': capacity})
    # csv_name = "polarization_8192_0.1"
    # # df.to_csv(csv_name, index=False)

    # df = pd.read_csv(csv_name)
    # inds = df['Index'].values
    # capacity = df['Capacity'].values
    
    # Plot capacity
    plt.scatter(i_vector,capacity,s=5)
    plt.title('BSC Polarization - Channel Capacity For Each Subchannel, Epsilon=' + str(epsilon), fontdict = {'fontsize' : 20})
    plt.xlabel('Index', fontdict = {'fontsize' : 15})
    plt.ylabel('Capacity', fontdict = {'fontsize' : 15})
    plt.show()

    sorted_inds = np.argsort(capacity)
    
    # Plot sorted capacity
    plt.scatter(i_vector,capacity[sorted_inds],s=5, label="Capacity")
    plt.title('BSC Polarization - Sorted  Channel Capacity, Epsilon='+str(epsilon), fontdict = {'fontsize' : 20})
    plt.xlabel('Sorted Indices', fontdict = {'fontsize' : 15})
    plt.ylabel('Capacity', fontdict = {'fontsize' : 15})
    MI = (-epsilon*np.log2(epsilon) - (1-epsilon)*np.log2(1-epsilon))
    plt.vlines(x = N*(MI), colors="r", linestyles='dotted', ymax=1, ymin=0, label="Mutual Information")
    plt.legend(fontsize=20)
    plt.show()

    # Plots Error Probability
    plt.scatter(i_vector,error_prob,s=5)
    plt.title('BSC Degraded Channel - Channel Error Probability for each index i')
    plt.xlabel('Index')
    plt.ylabel('Error Probability')
    plt.show()

# plotting(8192, 0.4, 4)

# Return a vector of size N where an element is 1 if the index channel is good, 0 if poor
def get_good_ind(N, epsilon, mu, thresh):
    m = int(math.log2(N)) # such that N = 2^m

    # initialize BSC channel with crossover probability epsilon
    W_init = np.array([[1-epsilon, epsilon],[epsilon, 1-epsilon]])

    # initialize heap to be altered globally
    global heap
    heap = MinHeap()    # sorted according to deltaI field

    error_prob = np.zeros(N)

    for i in range(N):
        # Binary representation in list b
        strbi = str(bin(i)[2:]).zfill(m)
        b = list(map(int,list(strbi)))
    
        channel = degrading_procedure(W_init,mu,b)

        # computing error probability: sum_i min(pi,qi)/2
        error_prob[i] = np.sum(np.minimum(channel[0, :], channel[1, :]) / 2)

    good_ind = np.zeros(N)

    for i in range(N):
        if error_prob[i] < thresh:
            good_ind[i] = 1

    return good_ind


# Return a vector of size N where an element is 1 if the index channel is good, 0 if poor
# Also returns a vector of the good channel indices in ascending order of error prob (best channel first)
def get_good_ind_inorder(N, epsilon, mu, thresh):
    m = int(math.log2(N)) # such that N = 2^m

    # initialize BSC channel with crossover probability epsilon
    W_init = np.array([[1-epsilon, epsilon],[epsilon, 1-epsilon]])

    # initialize heap to be altered globally
    global heap
    heap = MinHeap()    # sorted according to deltaI field

    error_prob = np.zeros(N)

    for i in range(N):
        # Binary representation in list b
        strbi = str(bin(i)[2:]).zfill(m)
        b = list(map(int,list(strbi)))
    
        channel = degrading_procedure(W_init,mu,b)

        # computing error probability: sum_i min(pi,qi)/2
        error_prob[i] = np.sum(np.minimum(channel[0, :], channel[1, :]) / 2)

    ALLin_order_inds = np.argsort(error_prob)
    good_inorder_inds = [ind for ind in ALLin_order_inds if error_prob[ind] <= thresh]

    return good_inorder_inds

def return_error_prob(N, epsilon, mu):
    m = int(math.log2(N)) # such that N = 2^m

    # initialize BSC channel with crossover probability epsilon
    W_init = np.array([[1-epsilon, epsilon],[epsilon, 1-epsilon]])

    # initialize heap to be altered globally
    global heap
    heap = MinHeap()    # sorted according to deltaI field

    error_prob = np.zeros(N)

    for i in range(N):
        # Binary representation in list b
        strbi = str(bin(i)[2:]).zfill(m)
        b = list(map(int,list(strbi)))
    
        channel = degrading_procedure(W_init,mu,b)

        # computing error probability: sum_i min(pi,qi)/2
        error_prob[i] = np.sum(np.minimum(channel[0, :], channel[1, :]) / 2)

    return error_prob


# Return a vector of size N where an element is 1 if the index channel is good, 0 if poor
# Also returns a vector of the good channel indices in ascending order of error prob (best channel first)
def all_inds_inorder(N, epsilon, mu):
    m = int(math.log2(N)) # such that N = 2^m

    # initialize BSC channel with crossover probability epsilon
    W_init = np.array([[1-epsilon, epsilon],[epsilon, 1-epsilon]])

    # initialize heap to be altered globally
    global heap
    heap = MinHeap()    # sorted according to deltaI field

    error_prob = np.zeros(N)

    for i in range(N):
        # Binary representation in list b
        strbi = str(bin(i)[2:]).zfill(m)
        b = list(map(int,list(strbi)))
    
        channel = degrading_procedure(W_init,mu,b)

        # computing error probability: sum_i min(pi,qi)/2
        error_prob[i] = np.sum(np.minimum(channel[0, :], channel[1, :]) / 2)

    ALLin_order_inds = np.argsort(error_prob)

    return ALLin_order_inds


# Return a vector of size N where an element is 1 if the index channel is good, 0 if poor
def get_good_ind_and_channels(N, e, mu, thresh):
    m = int(math.log2(N)) # such that N = 2^m

    # BSC flipping probability
    epsilon = e

    # initialize BSC channel with crossover probability epsilon
    W_init = np.array([[1-epsilon, epsilon],[epsilon, 1-epsilon]])

    # initialize heap to be altered globally
    global heap
    heap = MinHeap()    # sorted according to deltaI field

    error_prob = np.zeros(N)
    all_channels = np.zeros((2,mu,N))

    for i in range(N):
        # Binary representation in list b
        strbi = str(bin(i)[2:]).zfill(m)
        b = list(map(int,list(strbi)))

        channel = degrading_procedure(W_init,mu,b)
        
        all_channels[:,:,i] = channel

        # computing error probability: sum_i min(pi,qi)/2
        error_prob[i] = np.sum(np.minimum(channel[0, :], channel[1, :]) / 2)

    good_ind = np.zeros(N)

    for i in range(N):
        if error_prob[i] < thresh:
            good_ind[i] = 1

    return good_ind, all_channels

# plotting(1024, 0.1, 4) # uncomment when running this file
# print(len(get_good_ind_inorder(1024, 0.31, 4, 0.75)))
      
def get_goodind_and_error_prob(N, e, mu, thresh):
    m = int(math.log2(N)) # such that N = 2^m

    # BSC flipping probability
    epsilon = e

    # initialize BSC channel with crossover probability epsilon
    W_init = np.array([[1-epsilon, epsilon],[epsilon, 1-epsilon]])

    # initialize heap to be altered globally
    global heap
    heap = MinHeap()    # sorted according to deltaI field

    error_prob = np.zeros(N)

    for i in range(N):
        # Binary representation in list b
        strbi = str(bin(i)[2:]).zfill(m)
        b = list(map(int,list(strbi)))

        channel = degrading_procedure(W_init,mu,b)
        
        # computing error probability: sum_i min(pi,qi)/2
        error_prob[i] = np.sum(np.minimum(channel[0, :], channel[1, :]) / 2)

    good_ind = np.zeros(N)

    for i in range(N):
        if error_prob[i] < thresh:
            good_ind[i] = 1

    return good_ind, error_prob


def rank_depends_on_epsilon():
    N = 8
    mu = 4

    inorder1 = all_inds_inorder(N, 0.001, mu)
    inorder2 = all_inds_inorder(N, 0.01, mu)
    inorder3 = all_inds_inorder(N, 0.05, mu)
    inorder4 = all_inds_inorder(N, 0.1, mu)

    print(inorder1)
    print(inorder2)
    print(inorder3)
    print(inorder4)


# rank_depends_on_epsilon()