"""
    example:
    $ python LQR_Scalar.py -T 8 -a -1.0 -B 4096

    Output:

    A npy file in the data directory with a name that looks like:
        Vhat_a_-1.0_T_8.0_B_4096_seed_0.npy

    It stores a numpy array of shape (trials, len(h)), each row
        representing a trial with different choices of h.
        The trials will be averaged when plotting
"""

import os
import time
import argparse
import numpy as np
from tqdm import trange
from scipy.linalg import toeplitz

# For reproducibility
seed = 0
np.random.seed(seed)

parser = argparse.ArgumentParser()
parser.add_argument('-a', type=float, required=True, help='parameter a<0')
parser.add_argument('-T', type=float, required=True, help='horizon T')
parser.add_argument('-B', type=float, required=True, help='data budget B')
args = parser.parse_args()


T = args.T
N0 = 2**16
h0 = T/N0 # 2^-16
B = args.B #B = 2**12
imax = 16
M0 = int(B*(2**imax)/N0) # longest number of trajectories that we need
sigma = 1
q = 1
print(f"Hyperparams: T={T},B={B},h0={h0},N0={N0},M0={M0},sigma={sigma},q={q}")
h_list = 2**np.arange(1,imax+1) * h0 # [2^-15, ... 2^-1]
print(f"List of h: h = {h_list}")
a = args.a
print(f"a={a} and equivalent a in DT: {np.exp(a)}")

def gen_CT_data(vectorize=False, dtype=np.float32):
    """ generate M0 trajectories, each with N0 samples as a proxy to CT system"""
    # we are using left Riemann Sum so only need w(0),w(1),...w(N0-2)
    w = np.random.normal(scale=sigma * np.sqrt(h0),size=(N0-1, M0)).astype(dtype)
    x = np.zeros((N0, M0), dtype=dtype)
    if a != 0:
        if vectorize: # inefficient due to the large matrix multiplication
            a_eq = np.exp(a*h0)
            col = [ a_eq**i for i in range(N0-1)]
            row = [0] * (N0-1)
            row[0] = 1
            A_n = toeplitz(col, row)
            x[1:, :] = A_n@w # => result (N0-1 by M0), each col is a trajectory
        else:
            a_eq = np.exp(a*h0)
            for k in trange(1, N0, desc='generate x(t)'): # loop over all time steps t=0,N0-1
                x[k,:] = a_eq * x[k-1,:] + w[k-1,:]
        x_sq = x ** 2
    else: # a=0, just summing w(t)
        x_sq = np.cumsum(w, axis=0)** 2
    return x_sq

def compute_Vhat(h, x_sq, verbose=False):
    """ assume x_sq(t) is given """
    N = int(T/h) # number of samples in one trajectory
    M = max(int(B/N), 1) # number of trajectories, at least 1
    h_ratio = int(h/h0)
    if verbose:
        print(f"h={h},N={N},M={M}, h/h0={h_ratio}")

    xh_sq = x_sq[::h_ratio, :M]
    # we sampled the x(kh)^2 from the CT data, each xh_sq[i] is of length N

    Jh = q * h * np.sum( xh_sq[:-1, :], axis=0)
    Jh += q * (T - (xh_sq.shape[0]-1)*h) * xh_sq[-1, :]
    # the last pt accounts for the remaining integral
    # it works for partial (not considered) and full trajectory
    if verbose:
        print(f"Jh was {Jh[:10]}")
        Jh_orig = q * h * np.sum( xh_sq, axis=0)
        print(f"Jh_orig was {Jh_orig[:10]}")

    # each Jh[i] is the return for i-th trajectory
    if verbose:
        print(f"shape of xh_sq: {xh_sq.shape}")
        print(f"shape of Jh: {Jh.shape}")

    Vhat = np.mean(Jh) # avg over M trajectories

    return Vhat

N_trials = 50
Vhat = np.zeros((N_trials,len(h_list)))

start = time.time()
for j in trange(N_trials, desc="trials"): # outer loop of many trials to approximate the E[(Vh-V)^2]
    x_sq = gen_CT_data()
    for i, h in enumerate(h_list):
        if i == 0 and j == 0:
            Vhat[j,i] = compute_Vhat(h, x_sq, True)
        else:
            Vhat[j,i] = compute_Vhat(h, x_sq, False)

print(f"total time for {N_trials} trials: {time.time()-start}")


path = 'data'
if not os.path.exists(path):
    os.makedirs(path)
fname = os.path.join(path, f'Vhat_a_{a}_T_{T}_B_{B}_seed_{seed}.npy')
with open(fname, 'wb') as f:
    np.save(f, Vhat)

