import numpy as np
import torch
import torch.nn as nn


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    nn.init.orthogonal_(layer.weight, std)
    nn.init.constant_(layer.bias, bias_const)
    return layer


def get_representation_net_for_minigrid(space, embedding_size=64):
    c, n, m = space.shape[0], space.shape[1], space.shape[2]

    if n <= 1 or m <= 1:
        image_embedding_size = c * n * m
        image_conv = nn.Sequential(
            nn.Flatten(),
            layer_init(nn.Linear(image_embedding_size, embedding_size)),
            nn.ReLU(),
        )
    elif n <= 3 or m <= 3:
        image_embedding_size = (n + 2 - 2 + 1 - 2 + 1 - 2 + 1) * (m + 2 - 2 + 1 - 2 + 1 - 2 + 1) * 64
        image_conv = nn.Sequential(
            layer_init(nn.Conv2d(3, 32, (2, 2), stride=1, padding=1)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, (2, 2), stride=1, padding=0)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, (2, 2), stride=1, padding=0)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(image_embedding_size, embedding_size)),
            nn.ReLU(),
        )
    elif n <= 9 or m <= 9:
        image_embedding_size = (n - 2 + 1 - 2 + 1 - 2 + 1) * (m - 2 + 1 - 2 + 1 - 2 + 1) * 64
        image_conv = nn.Sequential(
            layer_init(nn.Conv2d(3, 32, (2, 2), stride=1, padding=0)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, (2, 2), stride=1, padding=0)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, (2, 2), stride=1, padding=0)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(image_embedding_size, embedding_size)),
            nn.ReLU(),
        )
    else:
        image_embedding_size = ((n - 2 + 1) // 2 - 2 + 1 - 2 + 1) * ((m - 2 + 1) // 2 - 2 + 1 - 2 + 1) * 64
        image_conv = nn.Sequential(
            layer_init(nn.Conv2d(3, 32, (2, 2), stride=1, padding=0)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            layer_init(nn.Conv2d(32, 64, (2, 2), stride=1, padding=0)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, (2, 2), stride=1, padding=0)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(image_embedding_size, embedding_size)),
            nn.ReLU(),
        )

    return image_conv, torch.float32, embedding_size


def get_representation_net_for_miniworld(space, embedding_size=64):
    class NormalizationLayer(nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return x / 255.0

    c, n, m = space.shape[0], space.shape[1], space.shape[2]

    image_embedding_size = (((n - 8) // 4 + 1 - 4) // 2 + 1 - 3 + 1) * (((m - 8) // 4 + 1 - 4) // 2 + 1 - 3 + 1) * 64
    image_conv = nn.Sequential(
        NormalizationLayer(),
        layer_init(nn.Conv2d(1, 32, (8, 8), stride=4, padding=0)),
        nn.ReLU(),
        layer_init(nn.Conv2d(32, 64, (4, 4), stride=2, padding=0)),
        nn.ReLU(),
        layer_init(nn.Conv2d(64, 64, (3, 3), stride=1, padding=0)),
        nn.ReLU(),
        nn.Flatten(),
        layer_init(nn.Linear(image_embedding_size, embedding_size)),
        nn.ReLU(),
    )

    return image_conv, torch.float32, embedding_size
