import argparse
import yaml
import wandb
from dataset import load_dataset,TUData,OGB_Data
from train  import train_model
from utils import seed_everything
import os
import torch
from train_dist  import data_loader
from ogb.graphproppred import Evaluator

parser = argparse.ArgumentParser(description='  ')

parser.add_argument('--model', default='GMT', type=str,help='train the XX model')
parser.add_argument('--dataset', default='PROTEINS', type=str,help='on which dataset')
parser.add_argument('--device', required=False, default=7, type=int, help='Device Number' )
parser.add_argument("--seed", type=int, default=1234, help="random seed (default: 1234)")
args = parser.parse_args()

seed_everything(args.seed)

config_path="./config/example.yaml"
config=yaml.safe_load(open(config_path,'r'))



repeat_time= 10


    
save_data=False

if __name__ == '__main__':
   
    dataset=load_dataset(args.dataset,args.model,shuffle=True,)
    dataloaders = data_loader(args.dataset, dataset, config[args.dataset]['batch_size'])

    if save_data:
        folder="./tmp/"+args.dataset+"/"+args.model
        if not os.path.exists(folder):
            os.makedirs(folder)
        torch.save(dataset, os.path.join(folder, 'splited.pt'))
        print("Shuffled Data saved to ",folder)

    for item in range(repeat_time):
        
        model=train_model(args.model,
                        dataset,
                        dataloaders=dataloaders,
                        config=config[args.dataset],
                        device=args.device,
                        save_model=True,
                        patience=30,#wandb.config.patience,
                        min_delta=0.005,#wandb.config.patience,
                        seed=args.seed,
                        evaluator=None if args.dataset in TUData else Evaluator(args.dataset),
                        task_type=None if args.dataset in TUData else dataset.task_type
                        )     
        
      

