import os
import csv
import numpy as np
import pandas as pd
from PIL import Image
import scipy.spatial.distance as distance
import skimage.exposure
import skimage.filters
import skimage.morphology

class Dataset(object):
    def __init__(self):
        self.image_ids = []
        self.image_info = []

    def add_image(self, image_id, path, name):
        image_info = {
            "id": image_id,
            "path": path,
            "name": name,
        }
        self.image_info.append(image_info)

    def add_dataset(self, root_dir):
        """Function for adding a directory containing subdirectories of images, grouped by protein.
        Feat_dir is the directory containing the features for each cell, so we can sample by distance."""
        i = 0
        for currdir in os.listdir(root_dir):
            print ("Adding " + currdir)

            # Get all the stained images only
            all_files = os.listdir(root_dir + currdir)
            image_names = []
            for file in all_files:
                if ("_protein" in file):
                    image_names.append(file)

            # Add images
            for j in range(len(image_names)):
                # Get the features for the cell
                self.add_image(
                    image_id=i,
                    path=root_dir + currdir + "/",
                    name=image_names[j])
                i += 1

    def load_image(self, image_id):
        path = self.image_info[image_id]['path']
        proteinname = self.image_info[image_id]['name']
        brightfieldname = proteinname.replace("_protein", "_nucleus")

        protein = np.array(Image.open(path + proteinname))
        try:
            brightfield = np.array(Image.open(path + brightfieldname))
        except FileNotFoundError:
            brightfieldname = proteinname.replace("_gfp", "_rfp")
            brightfield = np.array(Image.open(path + brightfieldname))

        return protein, brightfield

    def load_image_with_label(self, image_id):
        path = self.image_info[image_id]['path']
        proteinname = self.image_info[image_id]['name']
        brightfieldname = proteinname.replace("_protein", "_nucleus")
        label = path + proteinname

        protein = np.array(Image.open(path + proteinname))
        try:
            brightfield = np.array(Image.open(path + brightfieldname))
        except FileNotFoundError:
            brightfieldname = proteinname.replace("_gfp", "_rfp")
            brightfield = np.array(Image.open(path + brightfieldname))

        return protein, brightfield, label

    def sample_pair_equally(self, image_id):
        """Sample a pair for the given image by drawing with equal probability from the folder"""
        path = self.image_info[image_id]['path']
        name = self.image_info[image_id]['name']

        all_files = np.array(os.listdir(path))
        image_names = [file for file in all_files if "_protein" in file and file != name]
      
        if len(all_files) > 1:
            sampled_image = np.random.choice(image_names)

            proteinname = sampled_image
            brightfieldname = proteinname.replace("_protein", "_nucleus")

            protein = np.array(Image.open(path + proteinname))
            try:
                brightfield = np.array(Image.open(path + brightfieldname))
            except FileNotFoundError:
                brightfieldname = proteinname.replace("_gfp", "_rfp")
                brightfield = np.array(Image.open(path + brightfieldname))


            return protein, brightfield
        else:
            raise ValueError("Directory " + path + " has only one image.")

    def prepare(self):
        """Prepares the Dataset class for use."""
        # Build (or rebuild) everything else from the info dicts.
        self.num_images = len(self.image_info)
        self.image_ids = np.arange(self.num_images)
