import torch
import torch.nn as nn
from PIL import Image
from torch.distributed import init_process_group, destroy_process_group

def get_data_scaler(centered=False):
  """Data normalizer. Assume data are always in [0, 1]."""
  if centered:
    # Rescale to [-1, 1]
    return lambda x: x * 2. - 1.
  else:
    return lambda x: x

def get_data_inverse_scaler(centered=False):
  """Inverse data normalizer."""
  if centered:
    # Rescale [-1, 1] to [0, 1]
    return lambda x: (x + 1.) / 2.
  else:
    return lambda x: x

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y, transform):
        super().__init__()
        self.X = X
        self.Y = Y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x_idx = self.X[idx]
        y_idx = self.Y[idx]
        #print(x_idx.size())
        x_idx = Image.fromarray(x_idx.squeeze())

        if self.transform is not None:
            x_idx = self.transform(x_idx)
        return (x_idx, y_idx)


class CustomDataset_idx(torch.utils.data.Dataset):
    def __init__(self, X, Y, transform):
        super().__init__()
        self.X = X
        self.Y = Y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x_idx = self.X[idx]
        # if self.Y:
        y_idx = self.Y[idx]
        #print(x_idx.size())
        #print(x_idx.shape)
        #print(x_idx.max(), x_idx.min())
        x_idx = Image.fromarray(x_idx.squeeze())

        if self.transform is not None:
            x_idx = self.transform(x_idx)
        return (x_idx, y_idx, idx)
    
def to_flattened_numpy(x):
  """Flatten a torch tensor `x` and convert it to numpy."""
  return x.detach().cpu().numpy().reshape((-1,))


def from_flattened_numpy(x, shape):
  """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
  return torch.from_numpy(x.reshape(shape))