# -*- coding: utf-8 -*-
"""polar_modified.py


This module includes polar code routines that were edited for the purposes of channel simulation.
Functions edited are denoted via commenting. All other functions are unchanged.

Source: https://github.com/henrypfister/polar_intro

"""

import copy
import numpy as np
import numba
from numba import int64, float64, jit, njit, vectorize
import matplotlib.pyplot as plt

@njit
def bec_p1(x: int64[:], e: float64) -> (float64[:]):
    '''
    Transmit binary x through BEC(e) with P1 domain output

          Arguments:
                  x (int[:]): array of 0/1 channel inputs
                  e (float): channel erasure probability

          Returns:
                  y (float[:]): Numpy array of P1 values for received values
    '''
    y = np.random.random_sample(len(x))
    for i in range(len(x)):
        if (y[i]<e):
            y[i] = 0.5
        else:
            y[i] = x[i]

    return y

@njit
def bsc_p1(x: int64[:], p: float64) -> (float64[:]):
    '''
    Transmit binary x through BSC(p) with P1 domain output

          Arguments:
                  x (int[:]): array of 0/1 channel inputs
                  p (float): channel erasure probability

          Returns:
                  y (float[:]): Numpy array of P1 values for received values
    '''
    z = np.random.random_sample(len(x))
    y = np.zeros(len(x))+p
    for i in range(len(x)):
        if ((z[i]<p) != (x[i]==1)):
            y[i] = 1-p

    return y

@njit
def awgn_p1(x: int64[:], esno_db: float64) -> (float64[:]):
    '''
    Transmit binary x through BIAWGN(esno_db) with P1 domain output

          Arguments:
                  x (int[:]): array of 0/1 channel inputs
                  esno_db (float): channel signal energy divided by noise spctral density in dB

          Returns:
                  y (float[:]): Numpy array of P1 values for received values
    '''
    m = np.sqrt(10**(esno_db/10) * 2)
    z = np.random.normal(loc=m,scale=1,size=x.shape)
    y = 1/(1+np.exp(2*m*z))
    for i in range(len(x)):
        if (x[i]==1):
            y[i] = 1-y[i]

    return y

@njit
def awgn_p1_new(x: int64[:], sigma: float64) -> (float64[:]):
    '''
    Transmit binary x through BIAWGN(esno_db) with P1 domain output

          Arguments:
                  x (int[:]): array of 0/1 channel inputs
                  sigma (float): std dev of noise

          Returns:
                  y (float[:]): Numpy array of P1 values for received values
    '''
    #Change 0's to -1's
    y = 2*x - 1
    y = y + np.random.normal( loc=0, scale = sigma, size = x.shape )
    y_p1 = np.exp( -0.5 * (y-1)**2/(sigma**2) )/( np.exp( -0.5 * (y-1)**2/(sigma**2) ) + np.exp( -0.5 * (y+1)**2/(sigma**2) ) )

    return y_p1

@njit # Input/output specifications below to make Numba work
def polar_decode(y: float64[:],f: float64[:]) -> (int64[:],int64[:]):
    '''
    Recursive succesive cancellation polar decoder from P1 observations

          Arguments:
                  y (float[:]): channel observations in output order
                  f (float[:]): input a priori probabilities in input order

          Returns:
                  u (int[:]): input hard decisions in input order
                  x (int[:]): output hard decisions in output order
    '''
    # Recurse down to length 1
    N = len(y)
    if (N==1):
        # If information bit (i.e., f=1/2 for P1 domain)
        x = hard_dec_rr(y)
        if (f[0]==1/2):
            # Make hard decision based on observation
            return x, x.copy()
        else:
            # Use frozen bit (u,x) = (f,f)
            return x, f.astype(np.int64)
    else:
        # Compute soft mapping back one stage
        u1est = cnop(y[::2],y[1::2])

        # R_N^T maps u1est to top polar code
        uhat1, u1hardprev = polar_decode(u1est,f[:(N//2)])

        # Using u1est and x1hard, we can estimate u2
        u2est = vnop(cnop(u1hardprev,y[::2]),y[1::2])

        # R_N^T maps u2est to bottom polar code
        uhat2, u2hardprev = polar_decode(u2est,f[(N//2):])

    # Pass u decisions up and interleave x1,x2 hard decisions
    #   note: Numba doesn't like np.concatenate
    u = np.zeros(N,dtype=np.int64)
    u[:(N//2)] = uhat1
    u[(N//2):] = uhat2
    x1 = cnop(u1hardprev,u2hardprev)
    x2 = u2hardprev
    x = np.zeros(N,dtype=np.int64)
    x[::2] = x1
    x[1::2] = x2

    return u, x


@njit
def polar_channel_mc(n: int64, chan, p: float64, M: int64) -> (float64[:]):
    '''
    Monte Carlo estimate of error rates for effective channels of length N=2^n polar code

          Arguments:
                  n (int): number of polarization stages
                  chan (function): function that sends bits over channel and returns P1 observations
                  p (float): parameter for channel
                  M (int): Number of blocks for Monte Carlo estimate

          Returns:
                  biterrd (float[:]): Numpy array of channel noise scores (e.g., error rates)
                  u (int[:]): input hard decisions in input order
                  x (int[:]): output hard decisions in output order
    '''
    # Setup parameters
    N = 2**n
    f = np.zeros(N)
    biterrd = np.zeros(N)

    # Monte Carlo evaluation of error probability
    x = np.zeros(N,dtype=np.int64)
    for i in range(M):
        # Transmit random codeword through channel with parameter p
        #x = np.random.randint(0,2,size=N)
        y = chan(x,p)
        #for j in range(N):
        #    if (x[j]==1):
        #        y[j] = 1 - y[j]

        # Decode received vector using all-zero frozen vector
        uhat, xhat = polar_decode(y,f)
        biterrd = biterrd + uhat.astype(np.float64)

    return biterrd/M


@njit('(int64[:])(int64[:])') # Input/output specifications to make Numba work
def polar_transform(u):
    '''
    Encode polar information vector u

          Arguments:
                  u (int64[:]): Numpy array of input bits

          Returns:
                  x (int64[:]): Numpy array of encoded bits
    '''
    # Recurse down to length 1
    if (len(u)==1):
        return u;
    else:
        # R_N maps odd/even indices (i.e., u1u2/u2) to first/second half
        # Compute odd/even outputs of (I_{N/2} \otimes G_2) transform
        x = np.zeros(len(u), dtype=np.int64)
        x[:len(u)//2] = polar_transform((u[::2]+u[1::2])%2)
        x[len(u)//2:] = polar_transform(u[1::2])
        return x


# Check-node operation in P1 domain
#   For two independent bits with P1 equal to w1,w2, return probability of even parity
@vectorize([float64(float64,float64)],nopython=True)
def cnop(w1,w2):
    return w1*(1-w2) + w2*(1-w1)

# Bit-node operation in P1 domain
#   For two independent P1 observations (w1,w2) a uniform bit, return P1 of the bit
@vectorize([float64(float64,float64)],nopython=True)
def vnop(w1,w2):
    return (w1*w2) / (w1*w2 + (1-w1)*(1-w2))

# Hard decision with randomized rounding in P1 domain
#   Return hard MAP decision with randomized tie breaking for P1 observation
@vectorize([int64(float64)],nopython=True)
def hard_dec_rr(w):
    return np.int64(((w+2e-12*np.random.random_sample(1))>0.5+1e-12).all())
#    return np.int64((1-np.sign(1-2*w)>2*np.random.random_sample(1)).all())

# MODIFIED FOR CHANNEL SIMULATION
# Instead of stochastic tiebreaking, looks at z for tiebreaking decision
# Allows for computation of Delta string.
@vectorize([int64(float64)],nopython=True)
def hard_dec_rr_with_cr(w):
    return np.int64(((w*np.ones(1,dtype=np.float64))>0.5).all())
#    return np.int64((1-np.sign(1-2*w)>2*np.random.random_sample(1)).all())


# MODIFIED FOR CHANNEL SIMULATION
@njit # Input/output specifications below to make Numba work
def polar_decode_with_cr(y: float64[:],z: float64[:],f: float64[:]) -> (int64[:],int64[:],int64[:]):
    '''
    Recursive succesive cancellation polar decoder from P1 observations

          Arguments:
                  y (float[:]): channel observations in output order
                  z (float[:]): realizations of Unif(0,1) common randomness
                  f (float[:]): input a priori probabilities in input order

          Returns:
                  u (int[:]): input hard decisions in input order
                  x (int[:]): output hard decisions in output order
                  delta(int[:]): agreement between z and u
    '''
    # Recurse down to length 1
    N = len(y)
    if (N==1):
        # If information bit (i.e., f=1/2 for P1 domain)
        # Add in randomness here and compute delta for this position
        # Compute delta
        u_i = np.int64(z[0] > 1 - y[0])
        v_i = np.int64(z[0] > 0.5)
        delta_i = np.array([np.int64(u_i^v_i)], dtype=np.int64)
        if y[0] == 0.5:
            if z[0] > 0.5:
                y[0] += 0.1
            else:
                y[0] -= 0.1
        return np.array([u_i]), np.array([u_i]), delta_i
        #else:
        #    # Use frozen bit (u,x) = (f,f)
        #    return x, f.astype(np.int64)
    else:
        # Compute soft mapping back one stage
        u1est = cnop(y[::2],y[1::2])

        # R_N^T maps u1est to top polar code
        uhat1, u1hardprev, delta1 = polar_decode_with_cr(u1est,z[:(N//2)],f[:(N//2)])

        # Using u1est and x1hard, we can estimate u2
        u2est = vnop(cnop(u1hardprev,y[::2]),y[1::2])

        # R_N^T maps u2est to bottom polar code
        uhat2, u2hardprev, delta2 = polar_decode_with_cr(u2est,z[(N//2):],f[(N//2):])

    # Pass u decisions up and interleave x1,x2 hard decisions
    #   note: Numba doesn't like np.concatenate
    u = np.zeros(N,dtype=np.int64)
    delta = np.zeros(N, dtype=np.int64)
    u[:(N//2)] = uhat1
    u[(N//2):] = uhat2
    delta[:(N//2)] = delta1
    delta[(N//2):] = delta2
    x1 = cnop(u1hardprev,u2hardprev)
    x2 = u2hardprev
    x = np.zeros(N,dtype=np.int64)
    x[::2] = x1
    x[1::2] = x2

    return u, x, delta