import pdb
import torch
import torch.nn as nn

class PixelNormalization(nn.Module):
    def __init__(self, eps=1e-10):
        super(PixelNormalization, self).__init__()
        self.gamma = nn.Parameter(torch.ones([1]))
        self.beta = nn.Parameter(torch.zeros([1]))
        self.eps = eps
    
    # Assumes input is (N, C, H, W)
    def forward(self, x):
        x_norm =  (x - x.mean(-3, keepdim=True))/(x.std(-3, keepdim=True, unbiased=False) + self.eps)
        return x_norm * self.gamma + self.beta 
