from torch.utils.data import Dataset, DataLoader
import os
import torch
import h5py
import numpy as np
from scipy.io import loadmat
import pytorch_lightning as pl
from typing import List, Union, Optional
numeric = Union[int, float]
import torch.nn.functional as F
import torch.nn as nn


class NormalizeRescale:
    def __init__(self, 
                 mean: Union[numeric, List[numeric]] = 0,
                 std: Union[numeric, List[numeric]] = 1,
                 rescale: bool = False,
                 upper_old: int = 51,
                 upper_new: int = 21,
                ):
        super().__init__()
        self.mean = mean
        self.std = std
        self.multi_mean = hasattr(mean, '__len__')
        self.multi_std = hasattr(std, '__len__')
        self.rescale = rescale
        self.upper_old = upper_old
        self.upper_new = upper_new
    
        
    def _normalize(self, x):
        sz = [1 for _ in range(len(x.size())-1)]
        mean = torch.Tensor(self.mean).view(-1, *sz) if self.multi_mean else self.mean
        std = torch.Tensor(self.std).view(-1, *sz) if self.multi_std else self.std
        return (x-mean)/std
    
    def _rescale(self, x):
        r, d = x.size()
        new_top = F.interpolate(x[None,None,:,:self.upper_old], [r, self.upper_new],
                               mode='bicubic', align_corners=True)
#         print(new_top.size(), x[...,self.upper_old:].size())
        return torch.concatenate((new_top[0,0],x[...,self.upper_old:]), -1)
    
    def __call__(self, x):
        x = self._normalize(x)
        if self.rescale:
            x = self._rescale(x)
        return x
    def unnormalize(self, x):
        sz = [1 for _ in range(len(x.size())-1)]
        mean = torch.Tensor(self.mean).view(-1, *sz) if self.multi_mean else self.mean
        std = torch.Tensor(self.std).view(-1, *sz) if self.multi_std else self.std
        return x*std + mean
    
class FileNormalizeRescale(nn.Module):
    def __init__(self, 
                 filename: str,
                 datatype: str,
                 rescale: bool = False,
                 upper_old: int = 51,
                 upper_new: int = 21,
                ):
        super().__init__()
        assert datatype in ['ssp', 'at']
        data = torch.load(filename)
#         self.mean = data[f'{datatype}_mean']
#         self.std = data[f'{datatype}_std']
        self.register_buffer('mean', torch.Tensor(data[f'{datatype}_mean']))
        self.register_buffer('std', torch.Tensor(data[f'{datatype}_std']))
    
    
        del data
        self.rescale = rescale
        self.upper_old = upper_old
        self.upper_new = upper_new
    def _normalize(self, x):
        return (x-self.mean)/self.std
    def _rescale(self, x):
        r, d = x.size()
        new_top = F.interpolate(x[None,None,:,:self.upper_old], [r, self.upper_new],
                               mode='bicubic', align_corners=True)
#         print(new_top.size(), x[...,self.upper_old:].size())
        return torch.concatenate((new_top[0,0],x[...,self.upper_old:]), -1)
    
    def forward(self, x):
        x = self._normalize(x)
        if self.rescale:
            x = self._rescale(x)
        return x
    def unnormalize(self, x):
        if self.mean.ndim == x.ndim:
            # single op
            return x*self.std + self.mean
        elif self.mean.ndim + 1 == x.ndim:
            # batch op
            return x*self.std[None] + self.mean[None]
        else:
            return x

def loadMatFile(matpath, key, time_idx=None):   
    try:
        data = loadmat(matpath)[key].T
        if time_idx is not None:
            data = data[time_idx, 0].T
        return data
    except NotImplementedError:
        with h5py.File(matpath, 'r') as data:
            out = data[key][()]
            if time_idx is not None:
                out = data[out[time_idx, 0]][()]
            return out
    except:
        ValueError('could not read at all...')
class OATDataModule(pl.LightningDataModule):
    def __init__(self, data_dir:str = "./", 
                 slices: Union[list[int], int] =  [1,2,3,4,5,6,7,8,9,10],
                 batch_size=32, ssp_transform=None, at_transform=None):
        super().__init__()
        self.data_dir = data_dir
        self.slices = slices
        if type(self.slices) == int:
            self.slices = [self.slices]
        self.batch_size = batch_size
        self.ssp_transform = ssp_transform
        self.at_transform = at_transform
    
    def prepare_data(self):
        pass
    
    def setup(self, stage: str):
        
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.oat_train = OATFlatEarthDataset(self.data_dir, slice_end=1000, slice_sets=self.slices, 
                                                 ssp_transform=self.ssp_transform,
                                                 at_transform=self.at_transform
                                                )
            self.oat_val = OATFlatEarthDataset(self.data_dir, slice_start=1000, slice_end=1200, slice_sets=self.slices, 
                                                 ssp_transform=self.ssp_transform,
                                                 at_transform=self.at_transform)

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.oat_test = OATFlatEarthDataset(self.data_dir, slice_start=1200, slice_sets=self.slices, 
                                                 ssp_transform=self.ssp_transform,
                                                 at_transform=self.at_transform)

    def train_dataloader(self):
        return DataLoader(self.oat_train, batch_size=self.batch_size, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.oat_val, batch_size=self.batch_size, num_workers=8, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.oat_test, batch_size=self.batch_size, num_workers=8, shuffle=False)
        

class OATDataset(Dataset):
    def __init__(self, basepath, slice_start=0, slice_end=1439, slice_sets=np.arange(1, 11)):
        """
        SSPs
        basepath/Slice{x}/30L/Sound_Speed_Profiles/Slice{x}_{xx}.mat
        1-1439
        (100x12)
        
        ATs
        basepath/Slice{x}/30L/output_arrival_time_E_data/cm_tau_dir.mat'
        (1439, 1) -> (20, 20)
        """
        self.basepath = basepath
        # Calculate Length
        self.slice_sets = slice_sets
        self.slice_start = slice_start
        self.slice_end = slice_end
        self.num_slices = len(self.slice_sets)
        self.num_in_slice = slice_end-slice_start
        
    def __len__(self):
        return self.num_slices*(self.slice_end-self.slice_start)
        
    def __getitem__(self, idx):
        # Convert idx
        slice_num = self.slice_sets[idx//self.num_in_slice]
        time_idx = self.slice_start + idx % self.num_in_slice

        # Grab SSP
        ssp_path = os.path.join(self.basepath, f'Slice{slice_num}/30L/Sound_Speed_Profiles/Slice{slice_num}_{time_idx+1}.mat')
        ssp = loadmat(ssp_path)['SSP']
        
        # Grab Direct Path Arrival Times
        at_path = os.path.join(self.basepath, f'Slice{slice_num}/30L/output_arrival_time_E_data/cm_tau_dir.mat')
        with h5py.File(at_path, 'r') as data:
            ats = data[data['cm_tau_dir'][time_idx,0]][()]
            
        return torch.Tensor(ssp), torch.Tensor(ats)
 
class OATFlatEarthDataset(OATDataset):
    def __init__(
        self, 
        basepath, 
        slice_start=0, 
        slice_end=1439, 
        slice_sets=np.arange(1, 11),
        ssp_transform=None,
        at_transform=None,
    
    ):
        """
        SSPs
        basepath/Slice_{x}/Acoustic_Data/cm_ssp.mat
        (1439, 13, 314) -> (1439, 11, 231)
        
        ATs
        basepath/Slice_{x}/Acoustic_Data/cm_tau_dir.mat
        basepath/Slice_{x}/Acoustic_Data/cm_tau_sur.mat
        (1439, 1) -> (20, 20)
        """
        super().__init__(basepath, slice_start, slice_end, slice_sets)
        self.ssp_transform=ssp_transform
        self.at_transform=at_transform
        
        
    def __getitem__(self, idx):
        # Convert idx
        slice_num = self.slice_sets[idx//self.num_in_slice] 
        time_idx = self.slice_start + idx % self.num_in_slice

        # Grab SSP
        ssp_path = os.path.join(self.basepath, f'Slice_{slice_num}/Acoustic_Data/cm_ssp.mat')
#         with h5py.File(ssp_path, 'r') as file:
#             ssp = torch.Tensor(file[file['cm_ssp'][time_idx, 0]][:11,:231])
        ssp = torch.Tensor(loadMatFile(ssp_path, 'cm_ssp', time_idx)[:11,:231])
        if self.ssp_transform:
            ssp = self.ssp_transform(ssp)
#         ssp = loadmat(ssp_path)['SSP']
        
        # Grab Direct Path Arrival Times
        at_path = os.path.join(self.basepath, f'Slice_{slice_num}/Acoustic_Data/')                     
        dir_ats = loadMatFile(os.path.join(at_path, 'cm_tau_dir.mat'), 'cm_tau_dir', time_idx)
        sur_ats = loadMatFile(os.path.join(at_path, 'cm_tau_sur.mat'), 'cm_tau_sur', time_idx)
#         with h5py.File(os.path.join(at_path, 'cm_tau_dir.mat'), 'r') as data:
#             dir_ats = data[data['cm_tau_dir'][time_idx,0]][()]
#         with h5py.File(os.path.join(at_path, 'cm_tau_sur.mat'), 'r') as data:
#             sur_ats = data[data['cm_tau_sur'][time_idx,0]][()]

        ats = torch.Tensor(np.concatenate((dir_ats[None], sur_ats[None]), 0))
        if self.at_transform:
            ats = self.at_transform(ats)
        return ssp, ats
    
class NAFlatEarthInitDataset(OATFlatEarthDataset):
    def __init__(
        self, 
        init_path: Optional[str] = None,
        basepath: str = './data/', 
        slice_start=0, 
        slice_end=1439, 
        slice_sets=np.arange(1, 11),
        ssp_transform=None,
        at_transform=None,
    
    ):
        """
        SSPs
        basepath/Slice_{x}/Acoustic_Data/cm_ssp.mat
        (1439, 13, 314) -> (1439, 11, 231)
        
        ATs
        basepath/Slice_{x}/Acoustic_Data/cm_tau_dir.mat
        basepath/Slice_{x}/Acoustic_Data/cm_tau_sur.mat
        (1439, 1) -> (20, 20)
        """
        super().__init__(basepath, slice_start, slice_end, slice_sets,ssp_transform,at_transform)
        
        num_data = (slice_end-slice_start)*len(slice_sets)
        if init_path is not None:
            self.inits = torch.load(init_path)
        else:
            self.inits = torch.zeros(num_data, 11, 231)
        assert self.inits.size(0) == num_data, f"Number of inits incorrect. Expected {num_data} but got {self.inits.size(0)}"
        
    def __getitem__(self, idx):
        ssp, ats = super().__getitem__(idx)
        
        return ssp, ats, self.inits[idx]
class NADataModule(OATDataModule):
    def __init__(self, data_dir:str = "./", 
                 init_path: Optional[str] = None,
                 slices: Union[list[int], int] =  [1,2,3,4,5,6,7,8,9,10],
                 batch_size=32, ssp_transform=None, at_transform=None):
        super().__init__(data_dir, slices, batch_size, ssp_transform, at_transform)
#         self.data_dir = data_dir
#         self.slices = slices
#         self.batch_size = batch_size
#         self.ssp_transform = ssp_transform
#         self.at_transform = at_transform
        self.init_path = init_path
    

    
    def setup(self, stage: str):
        
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.oat_train = NAFlatEarthInitDataset(self.init_path, 
                                                    self.data_dir, slice_end=1000, slice_sets=self.slices, 
                                                 ssp_transform=self.ssp_transform,
                                                 at_transform=self.at_transform
                                                )
            self.oat_val = NAFlatEarthInitDataset(self.init_path, 
                                                  self.data_dir, slice_start=1000, slice_end=1200, slice_sets=self.slices, 
                                                 ssp_transform=self.ssp_transform,
                                                 at_transform=self.at_transform)

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.oat_test = NAFlatEarthInitDataset(self.init_path, 
                                                   self.data_dir, slice_start=1200, slice_sets=self.slices, 
                                                 ssp_transform=self.ssp_transform,
                                                 at_transform=self.at_transform)
