# -*- coding: utf-8 -*-
"""
20-link pole balancing task with impoverished features, on-policy case
"""
__author__ = "Christoph Dann <cdann@cdann.de>"
import td
import examples
import numpy as np
import matplotlib.pyplot as plt
import dynamic_prog as dp
import util
import features
import policies
from task import LinearLQRValuePredictionTask
from experiments import experiment_main

import os
import pickle


dim=20
gamma = 0.95
sigma = np.ones(2*dim)*1.
dt = 0.1
mdp = examples.NLinkPendulumMDP(np.ones(dim)*.5, np.ones(dim)*.6, sigma=sigma, dt=dt)
phi = features.squared_diag(2*dim)
# phi = features.squared_tri(dim*(dim+1)/2+1)


n_feat = len(phi(np.zeros(mdp.dim_S)))
theta_p, _, _ = dp.solve_LQR(mdp, gamma=gamma)
theta_p = np.array(theta_p)
theta_o = theta_p.copy()
beh_policy = policies.LinearContinuous(theta=theta_p, noise=np.ones(dim)*0.4)
theta0 = 0. * np.ones(n_feat)

task = LinearLQRValuePredictionTask(mdp, gamma, phi, theta0,
                                    policy=beh_policy,
                                    normalize_phi=True, mu_next=1000)


# datasets = []
# datasets_dir = '/tmp/bbo/datasets'
# if not os.path.exists(datasets_dir):
#     os.makedirs(datasets_dir)

# for i in range(25):
#     (state_0s,
#      actions,
#      rewards,
#      state_1s,
#      restarts) = mdp.samples_cached(
#         n_iter=30000,
#         n_restarts=1,
#         policy=beh_policy,
#         seed=i, verbose=100)

#     terminals = np.roll(restarts, -1, axis=0)
#     terminals[-1] = False

#     samples = {
#         'state_0': state_0s,
#         'action': actions,
#         'state_1': state_1s,
#         'reward': rewards.reshape(-1, 1),
#         'terminal': terminals.reshape(-1, 1),
#         'info': {},
#     }

#     dataset = {'samples': samples}

#     datasets.append(dataset)

#     with open(os.path.join(datasets_dir, 'dataset-{i}.pkl'.format(i=i)), 'wb') as f:
#         pickle.dump(dataset, f)


value_functions_dir = '/tmp/bbo/value_functions'
if not os.path.exists(value_functions_dir):
    os.makedirs(value_functions_dir)
values = np.einsum(
    'i,bi->b',
    features.squared_tri(task.mdp.dim_S).param_forward(*task.V_true),
    task.mu_phi_full
)[..., None]

states = task.mu

breakpoint()


# with open(os.path.join(value_functions_dir, 'value_function.pkl'), 'wb') as f:
#     pickle.dump((states, values), f)

breakpoint()


methods = []


alpha = 1.0
bbo_v2 = td.BBOV2(
    alpha,
    D_a=beh_policy.dim_A,
    phi=phi)
bbo_v2.name = r"BBO-v2".format()
bbo_v2.color = "black"
methods.append(bbo_v2)

alpha = 1.0
bbo_v3 = td.BBOV3(
    alpha,
    D_a=beh_policy.dim_A,
    phi=phi)
bbo_v3.name = r"BBO-v3".format()
bbo_v3.color = "black"
methods.append(bbo_v3)


alpha = 0.0005
mu = 2.
gtd = td.GTD(alpha=alpha, beta=mu * alpha, phi=phi)
gtd.name = r"GTD $\alpha$={} $\mu$={}".format(alpha, mu)
gtd.color = "r"
methods.append(gtd)

alpha, mu = 0.0005, 1.
gtd = td.GTD2(alpha=alpha, beta=mu * alpha, phi=phi)
gtd.name = r"GTD2 $\alpha$={} $\mu$={}".format(alpha, mu)
gtd.color = "orange"
methods.append(gtd)

alpha = td.RMalpha(0.06, 0.5)
lam = .0
td0 = td.LinearTDLambda(alpha=alpha, lam=lam, phi=phi, gamma=gamma)
td0.name = r"TD({}) $\alpha$={}".format(lam, alpha)
td0.color = "k"
methods.append(td0)

alpha = .0005
lam = .0
td0 = td.LinearTDLambda(alpha=alpha, lam=lam, phi=phi, gamma=gamma)
td0.name = r"TD({}) $\alpha$={}".format(lam, alpha)
td0.color = "k"
methods.append(td0)

lam = 0.0
alpha = 0.0005
mu = .05
tdc = td.TDCLambda(alpha=alpha, mu = mu, lam=lam, phi=phi, gamma=gamma)
tdc.name = r"TDC({}) $\alpha$={} $\mu$={}".format(lam, alpha, mu)
tdc.color = "b"
methods.append(tdc)

alpha = .01
lam = 0.0
lstd = td.RecursiveLSPELambda(lam=lam, alpha=alpha, phi=phi, gamma=gamma)
lstd.name = r"LSPE({}) $\alpha$={}".format(lam, alpha)
lstd.color = "g"
methods.append(lstd)

lam = 0.
eps = 0.01
lstd = td.RecursiveLSTDLambda(lam=lam, eps=eps, phi=phi, gamma=gamma)
lstd.name = r"LSTD({}) $\epsilon$={}".format(lam, eps)
lstd.color = "g"
lstd.ls = "-."
methods.append(lstd)
#
alpha = 0.0005
lam = .2
lstd = td.FPKF(lam=lam, alpha = alpha, phi=phi, gamma=gamma)
lstd.name = r"FPKF({}) $\alpha$={}".format(lam, alpha)
lstd.color = "g"
lstd.ls = "-."
methods.append(lstd)

alpha = .0005
rg = td.ResidualGradientDS(alpha=alpha, phi=phi, gamma=gamma)
rg.name = r"RG DS $\alpha$={}".format(alpha)
rg.color = "brown"
rg.ls = "--"
methods.append(rg)

alpha = .003
rg = td.ResidualGradient(alpha=alpha, phi=phi, gamma=gamma)
rg.name = r"RG $\alpha$={}".format(alpha)
rg.color = "brown"
methods.append(rg)

brm = td.RecursiveBRMDS(phi=phi, eps=0.01)
brm.name = "BRMDS"
brm.color = "b"
brm.ls = "--"
methods.append(brm)

brm = td.RecursiveBRM(phi=phi, eps=1e5)
brm.name = "BRM"
brm.color = "b"
methods.append(brm)

l = 30000
error_every = 300
n_indep = 50
n_eps = 1
episodic = False
criterion = "MSE"
criteria = ["RMSPBE", "RMSBE", "RMSE", "MSPBE", "MSBE", "MSE"]
title = "11. 20-link Lin. Pole Balancing On-pol."
name = "link20_imp_onpolicy"


if __name__ == "__main__":
    experiment_main(**globals())
