import traceback
from pathlib import Path

import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np

from GaussianConv import GaussianConv
import matplotlib.pyplot as plt
import os

class ImageLoader(Dataset):

    def __init__(self, patch_size=(7, 7), dataset="MNIST", return_color=False):
        self.patch_size = patch_size
        self.return_color = return_color
        self.custom_loader = False
        if dataset == "MNIST":
            transform = transforms.Compose([transforms.ToTensor()])
            self.trainset = torchvision.datasets.MNIST(root='../_DATA', train=True, download=False, transform=transform)
        # define center surround kernels
        self.center_cells = GaussianConv(kernel_size=5, sigma=2, device='cpu')
        self.surround_cells = GaussianConv(kernel_size=9, sigma=3, device='cpu')


    def __len__(self):
        return len(self.trainset)

    def __getitem__(self, idx):
        try:
            if not self.custom_loader:
                img, label = self.trainset[idx]
            else:
                # load image directly from disk
                img = transforms.ToTensor()(Image.open(self.trainset[idx]))
                label = -1
            img = img.permute(1, 2, 0)

            if not self.return_color:
                img = torch.mean(img, dim=2, keepdim=True)

            # calc center surround using 2 gaussian kernels
            img = self.center_cells(img.squeeze(2)) - self.surround_cells(img.squeeze(2))
            # standardize the image
            img = (img-img.median())/(img.std() + 1e-12)
            # clip values between -1 and 1
            img = torch.tanh(img)

            # plt.imshow(img, cmap='gray')
            # plt.show()

            if self.patch_size is not None:
                while True:
                    x_offset = np.random.randint(0, img.shape[0]-self.patch_size[0])
                    y_offset = np.random.randint(0, img.shape[1]-self.patch_size[1])
                    img_final = img[x_offset:x_offset+self.patch_size[0], y_offset:y_offset+self.patch_size[1]]
                    if img_final.abs().sum()>1:
                        break
                    # else:
                    #     print("BLANK IMG", img_final.abs().sum())
            else:
                img_final = img

            # set near-zero responses to zero
            on_cells = torch.where(img_final>0.05, img_final, torch.zeros_like(img_final))
            off_cells = torch.where(img_final<-0.05, img_final.abs(), torch.zeros_like(img_final))
            # stack both, output on cells, off cells, and 3th dimension thats unused (easier to reshape later to images)
            img_final = torch.stack([on_cells, off_cells, torch.zeros_like(on_cells)]).permute(1, 2, 0)
        except Exception as e:
            traceback.print_exc()
            img_final = torch.zeros(size=(self.patch_size[0], self.patch_size[1], 3))
            label = -1

        return img_final, label
