from email import header
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical, Beta
from torch.autograd import Function
from utils import reparameterize
import torchvision.models as models

# 9216 -> 9216
# Models for carracing games
class Encoder(nn.Module):
    def __init__(self, class_latent_size = 8, content_latent_size = 32, input_channel = 3, flatten_size = 1024):
        super(Encoder, self).__init__()
        self.class_latent_size = class_latent_size
        self.content_latent_size = content_latent_size
        self.flatten_size = flatten_size

        self.main = nn.Sequential(
            nn.Conv2d(input_channel, 32, 4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2), nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2), nn.ReLU()
        )

        self.linear_mu = nn.Linear(flatten_size, content_latent_size)
        self.linear_logsigma = nn.Linear(flatten_size, content_latent_size)
        self.linear_classcode = nn.Linear(flatten_size, class_latent_size) 

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size(0), -1)
        mu = self.linear_mu(x)
        logsigma = self.linear_logsigma(x)
        classcode = self.linear_classcode(x)

        return mu, logsigma, classcode

    def get_feature(self, x):
        mu, logsigma, classcode = self.forward(x)
        return mu


class Decoder(nn.Module):
    def __init__(self, latent_size = 32, output_channel = 3, flatten_size=1024):
        super(Decoder, self).__init__()

        self.fc = nn.Linear(latent_size, flatten_size)

        self.main = nn.Sequential(
            nn.ConvTranspose2d(flatten_size, 128, 5, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 5, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 6, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 6, stride=2), nn.Sigmoid()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = self.main(x)
        return x


# DisentangledVAE here is actually Cycle-Consistent VAE, disentangled stands for the disentanglement between domain-general and domain-specifc embeddings 
class DisentangledVAE(nn.Module):
    def __init__(self, class_latent_size = 8, content_latent_size = 32, img_channel = 3, flatten_size = 1024):
        super(DisentangledVAE, self).__init__()
        self.encoder = Encoder(class_latent_size, content_latent_size, img_channel, flatten_size)
        self.decoder = Decoder(class_latent_size + content_latent_size, img_channel, flatten_size)

    def forward(self, x):
        mu, logsigma, classcode = self.encoder(x)
        contentcode = reparameterize(mu, logsigma)
        latentcode = torch.cat([contentcode, classcode], dim=1)

        recon_x = self.decoder(latentcode)

        return mu, logsigma, classcode, recon_x


# Models for CARLA autonomous driving
class CarlaEncoder(nn.Module):
    def __init__(self, class_latent_size = 16, content_latent_size = 32, input_channel = 3, flatten_size = 9216):
        super(CarlaEncoder, self).__init__()
        self.class_latent_size = class_latent_size
        self.content_latent_size = content_latent_size
        self.flatten_size = flatten_size

        self.main = nn.Sequential(
            nn.Conv2d(input_channel, 32, 4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2), nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2), nn.ReLU()
        )

        self.linear_mu = nn.Linear(flatten_size, content_latent_size)
        self.linear_logsigma = nn.Linear(flatten_size, content_latent_size)
        self.linear_classcode = nn.Linear(flatten_size, class_latent_size) 

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size(0), -1)
        mu = self.linear_mu(x)
        logsigma = self.linear_logsigma(x)
        classcode = self.linear_classcode(x)

        return mu, logsigma, classcode

    def get_feature(self, x):
        mu, logsigma, classcode = self.forward(x)
        return mu


class CarlaDecoder(nn.Module):
    def __init__(self, latent_size = 32, output_channel = 3):
        super(CarlaDecoder, self).__init__()
        self.fc = nn.Linear(latent_size, 9216)

        self.main = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2), nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.fc(x)
        x = torch.reshape(x, (-1,256,6,6))
        x = self.main(x)
        return x


class CarlaDisentangledVAE(nn.Module):
    def __init__(self, class_latent_size = 16, content_latent_size = 32, img_channel = 3, flatten_size=9216):
        super(CarlaDisentangledVAE, self).__init__()
        self.encoder = CarlaEncoder(class_latent_size, content_latent_size, img_channel, flatten_size)
        self.decoder = CarlaDecoder(class_latent_size + content_latent_size, img_channel)

    def forward(self, x):
        mu, logsigma, classcode = self.encoder(x)
        contentcode = reparameterize(mu, logsigma)
        latentcode = torch.cat([contentcode, classcode], dim=1)

        recon_x = self.decoder(latentcode)

        return mu, logsigma, classcode, recon_x

class ResnetDisentangledVAE(nn.Module):
    def __init__(self, class_latent_size = 32, content_latent_size = 256, img_channel = 3, flatten_size=2048):
        super(ResnetDisentangledVAE, self).__init__()

        self.encoder = ResnetVAEEncoder(class_latent_size, content_latent_size, img_channel, flatten_size)
        self.decoder = ResnetVAEDecoder(class_latent_size + content_latent_size, img_channel)

    def forward(self, x):
        mu, logsigma, classcode = self.encoder(x)
        contentcode = reparameterize(mu, logsigma)
        latentcode = torch.cat([contentcode, classcode], dim=1)

        recon_x = self.decoder(latentcode)
        return mu, logsigma, classcode, recon_x


class ResnetVAEEncoder(nn.Module):
    def __init__(self, class_latent_size = 32, content_latent_size = 256, input_channel = 3, flatten_size=2048):
        super(ResnetVAEEncoder, self).__init__()

        self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = 1024, 1024, 256

        self.class_latent_size = class_latent_size
        self.content_latent_size = content_latent_size
        self.flatten_size = flatten_size
    
        main = models.resnet50(pretrained=False)
        modules = list(main.children())[:-1] 
        self.main = nn.Sequential(*modules)
        # self.fc1 = nn.Linear(main.fc.in_features, self.fc_hidden1)
        # self.bn1 = nn.BatchNorm1d(self.fc_hidden1, momentum=0.01)
        # self.fc2 = nn.Linear(self.fc_hidden1, self.fc_hidden2)
        # self.bn2 = nn.BatchNorm1d(self.fc_hidden2, momentum=0.01)
        # self.relu = nn.ReLU(inplace=True)

        # Latent vectors mu and sigma
        # self.fc3_mu = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)      # output = CNN embedding latent variables
        # self.fc3_logvar = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)  # output = CNN embedding latent variables
        self.linear_mu = nn.Linear(flatten_size, content_latent_size) # fc3_mu
        self.linear_logsigma = nn.Linear(flatten_size, content_latent_size) # fc3_logvar
        self.linear_classcode = nn.Linear(flatten_size, class_latent_size) 

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size(0), -1)

        # # FC layers
        # x = self.bn1(self.fc1(x))
        # x = self.relu(x)
        # x = self.bn2(self.fc2(x))
        # x = self.relu(x)
   
        mu = self.linear_mu(x)

        logsigma = self.linear_logsigma(x)
        classcode = self.linear_classcode(x)

        return mu, logsigma, classcode

    def get_feature(self, x):
        mu, logsigma, classcode = self.forward(x)
        return mu


class ResnetVAEDecoder(nn.Module):
    def __init__(self, latent_size = 256, output_channel = 3):
        super(ResnetVAEDecoder, self).__init__()

        self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = 1024, 768, 256

        # self.main = nn.Sequential(
        #     nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2), nn.ReLU(),
        #     nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2), nn.ReLU(),
        #     nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2), nn.ReLU(),
        #     nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2), nn.Sigmoid(),
        # )
        
        # CNN architechtures
        self.ch1, self.ch2, self.ch3, self.ch4 = 16, 32, 64, 128
        self.k1, self.k2, self.k3, self.k4 = (5, 5), (3, 3), (3, 3), (3, 3)      # 2d kernal size
        self.s1, self.s2, self.s3, self.s4 = (2, 2), (2, 2), (2, 2), (2, 2)      # 2d strides
        self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0)  # 2d padding


        self.main = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2), nn.Sigmoid(),
        
        )
        self.fc = nn.Linear(latent_size, 256 * 12 * 12)

        # # Sampling vector
        # self.fc4 = nn.Linear(latent_size, self.fc_hidden2)
        # self.fc_bn4 = nn.BatchNorm1d(self.fc_hidden2)
        # self.fc5 = nn.Linear(self.fc_hidden2, 64 * 4 * 4)
        # self.fc_bn5 = nn.BatchNorm1d(64 * 4 * 4)
        # self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.fc(x).view(-1,256,12,12)
        # x = self.relu(self.fc_bn4(self.fc4(x)))
        # x = self.relu(self.fc_bn5(self.fc5(x))).view(-1, 64, 4, 4)
        x = self.main(x)

        return x



class ViTDisentangledVAE(nn.Module):
    def __init__(self, class_latent_size = 32, content_latent_size = 512, img_channel = 3, flatten_size=38400):
        super(ViTDisentangledVAE, self).__init__()

        self.encoder = ViTVAEEncoder(class_latent_size, content_latent_size, img_channel, flatten_size)
        self.decoder = ViTVAEDecoder(class_latent_size + content_latent_size, img_channel)

    def forward(self, x):
        mu, logsigma, classcode = self.encoder(x)
        contentcode = reparameterize(mu, logsigma)
        latentcode = torch.cat([contentcode, classcode], dim=1)

        recon_x = self.decoder(latentcode)
        return mu, logsigma, classcode, recon_x

class ViTVAEEncoder(nn.Module):
    def __init__(self, class_latent_size = 32, content_latent_size = 512, input_channel = 3, flatten_size=38400):
        super(ViTVAEEncoder, self).__init__()
        self.class_latent_size = class_latent_size
        self.content_latent_size = content_latent_size
        self.flatten_size = flatten_size
        
        from pytorch_pretrained_vit import ViT
        self.main = ViT('B_32', pretrained=False)
        del self.main.fc

        self.linear_mu = nn.Linear(flatten_size, content_latent_size)
        self.linear_logsigma = nn.Linear(flatten_size, content_latent_size)
        self.linear_classcode = nn.Linear(flatten_size, class_latent_size) 

    def forward(self, x):

        x = self.main(x)
        x = x.view(x.size(0), -1)

        mu = self.linear_mu(x)

        logsigma = self.linear_logsigma(x)
        classcode = self.linear_classcode(x)

        return mu, logsigma, classcode

    def get_feature(self, x):
        mu, logsigma, classcode = self.forward(x)
        return mu


class ViTVAEDecoder(nn.Module):
    def __init__(self, latent_size = 512, output_channel = 3):
        super(ViTVAEDecoder, self).__init__()
        self.fc = nn.Linear(latent_size, 36864)

        self.main = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2), nn.Sigmoid(),
        
        )

    def forward(self, x):
        x = self.fc(x)
        x = torch.reshape(x, (-1,256,12,12))
        x = self.main(x)
        return x



      
class IthorEncoder(nn.Module):
    def __init__(self, class_latent_size = 32, content_latent_size = 32, input_channel = 3, flatten_size =36864):
        super(IthorEncoder, self).__init__()
        self.class_latent_size = class_latent_size
        self.content_latent_size = content_latent_size
        self.flatten_size = flatten_size

        self.main = nn.Sequential(
            nn.Conv2d(input_channel, 32, 4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2), nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2), nn.ReLU()
        )

        self.linear_mu = nn.Linear(flatten_size, content_latent_size)
        self.linear_logsigma = nn.Linear(flatten_size, content_latent_size)
        self.linear_classcode = nn.Linear(flatten_size, class_latent_size) 

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size(0), -1)
        mu = self.linear_mu(x)

        logsigma = self.linear_logsigma(x)
        classcode = self.linear_classcode(x)

        return mu, logsigma, classcode

    def get_feature(self, x):
        mu, logsigma, classcode = self.forward(x)
        return mu


class IthorDecoder(nn.Module):
    def __init__(self, latent_size = 32, output_channel = 3):
        super(IthorDecoder, self).__init__()
        self.fc = nn.Linear(latent_size, 36864)

        self.main = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2), nn.Sigmoid(),
        
        )

    def forward(self, x):
        x = self.fc(x)
        x = torch.reshape(x, (-1,256,12,12))
        x = self.main(x)
        return x


class IthorDisentangledVAE(nn.Module):
    def __init__(self, class_latent_size = 32, content_latent_size = 32, img_channel = 3, flatten_size=36864):
        super(IthorDisentangledVAE, self).__init__()
        self.encoder = IthorEncoder(class_latent_size, content_latent_size, img_channel, flatten_size)
        self.decoder = IthorDecoder(class_latent_size + content_latent_size, img_channel)

    def forward(self, x):
        mu, logsigma, classcode = self.encoder(x)
        contentcode = reparameterize(mu, logsigma)
        latentcode = torch.cat([contentcode, classcode], dim=1)

        recon_x = self.decoder(latentcode)

        return mu, logsigma, classcode, recon_x




class CarlaLatentPolicy(nn.Module):
    def __init__(self, input_dim, action_dim, hidden_layer=[64,64]):
        super(CarlaLatentPolicy, self).__init__()
        actor_layer_size = [input_dim] + hidden_layer
        actor_feature_layers = nn.ModuleList([])
        for i in range(len(actor_layer_size)-1):
            actor_feature_layers.append(nn.Linear(actor_layer_size[i], actor_layer_size[i+1]))
            actor_feature_layers.append(nn.ReLU())
        self.actor = nn.Sequential(*actor_feature_layers)
        self.alpha_head = nn.Sequential(nn.Linear(hidden_layer[-1], action_dim), nn.Softplus())
        self.beta_head = nn.Sequential(nn.Linear(hidden_layer[-1], action_dim), nn.Softplus())
    
        critic_layer_size = [input_dim] + hidden_layer
        critic_layers = nn.ModuleList([])
        for i in range(len(critic_layer_size)-1):
            critic_layers.append(nn.Linear(critic_layer_size[i], critic_layer_size[i+1]))
            critic_layers.append(nn.ReLU())
        critic_layers.append(nn.Linear(hidden_layer[-1], 1))
        self.critic = nn.Sequential(*critic_layers)

    def forward(self, x, action=None):
        actor_features = self.actor(x)
        alpha = self.alpha_head(actor_features)+1
        beta = self.beta_head(actor_features)+1
        self.dist = Beta(alpha, beta)
        if action is None:
            action = self.dist.sample()
        else:
            action = (action+1)/2
        action_log_prob = self.dist.log_prob(action).sum(-1)
        entropy = self.dist.entropy().sum(-1)
        value = self.critic(x)
        return action*2-1, action_log_prob, value.squeeze(-1), entropy


class CarlaSimpleEncoder(nn.Module):
    def __init__(self, latent_size = 32, input_channel = 3):
        super(CarlaSimpleEncoder, self).__init__()
        self.latent_size = latent_size

        self.main = nn.Sequential(
            nn.Conv2d(input_channel, 32, 4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2), nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2), nn.ReLU()
        )

        self.linear_mu = nn.Linear(9216, latent_size)

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size(0), -1)

        mu = self.linear_mu(x)
        return mu


class CarlaImgPolicy(nn.Module):
    def __init__(self, input_dim, action_dim, hidden_layer=[400,300]):
        super(CarlaImgPolicy, self).__init__()
        self.main_actor = CarlaSimpleEncoder(latent_size = input_dim-1)
        self.main_critic = CarlaSimpleEncoder(latent_size = input_dim-1)
        actor_layer_size = [input_dim] + hidden_layer
        actor_feature_layers = nn.ModuleList([])
        for i in range(len(actor_layer_size)-1):
            actor_feature_layers.append(nn.Linear(actor_layer_size[i], actor_layer_size[i+1]))
            actor_feature_layers.append(nn.ReLU())
        self.actor = nn.Sequential(*actor_feature_layers)
        self.alpha_head = nn.Sequential(nn.Linear(hidden_layer[-1], action_dim), nn.Softplus())
        self.beta_head = nn.Sequential(nn.Linear(hidden_layer[-1], action_dim), nn.Softplus())
    
        critic_layer_size = [input_dim] + hidden_layer
        critic_layers = nn.ModuleList([])
        for i in range(len(critic_layer_size)-1):
            critic_layers.append(nn.Linear(critic_layer_size[i], critic_layer_size[i+1]))
            critic_layers.append(nn.ReLU())
        critic_layers.append(layer_init(nn.Linear(hidden_layer[-1], 1), gain=1))
        self.critic = nn.Sequential(*critic_layers)

    def forward(self, x, action=None):
        speed = x[:, -1:]
        x = x[:, :-1].view(-1, 3,128,128)  # image size in carla driving task is 128x128
        # x = x[:, :-1].view(-1, 3,224,224)  # image size in carla driving task is 128x128

        x1 = self.main_actor(x)
        x1 = torch.cat([x1, speed], dim=1)

        x2 = self.main_critic(x)
        x2 = torch.cat([x2, speed], dim=1)

        actor_features = self.actor(x1)
        alpha = self.alpha_head(actor_features)+1
        beta = self.beta_head(actor_features)+1
        self.dist = Beta(alpha, beta)
        if action is None:
            action = self.dist.sample()
        else:
            action = (action+1)/2
        action_log_prob = self.dist.log_prob(action).sum(-1)
        entropy = self.dist.entropy().sum(-1)
        value = self.critic(x2)
        return action*2-1, action_log_prob, value.squeeze(-1), entropy


# for vit
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb
