import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision import transforms
import os 
import json
import matplotlib.pyplot as plt
import glob
import numpy as np
import random
import json

def hamming_distance(train_config):
   target = "000"
   distance = 0
   for ic in range(len(target)):
       distance += np.abs(int(train_config[ic])-int(target[ic]))
   return distance


class my_dataset(Dataset):
    def __init__(self, transform=None, num_samples=5000, dataset="", configs="", training=True, n_class_color=None, n_class_size=None, test_size=None, alpha=1.0, beta=2.0, remove_node=None, flag_zipf=2, flag_double=1):
        self.training = training
        self.test_size = test_size

        self.dataset = dataset
        self.n_class_size = n_class_size
        self.n_class_color = n_class_color
        if n_class_size!=1 and "single-body" in dataset:
            #self.size_bins = np.linspace(0.7, 3.1, n_class_size+1)
            self.size_bins = np.linspace(0.8, 3.1, n_class_size+1)
        if n_class_color!=1 and "single-body" in dataset:
            self.color_bins = np.linspace(0.0, 1.0, n_class_color+1)
        if n_class_size!=1 and "two-body" in dataset:
            self.size_bins = np.linspace(-10.0, 10.0, n_class_size+1)

        if flag_zipf==0: 
            ratios = {0: beta**3, 1: beta**2, 2: beta}
            sum_ratio = ratios[0] + 3 * ratios[1] + 3 * ratios[2] 
        elif flag_zipf==1: 
            ratios = {0: beta**2, 1: beta, 2: beta**2}
            sum_ratio = ratios[0] + 3 * ratios[1] + 3 * ratios[2] 

        if training:
            self.train_image_paths = []
            for config in configs:
                if flag_zipf==2:
                    if config=="000" and alpha!=1500 and remove_node!="100": 
                        new_paths = glob.glob("../data/"+dataset+"/train_"+remove_node+"/CLEVR_000_*.png")
                    else: 
                        new_paths = glob.glob("../data/"+dataset+"/train/CLEVR_"+config+"_*.png")
                    if remove_node==config: 
                        new_paths = new_paths[:alpha]
                    if flag_double==0 and config=="010": 
                        new_paths = new_paths + new_paths
                else: 
                    distance = hamming_distance(config)
                    new_paths = glob.glob("../data/"+dataset+"/train/CLEVR_"+config+"_*.png")
                    new_paths = new_paths[:int(num_samples*ratios[distance]/sum_ratio)]
                self.train_image_paths += new_paths
        else:
            self.test_image_paths = glob.glob("../data/"+dataset+"/test/CLEVR_"+configs+"_*.png")

        if self.training: 
           self.len_data = len(self.train_image_paths) - 1
        else:
           self.len_data = len(self.test_image_paths) - 1

        self.num_samples = num_samples
        self.transform = transform


    def __getitem__(self, index):
        #print(index, ipath, len(self.image_paths))
       if self.training:
           ipath = random.randint(0, len(self.train_image_paths)-1)
           img_path = self.train_image_paths[ipath]
       else:
           ipath = random.randint(0, len(self.test_image_paths)-1)
           img_path = self.test_image_paths[ipath]
            
       img = Image.open(img_path) #.convert('RGB')
       if self.transform is not None:
           img = self.transform(img)

       name_labels = img_path.split("_")[-2]
       with open(img_path.replace(".png", ".json"), 'r') as f:
           my_dict = json.loads(f.read())
           if "single-body_2d_3classes" in self.dataset:
               _size = my_dict[1]
               _color = my_dict[2][:3]
           if "single-body_3d_3classes" in self.dataset:
               _size = my_dict[0]
               _color = my_dict[1][:3]
           if "single-body_2d_4classes" in self.dataset:
               _size = my_dict[0]
               _color = my_dict[1][:3]
               _position = my_dict[2] #[-1]
           elif "two-body" in self.dataset:
               _size = my_dict

       #label = {0: int(name_labels[0]), 1: int(name_labels[1]), 2: int(name_labels[2])}
       if self.training:
           size = _size
           if "single-body_2d_3classes" in self.dataset or "single-body_3d_3classes" in self.dataset:
               color = _color
           if "single-body_2d_4classes" in self.dataset:
               color = _color
               position = _position
       else:
           if "two-body" in self.dataset:
              if int(name_labels[2])==0: size = self.test_size
              if int(name_labels[2])==1: size = -self.test_size
           else:
              if int(name_labels[2])==0: size = 2.6
              if int(name_labels[2])==1: size = self.test_size
              if "single-body_2d_3classes" in self.dataset or "single-body_3d_3classes" in self.dataset:
                  if int(name_labels[1])==0: color = [ 0.9 , 0.1 , 0.1 ] 
                  if int(name_labels[1])==1: color = [ 0.1 , 0.1 , 0.9 ] 
                  if int(name_labels[1])==2: color = [ 0.1 , 0.9 , 0.1 ] 
              if "single-body_2d_4classes" in self.dataset:
                  if int(name_labels[1])==0: color = [ 0.9 , 0.1 , 0.1 ] 
                  if int(name_labels[1])==1: color = [ 0.1 , 0.1 , 0.9 ] 
                  #if int(name_labels[3])==0: position = 1.0
                  #if int(name_labels[3])==1: position = -1.0
                  if int(name_labels[3])==0: position = 0.0
                  if int(name_labels[3])==1: position = 0.5

       if self.n_class_size<=1: 
           size = np.array(size, dtype=np.float32)
       else:
           if self.training:
               size = int(np.digitize([size], self.size_bins)-1.)
           else:
               if int(name_labels[2])==0: size = 0
               if int(name_labels[2])==1: size = int(self.n_class_size - 1) 
       if "single-body_2d_3classes" in self.dataset or "single-body_3d_3classes" in self.dataset:
           if self.n_class_color==1: 
               color = np.array(color, dtype=np.float32)
           else:
               if self.training:
                   color = int(np.digitize(color[:1], self.color_bins)-1.)
               else: 
                  if int(name_labels[1])==0: color = 0 
                  if int(name_labels[1])==1: color = int(self.n_class_color - 1) 
                  if int(name_labels[2])==1: size = int(self.n_class_size - 1) 

       if "single-body_2d_3classes" in self.dataset or "single-body_3d_3classes" in self.dataset:
           label = {0: int(name_labels[0]), 1: color, 2: size}
       elif "single-body_2d_4classes" in self.dataset: 
           label = {0: int(name_labels[0]), 1: np.array(color, dtype=np.float32), 2: size, 3: np.array(position, dtype=np.float32)}
       else:
           label = {0: int(name_labels[0]), 1: int(name_labels[1]), 2: size}
       #label = {0: int(name_labels[0]), 1: np.array(color, dtype=np.float32), 2: np.array(size, dtype=np.float32)}

       return img, label 

    def __len__(self):
        #return len(self.image_paths)
        return self.num_samples
        #return self.len_data


if __name__ == '__main__':
    #transform = transforms.Compose([transforms.Resize((54,54)), transforms.ToTensor()])
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = my_dataset(transform, dataset="single-body_2d_3classes", n_class_size=1, n_class_color=1, configs=["000","010","100","001"])
    dataloader = DataLoader(dataset, batch_size=4)

    for img, label in dataloader:
        print('label=',label)
        print(img.shape)
        plt.imshow(np.transpose(img[0].numpy(), (2,1,0)))
        plt.show()
        print('img.shape=',img.shape)
        exit()

