import copy

import torch

from .. import buffers as B
from . import babyai, minihack, rnd


def create_models_and_buffers(env, FLAGS):
    if FLAGS.env.startswith("BabyAI"):
        model, generator_model = babyai.create_models(env, FLAGS)
        num_instrs = len(babyai.language.INSTRS)
        obs_spaces = {
            "frame": (env.observation_space.shape, torch.uint8),
            "partial_frame": (env.partial_observation_space.shape, torch.uint8),
        }

        learner_model = copy.deepcopy(model).to(FLAGS.device)
        learner_generator_model = copy.deepcopy(generator_model).to(FLAGS.device)

    else:
        model, generator_model = minihack.create_models(env, FLAGS, device="cpu")
        obs_spaces = {
            k: (v.shape, v.dtype) for k, v in env.observation_space.spaces.items()
        }
        num_instrs = generator_model.num_actions
        # Messages
        obs_spaces["message"] = ((25,), torch.int64)
        obs_spaces["message_len"] = ((1,), torch.int64)
        obs_spaces["split_messages"] = ((5, 25), torch.int64)
        obs_spaces["split_messages_len"] = ((5, 1), torch.int64)

        learner_model, learner_generator_model = minihack.create_models(
            env, FLAGS, device=FLAGS.device
        )
        learner_model = learner_model.to(FLAGS.device)
        learner_generator_model = learner_generator_model.to(FLAGS.device)

    buffers = B.create_buffers(
        obs_spaces,
        model.num_actions,
        generator_model.logits_size,
        generator_model.raw_goal_size,
        num_instrs,
        FLAGS,
    )

    return model, generator_model, learner_model, learner_generator_model, buffers


def create_rnd_model(env, FLAGS):
    if FLAGS.env.startswith("BabyAI"):
        rnd_model_fn = lambda: babyai.create_rnd_model(env, FLAGS)
        if FLAGS.separate_message_novelty:
            message_rnd_model_fn = lambda: babyai.create_message_rnd_model(env, FLAGS)
        else:
            message_rnd_model_fn = None
        rnd_model = rnd.RNDModel(rnd_model_fn, message_rnd_model_fn, FLAGS).to(
            FLAGS.device
        )
    else:
        rnd_model = minihack.create_rnd_model(env, FLAGS)
    return rnd_model
