"""
Student policy networks
"""


import torch
import torch.nn.functional as F
from torch import nn

from ...envs.babyai import relative_coords_from_grid
from .. import utils
from . import language


class Student(nn.Module):
    """Constructs the Student Policy which takes an observation and a goal and produces an action."""

    def __init__(
        self,
        observation_shape,
        num_actions,
        FLAGS,
        lang=language.LANG,
        lang_len=language.LANG_LEN,
        lang_templates=language.INSTR_TEMPLATES,
        vocab=language.VOCAB,
        use_intrinsic_rewards=None,
    ):
        super().__init__()
        self.observation_shape = observation_shape
        self.num_actions = num_actions
        self.FLAGS = FLAGS

        if use_intrinsic_rewards is None:
            use_intrinsic_rewards = FLAGS.generator
        self.use_intrinsic_rewards = use_intrinsic_rewards

        self.use_index_select = True
        self.obj_dim = 5
        self.col_dim = 3
        self.con_dim = 2
        self.goal_dim = self.FLAGS.goal_dim
        self.num_channels = self.obj_dim + self.col_dim + self.con_dim + 1

        self.embed_object = nn.Embedding(11, self.obj_dim)
        self.embed_color = nn.Embedding(6, self.col_dim)
        self.embed_contains = nn.Embedding(4, self.con_dim)

        self.lang = lang
        self.lang_len = lang_len
        self.lang_templates = lang_templates
        self.vocab = vocab

        init_ = lambda m: utils.init(
            m,
            nn.init.orthogonal_,
            lambda x: nn.init.constant_(x, 0),
            nn.init.calculate_gain("relu"),
        )

        self.feat_extract = nn.Sequential(
            init_(
                nn.Conv2d(
                    in_channels=self.num_channels,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
        )

        fc_input_dim = 32 + self.obj_dim + self.col_dim
        if self.FLAGS.language_goals is not None:
            fc_input_dim += self.goal_dim

        self.fc = nn.Sequential(
            init_(nn.Linear(fc_input_dim, self.FLAGS.state_embedding_dim)),
            nn.ReLU(),
            init_(
                nn.Linear(
                    self.FLAGS.state_embedding_dim, self.FLAGS.state_embedding_dim
                )
            ),
            nn.ReLU(),
        )

        if self.FLAGS.use_lstm:
            self.core = nn.LSTM(
                self.FLAGS.state_embedding_dim,
                self.FLAGS.state_embedding_dim,
                self.FLAGS.num_lstm_layers,
            )

        if self.FLAGS.language_goals is not None:
            # Define the language encoder.
            self.language_encoder = language.LanguageEncoder(
                self.vocab, embedding_dim=64, hidden_dim=self.goal_dim
            )

        init_ = lambda m: utils.init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0)
        )

        self.policy = init_(nn.Linear(self.FLAGS.state_embedding_dim, self.num_actions))
        self.baseline = init_(nn.Linear(self.FLAGS.state_embedding_dim, 1))
        if self.FLAGS.int.twoheaded:
            self.int_baseline = init_(nn.Linear(self.FLAGS.state_embedding_dim, 1))

    def initial_state(self, batch_size):
        """Initializes LSTM."""
        if not self.FLAGS.use_lstm:
            return tuple()
        return tuple(
            torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
            for _ in range(2)
        )

    def create_embeddings(self, x, id):
        """Generates compositional embeddings."""
        if id == 0:
            objects_emb = self._select(self.embed_object, x[:, :, :, id::3])
        elif id == 1:
            objects_emb = self._select(self.embed_color, x[:, :, :, id::3])
        elif id == 2:
            objects_emb = self._select(self.embed_contains, x[:, :, :, id::3])
        embeddings = torch.flatten(objects_emb, 3, 4)
        return embeddings

    def _select(self, embed, x):
        """Efficient function to get embedding from an index."""
        if self.use_index_select:
            out = embed.weight.index_select(0, x.reshape(-1))
            # handle reshaping x to 1-d and output back to N-d
            return out.reshape(x.shape + (-1,))
        else:
            return embed(x)

    def add_partial_xy_goals(self, goal_channel, goal, x, x_full):
        for i in range(goal.shape[0]):
            n_col = x_full.shape[-2]
            abs_goal_x, abs_goal_y = divmod(goal[i].item(), n_col)
            # Map absolute coords to partial coords.
            rel_coords = relative_coords_from_grid(
                abs_goal_x,
                abs_goal_y,
                x_full[i],
                x[i],
            )
            # Add goal channel indicator only if agent can see this pos.
            if rel_coords is not None:
                rel_goal_x, rel_goal_y = rel_coords
                # Check that the frame is visible
                is_visible = x[i, rel_goal_x, rel_goal_y, 0] != 0
                if is_visible:
                    goal_channel[i, rel_goal_x, rel_goal_y] = 1.0

    def forward(self, inputs, core_state=(), goal=None):
        """Main Function, takes an observation and a goal and returns and action."""
        if goal is None:
            goal = []

        # -- [unroll_length x batch_size x height x width x channels]
        if self.FLAGS.partial_obs:
            x = inputs["partial_frame"]
        else:
            x = inputs["frame"]
        T, B, h, w, *_ = x.shape

        # -- [unroll_length*batch_size x height x width x channels]
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        goal = torch.flatten(goal, 0, 1)

        # Construct the goal feature channel.
        goal_channel = torch.zeros_like(x, requires_grad=False)
        # Indicate which (x, y) coord is the goal, only if there is a
        # teacher and we are using (x, y) goals.
        # If language goals, we'll embed the goal later.
        if self.FLAGS.partial_obs:
            goal_channel = goal_channel[:, :, :, :1]
            if self.use_intrinsic_rewards and self.FLAGS.language_goals is None:
                x_full = torch.flatten(inputs["frame"], 0, 1)
                self.add_partial_xy_goals(goal_channel, goal, x, x_full)
        else:
            goal_channel = torch.flatten(goal_channel, 1, 2)[:, :, 0]

            if self.use_intrinsic_rewards and self.FLAGS.language_goals is None:
                for i in range(goal.shape[0]):
                    goal_channel[i, goal[i]] = 1.0

            goal_channel = goal_channel.view(T * B, h, w, 1)

        carried_col = inputs["carried_col"]
        carried_obj = inputs["carried_obj"]

        x = x.long()
        goal = goal.long()
        carried_obj = carried_obj.long()
        carried_col = carried_col.long()
        # -- [B x H x W x K]
        x = torch.cat(
            [
                self.create_embeddings(x, 0),
                self.create_embeddings(x, 1),
                self.create_embeddings(x, 2),
                goal_channel.float(),
            ],
            dim=3,
        )
        carried_obj_emb = self._select(self.embed_object, carried_obj)
        carried_col_emb = self._select(self.embed_color, carried_col)

        x = x.transpose(1, 3)
        x = self.feat_extract(x)
        x = x.view(T * B, -1)
        carried_obj_emb = carried_obj_emb.view(T * B, -1)
        carried_col_emb = carried_col_emb.view(T * B, -1)
        union = torch.cat([x, carried_obj_emb, carried_col_emb], dim=1)

        if self.FLAGS.language_goals is not None:
            # Embed the language goal and concatenate.
            if self.use_intrinsic_rewards:
                goal_emb = self.embed_goal(goal)
            else:  # Ignore goal
                goal_emb = torch.zeros((union.shape[0], self.goal_dim)).to(union.device)
            union = torch.cat([union, goal_emb], dim=1)

        core_input = self.fc(union)

        if self.FLAGS.use_lstm:
            core_input = core_input.view(T, B, -1)
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):
                nd = nd.view(1, -1, 1)
                core_state = tuple(nd * s for s in core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
        else:
            core_output = core_input
            core_state = tuple()

        policy_logits = self.policy(core_output)
        baseline = self.baseline(core_output)

        if self.training:
            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
        else:
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        output = dict(policy_logits=policy_logits, baseline=baseline, action=action)
        if self.FLAGS.int.twoheaded:
            int_baseline = self.int_baseline(core_output)
            output.update(int_baseline=int_baseline.view(T, B))

        return (
            output,
            core_state,
        )

    def embed_goal(self, goal):
        """
        Embed a language goal.
        """
        lang = self.lang[goal].to(goal.device)
        lang_len = self.lang_len[goal].to(goal.device)
        goal_emb = self.language_encoder(lang, lang_len)
        return goal_emb
