import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt

from exact_rdp import *

def plot_fig6(fname_nondp, fname_dp):
	dp_results = pickle.load(open(fname_dp, 'rb'))
	nondp_results = pickle.load(open(fname_nondp, 'rb'))

	dp_dpvi_params = dp_results[0]
	dp_dp_mcmc_params = dp_results[1]
	dp_chain = dp_results[2]
	dp_privacy_params = dp_results[3]

	nondp_dpvi_params = nondp_results[0]
	nondp_dp_mcmc_params = nondp_results[1]
	nondp_chain = nondp_results[2]
	nondp_privacy_params = nondp_results[3]

	# Compute the privacy budget for number of iterations
	T = dp_dp_mcmc_params['T']
	b = dp_dp_mcmc_params['B']
	N = dp_dp_mcmc_params['N']
	dp_dpvi_eps = dp_dpvi_params['eps']
	Ts = np.arange(1000, T+1)
	max_alpha = 100
	max_alpha = min(max_alpha,b//5)
	delta = 1/N
	min_eps_T = np.inf*np.ones(len(Ts))
	q = b/N
	for max_alpha_ in range(3, max_alpha):
		eps_alpha_list = [rd_approx(alpha, b) for alpha in range(2, max_alpha_+1)]
		amplified_eps = amplified_RDP(eps_alpha_list, max_alpha_, q)
		total_eps = [from_RDP_to_DP(T*amplified_eps, max_alpha_, delta) for T in Ts]
		min_eps_T = np.minimum(min_eps_T, total_eps)
	min_eps_T = min_eps_T + dp_dpvi_eps

	## Comparison between DP and non-DP
	true_mean = nondp_chain.mean(0)
	true_var = nondp_chain.var(0)

	dp_mean = np.zeros([len(Ts), 2])
	dp_var = np.zeros([len(Ts), 2])

	dp_mean_sem = np.zeros([len(Ts), 2])
	dp_var_sem = np.zeros([len(Ts), 2])

	n_runs = dp_chain.shape[0]
	for i, t in enumerate(Ts):
		dp_mean[i] = np.abs(np.mean([chain[:t].mean(0) for chain in dp_chain], axis=0)-true_mean)
		dp_var[i] = np.abs(np.mean([chain[:t].var(0) for chain in dp_chain], axis=0)-true_var)
		dp_mean_sem[i] = np.std([(chain[:t].mean(0)-true_mean) for chain in dp_chain], axis=0)/np.sqrt(n_runs)
		dp_var_sem[i] = np.std([(chain[:t].var(0)-true_var) for chain in dp_chain], axis=0)/np.sqrt(n_runs)

	plt.cla()
	plt.errorbar(min_eps_T, dp_mean[:,0], yerr=dp_mean_sem[:,0], alpha = 0.01,\
			color='red')
	plt.errorbar(min_eps_T, dp_mean[:,1], yerr=dp_mean_sem[:,1], alpha = 0.01,\
			color='blue')
	plt.plot(min_eps_T, dp_mean[:,0], color='red', label=r'$\theta_1$')
	plt.plot(min_eps_T, dp_mean[:,1], color='blue', label=r'$\theta_2$')
	plt.legend(loc='best')
	plt.xlabel(r'$\epsilon$', fontsize=16)
	plt.ylabel(r'$|\mu_{true}-\mu_{DP}|$', fontsize=16)
	plt.savefig('eps_vs_mean_acc.pdf',format='pdf')

	plt.cla()
	plt.errorbar(min_eps_T, dp_var[:,0], yerr=dp_var_sem[:,0], alpha = 0.01,\
			color='red')
	plt.errorbar(min_eps_T, dp_var[:,1], yerr=dp_var_sem[:,1], alpha = 0.01,\
			color='blue')
	plt.plot(min_eps_T, dp_var[:,0], color='red', label=r'$\theta_1$')
	plt.plot(min_eps_T, dp_var[:,1], color='blue', label=r'$\theta_2$')
	plt.legend(loc='best')
	plt.xlabel(r'$\epsilon$', fontsize=16)
	plt.ylabel(r'$|\sigma^2_{true}-\sigma^2_{DP}|$', fontsize=16)
	plt.savefig('eps_vs_var_acc.pdf',format='pdf')
	plt.close()
