import argparse
import numpy as np
from os.path import join, exists
from os import makedirs
from matplotlib import pyplot as plt


def MSE_over_h(args):
    #h_list = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1]
    h_list = args.h_list
    # h_list = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]
    T = args.time_limit # 50
    V = args.V
    D_list = args.D
    print(f"V is {V}; D_list is {D_list}; h_list is {h_list}")

    fig, ax = plt.subplots()
    for D in D_list:
        Vhat = np.load(join(args.data_dir, f'Vhat_D_{D}.npy'))
        # Vhat: [nb_trails, nb_h]
        obj = np.mean((Vhat - V)**2, axis=0)
        nb_runs = Vhat.shape[0] 
        std = np.std((Vhat - V)**2, axis=0) / np.sqrt(nb_runs)
        ax.plot(h_list, obj,'-o', linewidth=2, label=f'B={D:.0f}', markeredgewidth=1.5,
                            alpha=0.6, 
                            #markerfacecolor=(0.7,0.7,0.7,0.5),
                            markeredgecolor=(0,0,0,1) ) # lower the frequency of data when plotting
        ax.fill_between(h_list, obj+ std, obj- std, alpha=0.1)
    ax.set_yscale('log')
    ax.set_xscale('log')
    #ax.set_title(r'$\mathbb{E}[(\hat{V}_M(h) - V)^2]$'+f', T={T:.0f}, {args.env_name}')
    ax.set_title(r'$(\hat{V}_M(h) - V)^2$'+f', T={T:.0f}, {args.env_name}')
    ax.set_xlabel("h")
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.legend(bbox_to_anchor = (1.05, 0.6))
    fig.tight_layout()
    fname = f'MSE_T_{T}_{args.env_name}' 
    img_dir = join(args.data_dir,'img')
    if not exists(img_dir):
        makedirs(img_dir) 
    fig.savefig(join(img_dir,fname+'.pdf'), bbox_inches='tight', dpi=300)
    if args.show_img:
        plt.show()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='log_pendulum_Vhat',
                        help='directory.')
    parser.add_argument('--env_name', type=str, default='pendulum',
                        help='env name.')
    parser.add_argument('--V', type=float, default=-0.16, help='true value')
    parser.add_argument('--D', type=int, nargs="+", default=[10000], help='data budget')
    parser.add_argument('--show_img', action='store_true',
                        help='if true, run plt.show()')
    parser.add_argument('--h_list', type=float, nargs="+", default=None,
                        help='list of h')
    parser.add_argument('--time_limit', type=float, default=None,
                        help='specify environment time limit (physical time).')
    args = parser.parse_args()

    MSE_over_h(args) 
