from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
import numpy as np
import os

def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
    
class ChestXray(Dataset):
    def __init__(self, root, split_txt, transform, labeled_file=None):
        """
        Args:
            root: Root directory path.
            transform: PyTorch transforms for image transformations and tensor conversion.
        """
        self.img_path = os.path.join(root, 'images')
        self.csv_path = os.path.join(root, 'Data_Entry_2017.csv')
        self.split_txt = os.path.join(root, split_txt)
        #  "No Finding",
        self.used_labels = ["Atelectasis",  "Consolidation",  "Infiltration", 
                            "Pneumothorax", "Edema", "Emphysema", 
                            "Fibrosis", "Effusion", "Pneumonia", 
                            "Pleural_Thickening", "Cardiomegaly", "Nodule", 
                            "Mass", "Hernia"]

        self.labels_maps = {"No Finding": 0, 
                            "Atelectasis": 1, 
                            "Consolidation": 2, 
                            "Infiltration": 3, 
                            "Pneumothorax": 4, 
                            "Edema": 5, 
                            "Emphysema": 6, 
                            "Fibrosis": 7, 
                            "Effusion": 8, 
                            "Pneumonia": 9, 
                            "Pleural_Thickening": 10, 
                            "Cardiomegaly": 11, 
                            "Nodule": 12, 
                            "Mass": 13,
                            "Hernia": 14}
        

        # Read the csv file
        self.data_info = pd.read_csv(self.csv_path, skiprows=[0], header=None)

        # First column contains the image paths
        self.image_name_all = np.asarray(self.data_info.iloc[:, 0])
        self.labels_all = np.asarray(self.data_info.iloc[:, 1])

        self.image_name = []
        self.label = []

        self.transform = transform
        
        self.image_name = open(self.split_txt).read().splitlines()
        if labeled_file is None:
            for name, label in zip(self.image_name_all, self.labels_all):
                if name in self.image_name:
                    single_label = []
                    label = label.split("|")
                    for l in self.used_labels:
                        if l in label:
                            single_label.append(1)
                            # self.image_name.append(name)
                        else:
                            single_label.append(0)
                            # self.image_name.append(name)
                    self.label.append(single_label)
            self.save_label_file()
        else:
            self.image_name, self.label = self.load_labels(label_file=labeled_file)
            
    
        self.data_len = len(self.image_name)

        self.image_name = np.asarray(self.image_name)
        self.label = np.asarray(self.label)
        
    def load_labels(self, label_file):
        """
        Load labels from a saved label file.

        Args:
            label_file (str): Path to the label file.

        Returns:
            image_names (list): List of image names.
            labels (list): List of labels corresponding to each image.
        """
        image_names = []
        labels = []
        print("* Load Labels from {}".format(label_file))

        with open(label_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                line = line.strip()
                image_name, label_str = line.split(' ', 1)
                image_names.append(image_name)
                labels.append([int(l) for l in label_str.split(',')])

        return image_names, labels

    def save_label_file(self):
        
        label_file = self.split_txt.split('.')[0]+'_labels.txt'
        print("* Save Labels to {}".format(label_file))
        with open(label_file, 'w') as f:
            for i, labels in enumerate(self.label):
                label_str = ','.join(str(l) for l in labels)
                f.write(f"{self.image_name[i]} {label_str}\n")
        print("Labels saved successfully.")

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_name[index]

        # Open image
        image = pil_loader(os.path.join(self.img_path, single_image_name))
        # print(image)
        # try:
        if self.transform is not None:
            image = self.transform(image)
        # except:
        #     print(image)
        #     print(single_image_name)
        #     print("*"*10)
        #     exit()
        

        # Get label (class) of the image based on the cropped pandas column
        single_image_label = self.label[index]

        return (image, single_image_label)

    def __len__(self):
        return self.data_len
