from huge.algo import huge
from numpy import VisibleDeprecationWarning
import doodad as dd
import huge.doodad_utils as dd_utils
import argparse
import wandb
import yaml

def run(start_frontier = -1,
        frontier_expansion_rate=10,
        frontier_expansion_freq=-1,
        select_goal_from_last_k_trajectories=-1,
        throw_trajectories_not_reaching_goal=False,
        repeat_previous_action_prob=0.8,
        reward_layers="600,600", 
        fourier=False,
        fourier_goal_selector=False,
        command_goal_if_too_close=False,
        display_trajectories_freq=20,
        label_from_last_k_steps=-1,
        label_from_last_k_trajectories=-1,
        contrastive=False,
        k_goal=1, use_horizon=False, 
        sample_new_goal_freq=1, 
        weighted_sl=False, 
        buffer_size=20000, 
        stopped_thresh=0.05, 
        eval_episodes=200, 
        maze_type=0, 
        random_goal=False,
        explore_length=20, 
        desired_goal_sampling_freq=0.0,
        num_blocks=1, 
        deterministic_rollout=False,
        network_layers="128,128", 
        epsilon_greedy_rollout=0, 
        epsilon_greedy_exploration=0.2, 
        remove_last_k_steps=8, 
        select_last_k_steps=8, 
        eval_freq=5e3, 
        expl_noise_std = 1,
        goal_selector_epochs=400,
        stop_training_goal_selector_after=-1,
        normalize=False,
        task_config="slide_cabinet,microwave",
        human_input=False,
        save_videos = True, 
        continuous_action_space=False,
        goal_selector_batch_size=64,
        goal_threshold=-1,
        check_if_stopped=False,
        human_data_file='',
        env_name='pointmass_empty',train_goal_selector_freq=10, 
        distance_noise_std=0,  exploration_when_stopped=True, 
        remove_last_steps_when_stopped=True,  
        goal_selector_num_samples=100, data_folder="data", display_plots=False, render=False,
        explore_episodes=5, gpu=0, sample_softmax=False, seed=0, load_goal_selector=False,
        batch_size=100, train_regression=False,load_buffer=False, save_buffer=-1, policy_updates_per_step=1,
        select_best_sample_size=1000, max_path_length=50, lr=5e-4, train_with_preferences=True,
        log_tensorboard=False, use_oracle=False, exploration_horizon=30, 
        use_wrong_oracle=False,
        pretrain_goal_selector=False,
        pretrain_policy=False,
        num_demos=0,
        demo_epochs=100000,
        demo_goal_selector_epochs=1000,
        goal_selector_buffer_size=50000,
        fill_buffer_first_episodes=0,
        run_path='',
        policy_name='',
        max_timesteps=2e-4, goal_selector_name='', **extra_params):

    import gym
    import numpy as np
    
    import rlkit.torch.pytorch_util as ptu
    ptu.set_gpu_mode(True, gpu)

    import rlutil.torch as torch
    import rlutil.torch.pytorch_util as ptu

    # Envs

    from huge import envs
    from huge.envs.env_utils import DiscretizedActionEnv

    # Algo
    from huge.algo import buffer, variants, networks

    ptu.set_gpu(gpu)
    if not gpu:
        print('Not using GPU. Will be slow.')

    torch.manual_seed(seed)
    np.random.seed(seed)
    
    env = envs.create_env(env_name, task_config, num_blocks, random_goal, maze_type, continuous_action_space, goal_threshold)

    env_params = envs.get_env_params(env_name)
    env_params['max_trajectory_length']=max_path_length
    env_params['network_layers']=network_layers
    env_params['reward_layers'] = reward_layers
    env_params['buffer_size'] = buffer_size
    env_params['use_horizon'] = use_horizon
    env_params['fourier'] = fourier
    env_params['fourier_goal_selector'] = fourier_goal_selector
    env_params['normalize']=normalize
    env_params['env_name'] = env_name
    env_params['goal_selector_buffer_size'] = goal_selector_buffer_size

    print(env_params)
    env_params['goal_selector_name']=goal_selector_name
    env_params['continuous_action_space'] = continuous_action_space
    env, policy, goal_selector, replay_buffer, goal_selector_buffer, gcsl_kwargs = variants.get_params(env, env_params)

    if run_path != "":
        expert_policy = wandb.restore(f"checkpoint/{policy_name}.h5", run_path=run_path)
        policy.load_state_dict(torch.load(expert_policy.name, map_location=f"cuda:{gpu}"))
        expert_goal_selector = wandb.restore(f"checkpoint/{goal_selector_name}.h5", run_path=run_path)
        goal_selector.load_state_dict(torch.load(expert_goal_selector.name, map_location=f"cuda:{gpu}"))

    gcsl_kwargs['lr']=lr
    gcsl_kwargs['max_timesteps']=max_timesteps
    gcsl_kwargs['batch_size']=batch_size
    gcsl_kwargs['max_path_length']=max_path_length
    gcsl_kwargs['policy_updates_per_step']=policy_updates_per_step
    gcsl_kwargs['explore_episodes']=explore_episodes
    gcsl_kwargs['eval_episodes']=eval_episodes
    gcsl_kwargs['eval_freq']=eval_freq
    gcsl_kwargs['remove_last_k_steps']=remove_last_k_steps
    gcsl_kwargs['select_last_k_steps']=select_last_k_steps
    gcsl_kwargs['continuous_action_space']=continuous_action_space
    gcsl_kwargs['expl_noise_std'] = expl_noise_std
    gcsl_kwargs['check_if_stopped'] = check_if_stopped
    gcsl_kwargs['num_demos'] = num_demos
    gcsl_kwargs['demo_epochs'] = demo_epochs
    gcsl_kwargs['demo_goal_selector_epochs'] = demo_goal_selector_epochs
    print(gcsl_kwargs)

    algo = huge.HUGE(
        env,
        policy,
        goal_selector,
        replay_buffer,
        goal_selector_buffer,
        log_tensorboard=log_tensorboard,
        train_with_preferences=train_with_preferences,
        use_oracle=use_oracle,
        save_buffer=save_buffer,
        train_regression=train_regression,
        load_goal_selector=load_goal_selector,
        sample_softmax = sample_softmax,
        display_plots=display_plots,
        fill_buffer_first_episodes=fill_buffer_first_episodes,
        render=render,
        data_folder=data_folder,
        goal_selector_num_samples=goal_selector_num_samples,
        train_goal_selector_freq=train_goal_selector_freq,
        remove_last_steps_when_stopped=remove_last_steps_when_stopped,
        exploration_when_stopped=exploration_when_stopped,
        distance_noise_std=distance_noise_std,
        save_videos=save_videos,
        human_input=human_input,
        epsilon_greedy_exploration=epsilon_greedy_exploration,
        epsilon_greedy_rollout=epsilon_greedy_rollout,
        explore_length=explore_length,
        stopped_thresh=stopped_thresh,
        weighted_sl=weighted_sl,
        sample_new_goal_freq=sample_new_goal_freq,
        k_goal=k_goal,
        frontier_expansion_freq=frontier_expansion_freq,
        frontier_expansion_rate=frontier_expansion_rate,
        start_frontier=start_frontier,
        select_goal_from_last_k_trajectories=select_goal_from_last_k_trajectories,
        throw_trajectories_not_reaching_goal=throw_trajectories_not_reaching_goal,
        command_goal_if_too_close=command_goal_if_too_close,
        display_trajectories_freq=display_trajectories_freq,
        label_from_last_k_steps=label_from_last_k_steps,
        label_from_last_k_trajectories=label_from_last_k_trajectories,
        contrastive=contrastive,
        deterministic_rollout=deterministic_rollout,
        repeat_previous_action_prob=repeat_previous_action_prob,
        desired_goal_sampling_freq=desired_goal_sampling_freq,
        goal_selector_batch_size=goal_selector_batch_size,
        goal_selector_epochs=goal_selector_epochs,
        use_wrong_oracle=use_wrong_oracle,
        human_data_file=human_data_file,
        stop_training_goal_selector_after=stop_training_goal_selector_after,
        select_best_sample_size=select_best_sample_size,
        pretrain_goal_selector=pretrain_goal_selector,
        pretrain_policy=pretrain_policy,
        env_name=env_name,
        **gcsl_kwargs
    )

    algo.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed",type=int, default=0)
    parser.add_argument("--gpu",type=int, default=0)
    parser.add_argument("--env_name", type=str, default='pointmass_empty')
    parser.add_argument("--comment", type=str, default='')
    parser.add_argument("--method", type=str, default='hugrl')
    parser.add_argument("--epsilon_greedy_rollout",type=float, default=None)
    parser.add_argument("--explore_episodes",type=int, default=None)
    parser.add_argument("--max_path_length",type=int, default=None)
    parser.add_argument("--repeat_previous_action_prob",type=float, default=None)
    parser.add_argument("--task_config", type=str, default=None)
    parser.add_argument("--train_goal_selector_freq",type=int, default=None)
    parser.add_argument("--goal_selector_num_samples",type=int, default=None)
    parser.add_argument("--pretrain_policy", action="store_true", default=False)
    parser.add_argument("--pretrain_goal_selector", action="store_true", default=False)
    parser.add_argument("--num_demos", type=int, default=None)
    parser.add_argument("--demo_epochs", type=int, default=None)
    parser.add_argument("--lr", type=float, default=None)
    parser.add_argument("--stop_training_goal_selector_after", type=int, default=None)
    parser.add_argument("--goal_selector_buffer_size", type=int, default=None)
    parser.add_argument("--demo_goal_selector_epochs", type=int, default=None)
    parser.add_argument("--eval_freq", type=int, default=None)
    parser.add_argument("--max_timesteps", type=int, default=None)
    parser.add_argument("--start_frontier", type=int, default=None)
    parser.add_argument("--num_blocks", type=int, default=None)
    parser.add_argument("--desired_goal_sampling_freq", type=float, default=None)
    parser.add_argument("--human_data_file", type=str, default=None)
    parser.add_argument("--run_path", type=str, default=None)
    parser.add_argument("--policy_name", type=str, default=None)
    parser.add_argument("--fill_buffer_first_episodes", type=int, default=None)
    parser.add_argument("--goal_selector_name", type=str, default=None)

    args = parser.parse_args()

    with open("config.yaml") as file:
        config = yaml.safe_load(file)

    params = config["common"]


    params.update(config[args.method])
    params.update(config[args.env_name])

    for key in args.__dict__:
        value = args.__dict__[key]
        if value is not None:
            params[key] = value

    data_folder_name = f"{args.env_name}_"

    wandb_suffix = args.method
    data_folder_name = data_folder_name+"_use_oracle_"

    data_folder_name = data_folder_name + str(args.seed)

    params["data_folder"] = data_folder_name

    wandb.init(project=args.env_name+"gcsl_preferences", name=f"{args.env_name}_{args.method}_{args.seed}_{args.comment}", config=params)

    print("params before run", params)

    run(**params)
    # dd_utils.launch(run, params, mode='local', instance_type='c4.xlarge')
