# -*- coding: utf-8 -*-
"""
Created on Tue May 14 21:46:32 2024

@author: shara
"""
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from scipy import stats
from scipy.stats import bootstrap
import statsmodels.api as sm
import pickle
import scienceplots
plt.style.use(['science', 'grid'])
#matplotlib.use("pgf")
plt.rcParams['font.family'] = 'serif'  # or 'DejaVu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']  # 'DejaVu Serif' 'serif' 'Times
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'''
\usepackage{amsmath}
\usepackage[lcgreekalpha]{stix2}
'''

def get_data_store_filename( source_power, rate ):
    return 'AWGN_Sim_' + str(source_power) + '_' + str(rate) + '.pickle'
def snr_to_mse(snr, source_power=1):
    return source_power/( 10**( snr/10 ) )


operating_rates = [1, 2, 3]
operating_snrs = []
source_power = 1
blocklength = 1000
noise_powers = []
qq_fig, qq_axes = plt.subplots( 1, 3 )
colors = [ 'teal', 'teal', 'teal' ]
plot_index = 0

qq_fig.supxlabel('Theoretical Quantiles', y=-0.07)

#Confidence Intervals
ci_low = []
ci_high = []
for rate in operating_rates:
    # qq_axes[plot_index].set_title( 'R=' + str( operating_rates[plot_index] ) )
    qq_axes[plot_index].xaxis.label.set_visible(False)
    qq_axes[plot_index].yaxis.label.set_visible(False)

    qq_axes[plot_index].tick_params(axis='both', which='major', color='0', labelsize=4)
    qq_axes[plot_index].tick_params(axis='both', which='minor', color='0.3')
    qq_axes[plot_index].grid(b=True, which='major', color='0.65', linestyle='-', linewidth = 0.2)


    data_store_filename = get_data_store_filename( source_power, rate )
    with open(data_store_filename, 'rb') as f:
        data_dictionary = pickle.load(f)
    noise_power = data_dictionary['Noise Power']
    data_for_bootstrap = np.array(data_dictionary['Post Rotation MSE Levels'])/blocklength
    data_for_bootstrap = (data_for_bootstrap,)
    ci_lims = bootstrap(data_for_bootstrap, np.mean, n_resamples=20, method = 'percentile' )
    ci_low.append( ci_lims.confidence_interval.low )
    ci_high.append( ci_lims.confidence_interval.high )
    blocklength = data_dictionary['Blocklength']
    noise_powers.append( noise_power )
    operating_snrs.append( 0.5*np.log2( 1 + source_power/noise_power ) )
    pp = sm.ProbPlot(np.array( data_dictionary['Post Rotation MSE Levels'] ), dist=stats.gamma, distargs=(blocklength/2, ), loc = 0, scale = 2*noise_power)
    pp.qqplot( marker = '.', line="45", ax=qq_axes[plot_index], label = 'R='+str(rate), markerfacecolor = colors[plot_index])
    plot_index = plot_index + 1

    # plt.xlabel('Theoretical')
    # plt.ylabel( 'Observed' )
ci_low_mi = np.array( [ 0.5*np.log2( 1 + source_power/c ) for c in ci_low] )
ci_high_mi = np.array( [ 0.5*np.log2( 1 + source_power/c ) for c in ci_high] )
ci_high = np.array( ci_high )
plt.tight_layout()
qq_fig.supxlabel('Theoretical Quantiles', size=7, y=0.07)
qq_fig.supylabel('Sample Quantiles', size=7, x=-0.01)
qq_fig.savefig("TCQQQPlots.pdf")

    #Check if error is distributed symmetrically around source

#MI Plot
lb_color = '#2ca25f'
snr_range = np.linspace( 0,5,1000 )
fig, ax = plt.subplots()

ax.tick_params(axis='both', which='major', color='0', labelsize=12)
ax.tick_params(axis='both', which='minor', color='0.3')
ax.grid(b=True, which='major', color='0.65', linestyle='-', linewidth = 0.2)

ax.plot( snr_range, snr_range, color = lb_color, label='Lower bound' )
ax.plot( operating_snrs, operating_rates, 'x', color = '#67000d' , label = 'TCQ Sim' )

ax.plot( ci_low_mi, operating_rates, '.', color = 'black', markersize = 1, label = 'Upper Confidence Interval' )
ax.plot( ci_high_mi, operating_rates, '.', color = 'purple', markersize = 1, label = 'Lower Confidence Interval' )
ax.set_xlabel( '$I(X;Y)$', size = 8 )
ax.set_ylabel( 'Rate', size = 8 )
fig.savefig( "TCQPerformancePlot.pdf" )