"""
Script for creating experiments presented in the main paper.
All results are from tempered likelihoods.
"""

import numpy as np
import pandas as pd
import pickle
from mog_main_prop_var import fig3_data
from mog_main import fig5_data
from mog_main_nruns import fig6_data

from plot_rdp_bounds import plot_fig1, plot_fig2
from plot_var_vs_clip import plot_fig3
from plot_approx_errors import plot_fig4
from plot_scatter import plot_fig5
from plot_eps_vs_acc_nruns import plot_fig6

from X_corr import get_x_corr_params
# Choose which figures to reproduce
fig1=1
fig2=1
fig3=0
fig4=0
fig5=0
fig6=0

from_scratch = 0 # Whether to run experiments from scratch or load results from file

if __name__ == '__main__':
	
	# Create or load datas for figures
	if fig3:
		np.random.seed(2)
		if from_scratch : dp_pickle_fname_fig_3 = fig3_data()
		else : dp_pickle_fname_fig_3 = './results/dp_mcmc_results_temped_multiple_prop_vars_23_1.p'
	if fig5:
		np.random.seed(2)
		if from_scratch: non_dp_pickle_fname_fig_5 = fig5_data(privacy=False)
		else : non_dp_pickle_fname_fig_5 = './results/non_dp_mcmc_results_temped_23_1.p'
		np.random.seed(2)
		if from_scratch : dp_pickle_fname_fig_5 = fig5_data(privacy=True)
		else : dp_pickle_fname_fig_5 = './results/dp_mcmc_results_temped_23_1.p'
	if fig6:
		np.random.seed(2)
		if from_scratch:
			if fig5 : non_dp_pickle_fname_fig_6 = non_dp_pickle_fname_fig_5
			else : non_dp_pickle_fname_fig_6 = fig5_data(privacy=False)
		else : non_dp_pickle_fname_fig_6 = './results/non_dp_mcmc_results_temped_23_1.p'
		np.random.seed(2)
		if from_scratch : dp_pickle_fname_fig_6 = fig6_data()
		else : dp_pickle_fname_fig_6 = './results/dp_mcmc_results_temped_n_runs_23_1.p'

	# Plot results
	if fig1:
		qs = np.linspace(0.001, 0.01, 100)
		Ns = [100000, 1000000, 10000000]
		plot_fig1(qs, Ns, T=5000)
	if fig2:
		T_plot = 100000
		Ts = np.arange(10, T_plot, 10)
		bs = [100, 1000, 10000]
		plot_fig2(Ts, bs, q=0.001)	
	if fig3 : plot_fig3(dp_pickle_fname_fig_3)
	if fig4 : 
		np.random.seed(1606)
		if from_scratch : 
			T = 20000
			x_max = 10
			n_points = 1000
			all_C = [1.0, 1.5, 1.75, 2.0]
			lr = 1e-2
			for C in all_C:
				path_to_file = './X_corr/supplement_test_x_corr_params_C{}.pickle'.format(np.round(C,2))
				get_x_corr_params(x_max=x_max, n_points=n_points,C=C, lr=lr, T=T, path_to_file=path_to_file)
			# Note, it takes a while to run above
			plot_fig4()
		else : plot_fig4('./results/approx_error_res.pickle')
	if fig5 : plot_fig5(non_dp_pickle_fname_fig_5, dp_pickle_fname_fig_5)
	if fig6 : plot_fig6(non_dp_pickle_fname_fig_6, dp_pickle_fname_fig_6)

