import pytorch_lightning as pl 
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
import wandb
from attr import evolve
from pytorch_lightning.loggers import TensorBoardLogger
from utils import parse_args
import os

from pl_modules import SI_Linear_Module
from hparams import SI_Linear_HParams_Dict as config_dict
from utils import LR_WD_Scheduler,LR_WD_Logger,Norm_Logger,Model_logger,Manifold_Logger,DEVICE

# import multiprocessing as mp

def main(args):
    config_name = args.config
    suffix = args.suffix
    log_name = config_name +'_'+ suffix if len(suffix)>0 else config_name
    
    # seed_everything(0)
#     mp.set_start_method('spawn')
    config = config_dict[config_name]
    run =wandb.init(project="si_linear", name = log_name, sync_tensorboard=True, reinit = True, entity = 'si-limit-diffusion', save_code = True, config = config.to_dict())
    seed_everything(hash(run.id) % 10000007)
#     print(config.to_dict())
    method = SI_Linear_Module(config)
    logger = TensorBoardLogger("tb_logs", name=f"si_linear/{log_name}")
    logger.log_hyperparams(config.to_dict())
    callbacks = [LR_WD_Scheduler(epoch_wise=False), LR_WD_Logger(epoch_wise =False), Norm_Logger(layer_wise = True,epoch_wise=False,freq = config.check_val_every_n_epoch)]
    callbacks += [Model_logger(path = wandb.run.dir, freq=100)]
    callbacks += [Manifold_Logger(layer_wise = True,epoch_wise=False,freq = config.check_val_every_n_epoch)]
    print(DEVICE)
#     print(os.environ["CUDA_VISIBLE_DEVICES"])
    print(config.check_val_every_n_epoch, config.max_steps)
    trainer = pl.Trainer(
        gpus=-1 if torch.cuda.is_available() else 0, 
        max_steps=config.max_steps, 
        logger=logger, 
        callbacks=callbacks,  
        accelerator= 'gpu' if torch.cuda.is_available() else 'cpu',
        strategy = 'ddp' if torch.cuda.is_available() else None,
        deterministic=True, 
        log_every_n_steps=1,
        flush_logs_every_n_steps=config.check_val_every_n_epoch,
        check_val_every_n_epoch = config.check_val_every_n_epoch,
    ) #
    trainer.fit(method)
    run.finish()
    return method
    
if __name__ == '__main__':
    args = parse_args()
    # for i in range(8):
    #     print('RUNS ----------------', i, '-------------------')
    main(args)