import torch
from torch_geometric.datasets import TUDataset
import os
from utils import *
from torch_geometric.data import InMemoryDataset
from dataset_ogb import PygGraphPropPredDataset
from torch_geometric import transforms as T
from torch_geometric.utils import degree
from utils import *


TUData=["PROTEINS","IMDB-BINARY","REDDIT-BINARY","COLLAB","NCI1","NCI109","MUTAG","DD","PTC_MR","REDDIT-MULTI-5K"]
OGB_Data=['ogbg-molhiv','ogbg-molpcba']



def file_initialize():
    if os.path.exists("./tmp"):
        return    
    else:
        os.mkdir("./tmp")

def load_dataset(dataset_name,model_name,shuffle=False,):

    file_initialize()
    #print(file_name)
    if dataset_name in TUData:
        dataset = get_TUdataset(dataset_name,model_name,shuffle=shuffle)
    elif dataset_name in OGB_Data:
        dataset = get_OGBdataset(dataset_name, model_name, shuffle=shuffle)
    else:
        print("Error")
        raise Exception("Error in load_dataset")
    
    return dataset


def get_OGBdataset(name, model_name, sparse=True, cleaned=False, normalize=False,Permutation=None,index=None, shuffle=True, pre_transform=None):
    
    if model_name=="GPS":
        transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
        dataset = PygGraphPropPredDataset(name, root=os.path.join('./tmp'), pre_transform=transform, skip_collate=False)
    else:
        dataset = PygGraphPropPredDataset(name, root=os.path.join('./tmp'), pre_transform=pre_transform, skip_collate=False)
    
    # dataset = dataset.shuffle()
    
    if normalize:
        dataset.data.x -= torch.mean(dataset.data.x, axis=0)
        dataset.data.x /= torch.std(dataset.data.x, axis=0)
    
    
    
    return dataset
    

def get_TUdataset(name, model_name, sparse=True, cleaned=False, normalize=False,Permutation=None,index=None,shuffle=True,pre_transform=None):

    
    if model_name=="GPS":
        transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
        dataset = TUDataset(os.path.join('./tmp/pe'), name, use_node_attr=True, pre_transform=transform,cleaned=cleaned)
    else:
        dataset = TUDataset(os.path.join('./tmp'), name, use_node_attr=True, pre_transform=pre_transform,cleaned=cleaned)

    dataset.data.edge_attr = None

    if shuffle:
        dataset = ShuffleDataset(dataset)
        
    if dataset.data.x is None:
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            max_degree = max(max_degree, degs[-1].max().item())

        if max_degree < 1000:
            dataset.transform = T.OneHotDegree(max_degree)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            dataset.transform = NormalizedDegree(mean, std)

    elif normalize:

        dataset.data.x -= torch.mean(dataset.data.x, axis=0)
        dataset.data.x /= torch.std(dataset.data.x, axis=0)

    if not sparse:
        max_num_nodes = 0
        for data in dataset:
            max_num_nodes = max(data.num_nodes, max_num_nodes)

        if dataset.transform is None:
            dataset.transform = T.ToDense(max_num_nodes)
        else:
            dataset.transform = T.Compose(
                [dataset.transform, T.ToDense(max_num_nodes)])
    
    if name=='NCI1' or name=='NCI109': 
        dataset=dataset.shuffle()
    return dataset


class ShuffleDataset(InMemoryDataset):

    def __init__(self, tu_dataset):
        super(ShuffleDataset, self).__init__('.', None, None)
        self.tu_dataset = tu_dataset
        self.name=tu_dataset.name
        self.data, self.slices = self.process_dataset()

    def process_dataset(self):
        indices = torch.randperm(len(self.tu_dataset))
        data_list = [self.tu_dataset[i] for i in indices]
        return self.collate(data_list)



def load_splited_dataset(folder):
    dataset = torch.load(os.path.join(folder, 'data.pt'))
    return dataset

def load_datasets(root_dir='data_splits'):
    train_dataset = load_dataset(os.path.join(root_dir, 'train'))
    val_dataset = load_dataset(os.path.join(root_dir, 'val'))
    test_dataset = load_dataset(os.path.join(root_dir, 'test'))
    return train_dataset, val_dataset, test_dataset

