from re import I
from huge.algo import huge_human
from numpy import VisibleDeprecationWarning
import doodad as dd
import huge.doodad_utils as dd_utils
import argparse
import wandb
import io
import imageio
from imageio import v3 as iio
from fastapi import Response, FastAPI, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()
app.add_middleware(CORSMiddleware, 
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"])

global start
start = 0

@app.get("/image", response_class=Response)
async def get_image(answer =None, background_tasks:BackgroundTasks=None):
    print("answer", answer)
    # print("gait", gait)
    # print("img_selector", img_selector)
    # print("x_vel", x_vel)
    # print("y_vel", y_vel)
    # print("yaw_vel", yaw_vel)
    # print("body_height", body_height)
    # print("step_freq", step_freq)
    # print("footswing_height", footswing_height)
    # print("pitch", pitch)
    # print("stance_width", stance_width)
    # print("test img_index", img_index)

    global start
    start += 1

    if start == 1:
        background_tasks.add_task(algo.train)
    else:
        if answer == "right":
            label = 1
        elif answer == "left":
            label = 0
        else:
            label=None

        print("label", label)
        
        im = algo.add_point_and_fetch_case(label)
        #im = np.zeros((64,64,3))

        if start % algo.train_goal_selector_freq == 0:
            algo.collect_and_train_goal_selector()

        with io.BytesIO() as buf:
            iio.imwrite(buf, im, plugin="pillow", format="PNG")
            im_bytes = buf.getvalue()
            
        headers = {'Content-Disposition': 'inline; filename="robot.png"', "CrossOrigin":"Anonymous"}
        return Response(im_bytes, headers=headers, media_type='image/png')
    

# TODO: We can use environment variables
def run( output_dir='/tmp',start_frontier = -1,
        frontier_expansion_rate=10,
        frontier_expansion_freq=-1,
        select_goal_from_last_k_trajectories=-1,
        throw_trajectories_not_reaching_goal=False,
        repeat_previous_action_prob=0.8,
        greedy_before_stopping=False, 
        reward_layers="600,600", 
        fourier=False,
        fourier_goal_selector=False,
        command_goal_if_too_close=False,
        display_trajectories_freq=20,
        label_from_last_k_steps=-1,
        label_from_last_k_trajectories=-1,
        contrastive=False,
        pick_or_place=False,
        k_goal=1, use_horizon=False, 
        sample_new_goal_freq=1, 
        weighted_sl=False, 
        buffer_size=20000, 
        stopped_thresh=0.05, 
        eval_episodes=200, 
        maze_type=0, 
        random_goal=False,
        explore_length=20, 
        desired_goal_sampling_freq=0.0,
        num_blocks=1, 
        deterministic_rollout=False,
        train_policy_freq=10, 
        network_layers="128,128", 
        epsilon_greedy_rollout=0, 
        epsilon_greedy_exploration=0.2, 
        remove_last_k_steps=8, 
        select_last_k_steps=8, 
        eval_freq=5e3, 
        expl_noise_mean = 0,
        expl_noise_std = 1,
        goal_selector_epochs=400,
        stop_training_goal_selector_after=-1,
        no_training_goal_selector=False,
        normalize=False,
        set_desired_when_stopped=True, 
        task_config="slide_cabinet,microwave",
        human_input=False,
        logger_dump=False, save_videos = True, 
        continuous_action_space=False,
        goal_selector_batch_size=64,
        goal_threshold=-1,
        check_if_stopped=False,
        not_save_videos=False,
        human_data_file='',
        env_name='pointmass_empty',train_goal_selector_freq=10, 
        distance_noise_std=0,  exploration_when_stopped=True, 
        remove_last_steps_when_stopped=True,  
        goal_selector_num_samples=100, data_folder="data", display_plots=False, render=False,
        explore_episodes=5, gpu=0, sample_softmax=False, seed=0, load_goal_selector=False,
        batch_size=100, train_regression=False,load_buffer=False, save_buffer=-1, policy_updates_per_step=1,
        select_best_sample_size=1000, max_path_length=50, lr=5e-4, train_with_preferences=True,
        start_policy_timesteps=500, log_tensorboard=False, use_oracle=False, exploration_horizon=30, 
        use_wrong_oracle=False,
        comment="", max_timesteps=2e-4, goal_selector_name='', **kwargs):

    import gym
    import numpy as np
    from rlutil.logging import log_utils, logger
    
    import rlkit.torch.pytorch_util as ptu
    ptu.set_gpu_mode(True, 0)

    import rlutil.torch as torch
    import rlutil.torch.pytorch_util as ptu

    # Envs

    from huge import envs
    from huge.envs.env_utils import DiscretizedActionEnv

    # Algo
    from huge.algo import buffer, variants, networks

    ptu.set_gpu(gpu)
    if not gpu:
        print('Not using GPU. Will be slow.')

    torch.manual_seed(seed)
    np.random.seed(seed)
    

    env = envs.create_env(env_name, task_config, num_blocks, random_goal, maze_type, pick_or_place, continuous_action_space, goal_threshold)

    env_params = envs.get_env_params(env_name)
    env_params['max_trajectory_length']=max_path_length
    env_params['network_layers']=network_layers
    env_params['reward_layers'] = reward_layers
    env_params['buffer_size'] = buffer_size
    env_params['use_horizon'] = use_horizon
    env_params['fourier'] = fourier
    env_params['pick_or_place'] = pick_or_place
    env_params['fourier_goal_selector'] = fourier_goal_selector
    env_params['normalize']=normalize
    env_params['env_name'] = env_name

    print(env_params)
    env_params['goal_selector_name']=goal_selector_name
    env_params['continuous_action_space'] = continuous_action_space
    env, policy, goal_selector, replay_buffer, goal_selector_buffer, gcsl_kwargs = variants.get_params(env, env_params)

    gcsl_kwargs['lr']=lr
    gcsl_kwargs['max_timesteps']=max_timesteps
    gcsl_kwargs['batch_size']=batch_size
    gcsl_kwargs['max_path_length']=max_path_length
    gcsl_kwargs['policy_updates_per_step']=policy_updates_per_step
    gcsl_kwargs['explore_episodes']=explore_episodes
    gcsl_kwargs['eval_episodes']=eval_episodes
    gcsl_kwargs['eval_freq']=eval_freq
    gcsl_kwargs['remove_last_k_steps']=remove_last_k_steps
    gcsl_kwargs['select_last_k_steps']=select_last_k_steps
    gcsl_kwargs['train_policy_freq'] = train_policy_freq
    gcsl_kwargs['continuous_action_space']=continuous_action_space
    gcsl_kwargs['expl_noise_mean'] = expl_noise_mean
    gcsl_kwargs['expl_noise_std'] = expl_noise_std
    gcsl_kwargs['check_if_stopped'] = check_if_stopped
    print(gcsl_kwargs)

    global algo
    algo = huge_human.GCSL(
        env,
        policy,
        goal_selector,
        replay_buffer,
        goal_selector_buffer,
        log_tensorboard=log_tensorboard,
        train_with_preferences=train_with_preferences,
        use_oracle=use_oracle,
        save_buffer=save_buffer,
        train_regression=train_regression,
        load_goal_selector=load_goal_selector,
        sample_softmax = sample_softmax,
        display_plots=display_plots,
        render=render,
        data_folder=data_folder,
        goal_selector_num_samples=goal_selector_num_samples,
        train_goal_selector_freq=train_goal_selector_freq,
        remove_last_steps_when_stopped=remove_last_steps_when_stopped,
        exploration_when_stopped=exploration_when_stopped,
        distance_noise_std=distance_noise_std,
        save_videos=save_videos,
        logger_dump=logger_dump,
        human_input=human_input,
        epsilon_greedy_exploration=epsilon_greedy_exploration,
        epsilon_greedy_rollout=epsilon_greedy_rollout,
        set_desired_when_stopped=set_desired_when_stopped,
        explore_length=explore_length,
        greedy_before_stopping=greedy_before_stopping,
        stopped_thresh=stopped_thresh,
        weighted_sl=weighted_sl,
        sample_new_goal_freq=sample_new_goal_freq,
        k_goal=k_goal,
        frontier_expansion_freq=frontier_expansion_freq,
        frontier_expansion_rate=frontier_expansion_rate,
        start_frontier=start_frontier,
        select_goal_from_last_k_trajectories=select_goal_from_last_k_trajectories,
        throw_trajectories_not_reaching_goal=throw_trajectories_not_reaching_goal,
        command_goal_if_too_close=command_goal_if_too_close,
        display_trajectories_freq=display_trajectories_freq,
        label_from_last_k_steps=label_from_last_k_steps,
        label_from_last_k_trajectories=label_from_last_k_trajectories,
        contrastive=contrastive,
        deterministic_rollout=deterministic_rollout,
        repeat_previous_action_prob=repeat_previous_action_prob,
        desired_goal_sampling_freq=desired_goal_sampling_freq,
        goal_selector_batch_size=goal_selector_batch_size,
        goal_selector_epochs=goal_selector_epochs,
        not_save_videos=not_save_videos,
        use_wrong_oracle=use_wrong_oracle,
        human_data_file=human_data_file,
        no_training_goal_selector=no_training_goal_selector,
        stop_training_goal_selector_after=stop_training_goal_selector_after,
        **gcsl_kwargs
    )

    exp_prefix = 'example/%s/gcsl/' % (env_name,)
    #if logger_dump:
    #    log_utils.setup_logger(exp_prefix=exp_prefix, log_base_dir=output_dir)
    print("about to start training")
    #algo.train()
    #algo.startup()
    return algo

# TODO
# Solution: create algo, do startup algo, with random trajectories and all
# collect some labels, after a given number of labels, proceed with multiple rollouts
# continue looping like this 

# TODO: use config here

#parser.add_argument("--start_hallucination",type=int, default=0)
env_name="pointmass_rooms"
params = {
        'seed': 0,
        'env_name': env_name, #'pointmass_rooms', #['lunar', 'pointmass_empty','pointmass_rooms', 'pusher', 'claw', 'door'],
        'gpu': 0,
        'use_preferences': True,
        'log_tensorboard': True, #args.log_tensorboard,
        'train_with_preferences': True,
        'use_oracle': False,
        'lr': 5e-4,
        'max_timesteps':2000000,
        'batch_size':64,
        'max_path_length':70,
        'select_best_sample_size':1000,
        'policy_updates_per_step':100,
        'sample_softmax':True,
        'explore_episodes':5,
        'render':False,
        'display_plots':True,
        'goal_selector_num_samples':20, #todo: it shouldn't matter
        'train_goal_selector_freq':10, 
        'remove_last_steps_when_stopped': True,
        'exploration_when_stopped': True,
        'eval_episodes':5,
        'human_input':False,
        'epsilon_greedy_exploration':0,
        'epsilon_greedy_rollout':1,
        'eval_freq':5e3,
        'explore_length':20,
        'network_layers':"400,600,600,300",
        'random_goal':False,
        'maze_type':3,
        'stopped_thresh':0.05,
        'buffer_size':1000,
        'sample_new_goal_freq':5,
        'reward_layers':"400,600,600,300",
        'start_frontier':40,
        'frontier_expansion_rate':5,
        'frontier_expansion_freq':50,
        'select_goal_from_last_k_trajectories':100,
        'fourier':True,
        'fourier_goal_selector':True,
        'label_from_last_k_steps':20,
        'label_from_last_k_trajectories':50,
        'repeat_previous_action_prob':0.25,
        'continuous_action_space':False,
        'desired_goal_sampling_freq':0.1,
        'goal_threshold':0.05,
        'goal_selector_epochs':1000,
        'check_if_stopped':True,
        'human_data_file':"",
        'use_wrong_oracle':False,
        'no_training_goal_selector':False,
        'save_videos':False,
        'not_save_videos':True,
        #'start_hallucination': args.start_hallucination
    }
wandb.init(project=env_name+"gcsl_preferences", name=f"{env_name}", config={
    },
    )

global algo
algo = run(**params)
# dd_utils.launch(run, params, mode='local', instance_type='c4.xlarge')
