import numpy as np
import matplotlib.pyplot as plt
import cache_tools
from scipy.optimize import minimize_scalar
import all_degraded
import os
import math
import pandas as pd

#### HELPER FUNCTIONS ####

# Encoder u -> x using bit reversal permutation matrix and n-fold F matrix
def encoding(u_N):
    N = len(u_N)

    # Let B_N = bit-reversal permutation matrix
    B_N = permutation_mat(N)
    F = np.array([[1,0], [1,1]])

    F_N = calc_F(np.log2(N),F) # Kronecker product

    G_N = np.mod(np.matmul(B_N,F_N),2)   # G_N = B_N F^(Ox)n
    x_N = np.mod(np.matmul(u_N,G_N),2)
    return x_N

# Encoder u -> x using bit reversal permutation matrix and n-fold F matrix
def encoding_return_mats(u_N):
    N = len(u_N)

    # Let B_N = bit-reversal permutation matrix
    B_N = permutation_mat(N)
    F = np.array([[1,0], [1,1]])

    F_N = calc_F(np.log2(N),F) # Kronecker product

    G_N = np.mod(np.matmul(B_N,F_N),2)   # G_N = B_N F^(Ox)n
    x_N = np.mod(np.matmul(u_N,G_N),2)
    return x_N, B_N

# This is a modified decoding code to only decode for message bits
# Message bits are not assumed to encompass all of the good channels, there may be some good channels left over
def decoder(message_inds, y_N, epsilon, froz):
    N = len(y_N)
    u_hat_N = np.zeros(N)

    j = 0
    for i in range(N):
        if i in message_inds: # This is an information bit
            L_i = likelihood(y_N, u_hat_N[:i], epsilon) # in log
            if L_i >= 0: # L_i is a log likelihood value
                u_hat_N[i] = 0
            else:
                u_hat_N[i] = 1
        else: # This is a frozen bit
            u_hat_N[i] = froz[j]
            j += 1

    return u_hat_N

# Calculate and return the n-fold F matrix
def calc_F(n,F): # count backwards from n
    if n == 0:
        return np.ones(1)
    F_N = np.kron(F, calc_F(n-1,F))    # kronecker product
    return F_N

# Form the bit reversal permutation matrix
def permutation_mat(L):
    desired_length = len(bin(L-1)[2:])
    indices = np.arange(L)
    binary_indices = np.array([int('{0:0{1}b}'.format(i, desired_length)[::-1], 2) for i in indices])
    mat = np.eye(L)[binary_indices]
    return mat

# Helper function for the decoder to recursively get the likelihood ration of a prediction u_hat being 0 to 1
@cache_tools.memoize
def likelihood(y_N, u_hat, epsilon):
    N = len(y_N)
    if N == 1:
        if y_N == 0: # y_N is a single value here
            return np.log10((1-epsilon) / epsilon)
        else: # y_1 = 1
            return np.log10(epsilon / (1-epsilon))
    else:
        # y's needed as part of definition of new recursive likelihoods 
        firsthalf_y = y_N[:N//2]
        lasthalf_y = y_N[N//2:]
        
        u_hato = (u_hat[::2]).astype(int) # get only odd rows 1,3,...
        u_hate = (u_hat[1::2]).astype(int) # get only even rows 2,4,...

        if len(u_hato) != len(u_hate):
            u_hato = u_hato[:len(u_hato)-1]

        # Kronecker sum - add componentwise
        new_uhat = np.bitwise_xor(u_hato, u_hate)

        like1 = likelihood(firsthalf_y, new_uhat, epsilon)
        like2 = likelihood(lasthalf_y, u_hate, epsilon)
        # like1 and like2 will represent the order of magnitude of the likelihood calculations

        if len(u_hat) % 2 == 0: # Equation 75
            return safe_compute_even(like1, like2)
        else: # i is even, Equation 76
            power = 1 - 2*u_hat[len(u_hat)-1] # either 1 or -1
            return safe_compute_odd(like1, like2, power)

# Assist in numerical approximations
def log_one_plus_x( log_x ):
    if log_x < -745:
        ans = 0
    elif log_x < -37:
        ans = 10**log_x
    elif log_x > 37:
        ans = log_x + 10**(-1*log_x)
    else:
        ans = np.log10( 1 + 10**( log_x ) )
    return ans

# Likelihood ratio for even indices used in the likelihood function
# Avoids overflow using a log representation
def safe_compute_even( log_like_1, log_like_2 ):
    # (like1*like2 + 1) / (like1 + like2)
    
    #Compute Numerator
    log_like = log_like_1 + log_like_2
    numerator = log_one_plus_x( log_like )
    
    #Compute Denominator
    max_term = max( log_like_1, log_like_2 )
    min_term = min( log_like_1, log_like_2 )
    denominator = max_term + log_one_plus_x( min_term - max_term )
    
    return numerator - denominator

# Likelihood ratio for odd indices used in the likelihood function
# Avoids overflow using a log representation
def safe_compute_odd(like1, like2, power):
    # (like1)**power * like2
    if power == -1:
        order1 = -like1
    elif power == 1:
        order1 = like1
    order2 = like2

    # Multiplication of two terms
    return order1 + order2

# Create a random permutation matrix
def rand_permute_mat(N):
    arr = np.identity(N)
    # Shuffle identity matrix to create permutation matrix
    np.random.shuffle(arr)
    return arr

# Function to flip exactly k0 0s and k1 1s from y -> x_hat
def fix_flip(x_hat, y, k_all):
    N = len(y)
    
    # Find number of flipped bits (TOTAL)
    num_flip = 0
    for i in range(N):
        if x_hat[i] != y[i]: # Bit was flipped
            num_flip += 1

    # print("XXX Num flipped: " + str(num_flip) + ", Desired Num Flipped = " + str(k_all))

    # Number of bits that will need to be flipped
    need_flip = abs(num_flip - k_all)


    # Fix x_hat to satisfy k_all
    index = 0
    while num_flip != k_all:
        if y[index] == 0: 
            if num_flip < k_all and x_hat[index] == 0: # Need to flip more bits
                x_hat[index] = 1
                num_flip += 1
            elif num_flip > k_all and x_hat[index] == 1: # Need to correct flips from y=0
                x_hat[index] = 0
                num_flip -= 1
        elif y[index] == 1:
            if num_flip < k_all and x_hat[index] == 1: # Need to flip more bits
                x_hat[index] = 0
                num_flip += 1
            elif num_flip > k_all and x_hat[index] == 0: # Need to correct flips from y=1
                x_hat[index] = 1
                num_flip -= 1
        index += 1

    # print("XXX Completed at index " + str(index))

    return x_hat, need_flip, index

#### SIMULATOR CODE ####

# The simulator encoder consists of the channel decoder
# Takes the input y and returns a prediction u_hat
def sim_encoder(message_inds, y_N, epsilon, N, froz_N, k_all):    
    u_hat_N = decoder(message_inds, y_N, epsilon, froz_N)
    message_uhats = np.array([u_hat_N[i] for i in range(N) if i in message_inds])

    # Run decoder to get l and fixed bits for exact BSC
    uhat_N = np.zeros(N)
    j=k=0
    for i in range(N):
        if i in message_inds:
            uhat_N[i] = message_uhats[j]
            j += 1
        else:
            uhat_N[i] = froz_N[k]
            k += 1
    
    # encode by multiplying by G_N
    x_hat_N = encoding(uhat_N)

    # Fix flipped bits for exact BSC - Want to alter x_hat_N_message bits to flip exactly k0 0s and k1 1s
    x_hat_fixed, need_flip, l = fix_flip(x_hat_N, y_N, k_all)

    fixed_bits = x_hat_fixed[:l]

    return message_uhats, need_flip, l, fixed_bits

# The simulator decoder consists of the channel encoder
# Takes an input u_hat and returns an encoded x_hat 
def sim_decoder(message_inds, u_hat_message_N, N, froz, num_flip, l, fixed_bits):
    # add back frozen bits to get full u_hat
    uhat_N = np.zeros(N)

    j=k=0
    for i in range(N):
        if i in message_inds:
            uhat_N[i] = u_hat_message_N[j]
            j += 1
        else:
            uhat_N[i] = froz[k]
            k += 1
    
    # encode by multiplying by G_N
    x_hat_N = encoding(uhat_N)

    # Flip num_flip number of bits
    # Fix x_hat with fixed_bits from encoder
    x_hat_fixed = x_hat_N + 0
    if l != -1:
        x_hat_fixed[:l] = fixed_bits

    return x_hat_N, x_hat_fixed # Return the fixed string


#### SIMULATION FUNCTION ####

def simulation_exact_BSC(N, n, epsilon_arr):
    # n = # of channel uses
    # N = # of bits per channel use

    channel_message = ""
    for i in range(n):
        epsilon = epsilon_arr[i]

        csv_name = str(N) + "epsilon_numincorrect.csv"
        df = pd.read_csv(csv_name)
        num_incorrect_arr = df['Num Incorrect'].values
        epsilons = df['Epsilon'].values

        # Decide # bits to flip in message inds
        # These will be the exact numbers of 0,1 to flip
        k_all = np.random.binomial(N, epsilon)

        # Total flip prob - will be used as the new epsilon
        delta = k_all / N

        # Binomial distribution expected value
        binom_numinc = delta*N

        # Find closest num_incorrect and corresponding epsilon (theta), map to the curve
        index = np.argmin(np.abs(num_incorrect_arr - binom_numinc))
        theta = epsilons[index]
        
        # Generate y bits
        y_N = np.random.randint(2, size=N) # Random generation       
        
        # Compute number of message bits
        k = 1 - binary_entropy(theta)
        p = (N * k).astype(int) # number of message bits

        # Form list of ordered subchannels
        all_degraded.generate_orderfile(theta,N)

        # Use theta to determine # message bits and ordering of good subchannels
        file_name = "indices_" + str(N) + "ep" + str(theta) + ".txt"
        directory = "subchannel_rankings"
        with open(os.path.join(directory, file_name)) as f:
            ordered_inds = [int(line) for line in f.readlines()]

        message_inds = ordered_inds[:p] # This will hold all message indices 

        # print("# Message bits / # total bits: " + str(p) + "/" + str(N) + "=" + str(p/N))

        # randomly assign frozen bits
        froz_N = np.random.randint(2, size=N-p)    

        # Permute indices to not favor any bits over another
        permute_matrix = rand_permute_mat(N)
        y_N_permute = np.matmul(permute_matrix, y_N)

        # Simulation encoder to obtain message bits, using delta instead of epsilon
        u_hat_message_N, need_flip, l, fixed_bits_N = sim_encoder(message_inds, y_N_permute, theta, N, froz_N, k_all)
        
        u_hat_message = ''.join(map(str, u_hat_message_N.astype(int)))
        fixed_bits = ''.join(map(str, fixed_bits_N.astype(int)))

        # Prefix Code: thresholded on fraction of message bits
        if p/N >= 0.5: # Use N choose l prefix
            # Also would need to add on N choose l string...
            channel_message += str(bin(p)[2:].zfill(int(np.log2(N)))) + u_hat_message + str(bin(need_flip)[2:].zfill(int(np.log2(N))))
        else: # Use simple prefix case
            channel_message += str(bin(p)[2:].zfill(int(np.log2(N)))) + u_hat_message + str(bin(l)[2:].zfill(int(np.log2(N)))) + fixed_bits
 

    # Send message bits across channel...
    # print("Sending u_hat message bits and initial prefix across channel...")
    # nothing happening in the channel now, will at some point


    # Decipher channel bits
    u_hat_message_arr = []
    fixed_bits_arr = []
    l_arr = []
    num_flip_arr = []
    ind = 0
    while ind < len(channel_message):
        # Get length of message bits
        length_message = channel_message[ind:ind + int(np.log2(N))]
        p = int(length_message,2) # decipher bits to get # of message bits
        ind += int(np.log2(N))

        # Extract message bits
        curr_message = channel_message[ind:ind+p]
        u_hat_message_arr.append(curr_message)
        ind += p

        # Threshold on fraction of message bits
        if p/N >= 0.5: # N choose l prefix scheme
            # Get length of fixed array
            num_flip_length = channel_message[ind:ind + int(np.log2(N))]
            num_flip = int(num_flip_length,2)
            ind += int(np.log2(N))
            
            num_flip_arr.append(num_flip)
            # Append empty string to fixed_bit_arr
            fixed_bits_arr.append("")
            l_arr.append(-1)

        else: # simple scheme
            # Get length of fixed array
            length_fixed = channel_message[ind:ind + int(np.log2(N))]
            l = int(length_fixed,2)
            ind += int(np.log2(N))

            # Extract fixed bits
            curr_fixed_bits = channel_message[ind:ind+l]
            fixed_bits_arr.append(curr_fixed_bits)
            l_arr.append(l)
            num_flip_arr.append(-1)

            ind += l

    for message_num in range(len(u_hat_message_arr)):
        u_hat_message = u_hat_message_arr[message_num]
        curr_fixed_bits = fixed_bits_arr[message_num]
        l = l_arr[message_num]
        num_flip = num_flip_arr[message_num]

        # Convert to arrays
        u_hat_message_N = np.array([int(bit) for bit in u_hat_message], dtype=int)
        curr_fixed_bits_N = np.array([int(bit) for bit in curr_fixed_bits], dtype=int)

        x_hat_N_permute, x_hat_N_fix_permute = sim_decoder(message_inds, u_hat_message_N, N, froz_N, num_flip, l, curr_fixed_bits_N)

        # Un-permute the indices
        x_hat_N = np.matmul(np.linalg.inv(permute_matrix), x_hat_N_permute)
        x_hat_N_fix = np.matmul(np.linalg.inv(permute_matrix), x_hat_N_fix_permute)

        # Calculate num incorrect from x_hat_N to y_N elementwise
        num_incorrect = sum(x != y for x, y in zip(x_hat_N, y_N))
        num_incorrect_withfix = sum(x != y for x, y in zip(x_hat_N_fix, y_N))
        
        # print("Frac incorrect = " + str(num_incorrect / N))
        # print("Num incorrect = " + str(num_incorrect))

    return num_incorrect, num_incorrect_withfix, num_flip, l, k

### TESTS ###

def execute_fracmessage_test(N, R, H_inv):
    num_inc = np.zeros(len(R))
    orig_num_inc = np.zeros(len(R))
    num_flip_arr = np.zeros(len(R))
    l_arr = np.zeros(len(R))
    rate = np.zeros(len(R))
    for j in range(len(R)):
        eps = H_inv[j]
        arr = np.zeros(3)
        orig_arr = np.zeros(3)
        num_flip_avg = np.zeros(3)
        l_avg = np.zeros(3)
        rate_avg = np.zeros(3)
        for i in range(3):
            orig_arr[i], arr[i],  num_flip_avg[i], l_avg[i], rate_avg[i] = simulation_exact_BSC(N=N, n=1, epsilon_arr=[eps])
        num_flip_arr[j] = np.median(num_flip_avg)
        num_inc[j] = np.median(arr)
        l_arr[j] = np.median(l_avg)
        orig_num_inc[j] = np.median(orig_arr)
        rate[j] = np.median(rate_avg)
    return orig_num_inc, num_inc, num_flip_arr, l_arr, rate
import matplotlib
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
def test_capacity_curve(N):
    # Increment points
    R = np.arange(0, 1, 1/50)

    # Inverse capacity function
    H_inv = inverse_H(1-R)

    orig_num_inc, num_inc, num_flip_arr, l_arr, rate = execute_fracmessage_test(N, R, H_inv)
    
    log_N_choose_L = np.zeros(len(num_flip_arr))
    for i in range(len(num_flip_arr)):
        L = num_flip_arr[i]
        if L != -1:
            log_N_choose_L[i] = math.log2(math.comb(N,int(L)))/N

    plotting_x_axis = np.zeros(len(R))
    # 2logN for representation of l and message bits
    for point in range(len(R)):
        if R[point] >= 0.5: # R[point] = p, N choose l scheme
            plotting_x_axis[point] = rate[point] + log_N_choose_L[point] + 2*np.log2(N)/N
        else:
            plotting_x_axis[point] = rate[point] + (2*np.log2(N) + l_arr[point])/N

    # Save num_incorrect data to files
    df = pd.DataFrame({'Rate': rate, 'Num Incorrect': orig_num_inc})
    csv_name = str(N) + '_scheme3_numincorrect.csv'
    df.to_csv(csv_name, index=False)
    df = pd.DataFrame({'Total Rate': plotting_x_axis, 'Num Incorrect': num_inc})
    csv_name = str(N) + 'EXACT_center_hist_numincorrect.csv'
    df.to_csv(csv_name, index=False)

    plt.plot(rate, orig_num_inc/1024)
    plt.plot(plotting_x_axis, num_inc/1024)
    plt.plot(R, H_inv)
    plt.grid()
    plt.xlabel("Rate")
    plt.ylabel("Fraction incorrect X -> Y")
    plt.title("Fraction of incorrect bits vs Rate for n = " + str(N) + " Exact BSC")
    plt.legend(["Prior to Post-Correction", "With Post-Correction", "Expected"])
    plt.show()

def binary_entropy(p):
    return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

def inverse_H(arr):
    vals = []
    for val in arr:
        # Define a function that returns the absolute difference between H(p) and the desired value
        func = lambda p: abs(binary_entropy(p) - val)
        # Use minimize_scalar to find the value of p that minimizes the absolute difference
        result = minimize_scalar(func, bounds=(1e-15, 1-1e-15), method='bounded')
        vals.append(result.x)
    vals = np.array(vals)
    vals[vals > 0.5] = 1 - vals[vals > 0.5]
    return vals

def sim_histogram(trials, N):
    num_incorrect = np.zeros(trials)
    k = 0.2
    epsilon = inverse_H([1-k])

    for i in range(trials):
        num_incorrect[i],_ = simulation_exact_BSC(N=N, n=1, epsilon_arr=epsilon)

    binomialRV = np.random.binomial(N, epsilon, trials)

    ax1 = plt.subplot(211)
    plt.hist(num_incorrect, bins=trials)
    plt.title("Histogram of Simulation Number of Bits incorrect from Y to X")
    plt.xlabel("Number of bits incorrect from Y to X")
    plt.ylabel("Number of trials")

    plt.subplot(212, sharex = ax1)
    plt.hist(binomialRV, bins=trials, color='orange')
    plt.title("Histogram of Binomial Random Variable with Parameters" + str((N,epsilon[0])))
    plt.xlabel("Binomial RV Values")
    plt.ylabel("Number of trials")

    plt.tight_layout()
    # plt.show()
    plt.grid()
    plt.savefig("MIX_prefix_capacityplot1024.png")
    # plt.show()

import matplotlib
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
# test_capacity_curve(1024)
# sim_histogram(20, 1024)


def replot_capacity():
    # Increment points
    R = np.arange(0, 1, 1/50)

    # Inverse capacity function
    H_inv = inverse_H(1-R)

    csv_name = "8192_scheme3_numincorrect.csv"
    df = pd.read_csv(csv_name)
    orig_num_inc = df['Num Incorrect'].values
    rate = df['Rate'].values

    csv_name = "8192EXACT_center_hist_numincorrect.csv"
    df = pd.read_csv(csv_name)
    num_inc = df['Num Incorrect'].values
    plotting_x_axis = df['Total Rate'].values

    # Smooth post-correction results
    degree = 7
    coefficients = np.polyfit(plotting_x_axis, num_inc, degree)
    poly_curve = np.poly1d(coefficients)

    # Generate y values for the curve
    exact_X = np.linspace(min(plotting_x_axis), max(plotting_x_axis), 100)
    exact_Y = poly_curve(exact_X)


    plt.plot(rate, orig_num_inc)
    plt.plot(exact_X, exact_Y)
    plt.plot(R, H_inv * 8192)
    plt.grid()
    plt.xlabel("Rate")
    plt.ylabel("Number of bits incorrect Y -> X")
    plt.title("Num of bits incorrect vs Rate for N = " + str(8192) + " Exact BSC")
    plt.legend(["Prior to Post-Correction", "With Post-Correction", "Expected"])
    plt.show()


def replot_capacity2():
    # Increment points
    R = np.arange(0, 1, 1/50)

    # Inverse capacity function
    H_inv = inverse_H(1-R)

    csv_name = "8192_scheme3_numincorrect.csv"
    df = pd.read_csv(csv_name)
    orig_num_inc = df['Num Incorrect'].values
    rate = df['Rate'].values

    plt.plot(rate, orig_num_inc)
    plt.plot(R, H_inv * 8192)
    plt.grid()
    plt.xlabel("Rate")
    plt.ylabel("Number of bits incorrect Y -> X")
    plt.title("Num of bits incorrect vs Rate for N = " + str(8192))
    plt.legend(["Uncorrected Simulator", "Expected"])
    plt.show()

# replot_capacity()
# replot_capacity2()

# orig_num_inc, num_inc, num_flip_arr, l_arr, rate = simulation_exact_BSC(N=8192, n=1, epsilon_arr=[0.2])

# rate_total = rate + (2*np.log2(8192) + l_arr)/8192

# print(rate_total)