import datetime
import os
import sys
import time

import numpy as np

import torch
import torch.distributed as dist
import torchvision

from warmup_scheduler_local.scheduler import GradualWarmupScheduler
from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_mutual_info_score
from scipy.stats import entropy
import utils
from utils import print_or_log
from SLL_utils.cluster_cross_cpu import get_cluster_assignments
from SLL_utils.cluster_cross_gpu import get_cluster_assignments_gpu

try:
    from apex import amp
except ImportError:
    amp = None

# global variables
sk_schedule = None
group = None
sk_counter = 0

def make_scheduler(args, optimizer, milestones,  logger):
    if args.lr_warmup_epochs > 0:
        if args.scheduler_type == 'multi_step':
            print_or_log(f'Using Multi-Step LR scheduler', logger=logger)
            scheduler_step = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=milestones,
                gamma=args.lr_gamma
            )
        else:
            print_or_log(f'Using Cosine Annealing LR scheduler', logger=logger)
            scheduler_step = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
        lr_scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=args.world_size,
            total_epoch=args.lr_warmup_epochs,
            after_scheduler=scheduler_step
        )
    else:
        if args.scheduler_type == 'multi_step':
            print_or_log(f'Using Multi-Step LR scheduler', logger=logger)
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=milestones,
                gamma=args.lr_gamma
            )
        else:
            print_or_log(f'Using Cosine Annealing LR scheduler', logger=logger)
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
    return lr_scheduler

def get_loss(activations, targets):
    if args.headcount == 1:
        loss = torch.nn.functional.cross_entropy(activations, targets)
    else:
        loss = torch.mean(
            torch.stack(
                [torch.nn.functional.cross_entropy(activations[h], targets[:, h])
                 for h in range(args.headcount)]
            )
        )
    return loss


def cluster(args, selflabels, dataset, model, logger, writer, group, iter_num):
    selflabels_old = selflabels.clone()
    # cluster
    with torch.no_grad():
        if args.gpu_sk:
            selflabels = get_cluster_assignments_gpu(args, dataset, model, logger, writer, group, iter_num)
        else:
            selflabels = get_cluster_assignments(args, dataset, model, logger, writer, group, iter_num)
    self_labels_np  = selflabels[:, 0].cpu().numpy()
    global sk_counter
    sk_counter += 1
    if selflabels is not None:
        ## VIDEO
        nmi_v = normalized_mutual_info_score(
            self_labels_np,
            selflabels_old[:,0].cpu().numpy(),
            average_method='arithmetic'
        )
        if args.global_rank == 0:
            print_or_log(f'NMI_v: {nmi_v}', logger=logger)
        if writer:
            writer.add_scalar(
                f'train/nmi_v/iter',
                nmi_v,
                iter_num
            )
            writer.add_scalar(
                f'train/optim_count/iter',
                sk_counter,
                iter_num
            )

    true_labels = np.array(dataset._labels)[dataset.valid_indices]
    nmi_to_labels_v = normalized_mutual_info_score(
        self_labels_np,
        true_labels,
        average_method='arithmetic'
    )
    anmi_to_labels_v = adjusted_mutual_info_score(
        self_labels_np,
        true_labels,
        average_method='arithmetic'
    )
    print_or_log(f'NMI-tolabels: {nmi_to_labels_v}   aNMI-tolabels: {anmi_to_labels_v}', logger=logger)
    if writer:
        # Video
        writer.add_scalar(
            f'train/nmi-tolabels_v/iter',
            nmi_to_labels_v,
            iter_num
        )
        writer.add_scalar(
            f'train/a-nmi-tolabels_v/iter',
            anmi_to_labels_v,
            iter_num
        )
    if sk_counter % 10 == 0:
        entropies = []
        purities = []
        for sk_label in np.unique(self_labels_np):
            of_this_cluster = self_labels_np == sk_label
            size = of_this_cluster.sum()
            if size != 0:
                uniq, counts = np.unique(true_labels[of_this_cluster], return_counts=True)
                purities.append(max(counts)/sum(1.0*counts))
                entropies.append(entropy(counts/sum(1.0*counts)))
        print_or_log(f'Avg entropy: {np.mean(entropies)}   avg purity: {np.mean(purities)}', logger=logger)
        if writer:
            writer.add_histogram(
                'train/entropies',
                np.array(entropies),
                iter_num
            )
            writer.add_histogram(
                'train/purities',
                np.array(purities),
                iter_num
            )
            writer.add_scalar(
                'train/avg-entropy',
                np.mean(entropies),
                iter_num
            )
            writer.add_scalar(
                'train/avg-purity',
                np.mean(purities),
                iter_num
            )
    # Ensure processes reach to end of optim clusters
    if args.distributed and args.world_size > 8:
        if group is not None:
            dist.barrier(group=group)
        else:
            dist.barrier()
    return selflabels

def train_one_epoch(
        args,
        data_loader,
        model,
        selflabels,
        optimizer,
        device,
        epoch,
        print_freq,
        apex=False,
        logger=None,
        writer=None,
):
    global sk_schedule

    model.train()
    metric_logger = utils.MetricLoggerSLLX(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}'))

    header = 'Epoch: [{}]'.format(epoch)
    batches_thusfar = epoch * len(data_loader)
    for batch_idx, batch in metric_logger.log_every(data_loader, print_freq, header, logger, writer, 'train',
                                                    epoch=epoch):
        video, audio, _, _, selected = batch
        if batch_idx == 0:
            print_or_log((video.shape, audio.shape), logger=logger)
        # Occasional clustering via Sinkhorn-Knopp ###############################
        if batches_thusfar + batch_idx >= sk_schedule[-1]:
            ############ optimize labels #########################################
            print_or_log('Optimizaton starting', logger=logger)
            with torch.no_grad():
                _ = sk_schedule.pop()
                bs = data_loader.batch_size*args.world_size if args.world_size > 8 else data_loader.batch_size
                selflabels = cluster(args, selflabels, data_loader.dataset, model, logger, writer, group,
                                     iter_num=(batches_thusfar + batch_idx)*bs)


        # Cross-Entropy training of CNN #########################################################

        start_time = time.time()
        video, audio = video.to(device), audio.to(device)

        # Get activations
        feat_v, feat_a = model(video, audio)
        if batch_idx == 0:
            if args.headcount == 1:
                print_or_log((feat_v.shape, feat_a.shape), logger=logger)
            else:
                print_or_log((len(feat_v), feat_v[0].shape, feat_a[0].shape), logger=logger)

        loss_vid = get_loss(feat_v, selflabels[selected, 0] if args.headcount == 1 else selflabels[selected, :])
        loss_aud = get_loss(feat_a, selflabels[selected, 0] if args.headcount == 1 else selflabels[selected, :])
        loss = 0.5 * loss_vid + 0.5 * loss_aud

        # Backward pass             ############################################################
        optimizer.zero_grad()
        if apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        # signal received, relaunch experiment
        if os.environ['SIGNAL_RECEIVED'] == 'True':
            args.resume = 'True'
            if args.global_rank == 0:
                print_or_log("Beginning reqeue", logger=logger)
                utils.trigger_job_requeue(os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth'))

        batch_size = video.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[-1]["lr"])
        metric_logger.meters['batch_t/s'].update((time.time() - start_time))
        metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))
    if args.distributed:
        dist.barrier()
    torch.cuda.empty_cache()
    return metric_logger.loss.avg, selflabels


def main(args):
    # Set up mixed precision training
    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError("Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training."
            )

    # Make output dir
    if args.model_name is None:
        model_name = f'av_GDT_{args.vid_base_arch}_{args.aud_base_arch}_epochs_{args.epochs}_bsz_{args.batch_size}_optim_SGD_lr_{args.lr}_scheduler_{args.use_scheduler}'
    else:
        model_name = args.model_name
    args.output_dir = os.path.join(args.output_dir, model_name)
    if args.output_dir:
        utils.mkdir(args.output_dir)

    # Init distributed mode
    if torch.cuda.is_available():
        global group
        group = utils.init_distributed_mode(args, make_communication_groups=True)
        if args.distributed:
            group = group[0]

    # init signal handler
    utils.init_signal_handler()

    # Set up logger
    logger = None
    if args.distributed:
        filename = str(args.job_id) + '_' + str(args.global_rank) + '_log.out'
        logger = utils.setup_logger(
            "Video_reader, classification",
            args.output_dir,
            True,
            logname=filename
        )

    # Set up tensorboard
    tbx_path = os.path.join(args.output_dir, 'tensorboard')
    global_rank = args.global_rank if args.distributed else 0
    is_master = True if global_rank == 0 else False
    writer = None
    if is_master:
        writer = utils.setup_tbx(
            tbx_path,
            is_master
        )
        writer.add_text("namespace", repr(args))

    # Log version information
    print_or_log(args, logger=logger)
    print_or_log(f"torch version: {torch.__version__}", logger=logger)
    print_or_log(f"torchvision version: {torchvision.__version__}", logger=logger)

    # Set distributed mode
    device = torch.device(args.device)
    if args.world_size <= 8:
        device = 'cuda:0'

    # Set CudNN benchmark
    torch.backends.cudnn.benchmark = True

    # Create model
    print_or_log("Creating model", logger=logger)
    if args.use_mlp:
        if args.sync_bn:
            args.mlptype = 1
        else:
            args.mlptype = 0

    model = utils.load_model(
        model_name=args.model,
        vid_base_arch=args.vid_base_arch,
        aud_base_arch=args.aud_base_arch,
        pretrained=args.pretrained,
        norm_feat=args.norm_feat,
        use_mlp=args.use_mlp,
        mlptype=args.mlptype,
        headcount=args.headcount,
        num_classes=args.num_clusters,
    )
    model.to(device)
    ckpt_dict = torch.load(args.ckpt_path)
    model_weights = ckpt_dict["model"]
    utils.load_model_parameters(model, model_weights, only_encoder=True, on_ddp=True)

    if args.world_size <= 8:
        print_or_log(f'World size: {args.world_size}, Loading data parallel', logger=logger)
        #args.batch_size = args.batch_size * args.world_size
        #args.lr = args.lr * args.world_size
        print_or_log(f'LR: {args.lr}, Batch-size: {args.batch_size}', logger=logger)
        model_without_ddp = model
        model = torch.nn.DataParallel(model)
    else:
        if args.distributed and args.sync_bn:
            print_or_log("Sync BN on model", logger=logger)
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        model_without_ddp = model
        if args.distributed:
            ngpus_per_node = torch.cuda.device_count()
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                broadcast_buffers=False,
                find_unused_parameters=True if args.splitnow else True
            )

    if args.aug_audio:
        if args.audio_augtype == 'mild':
            args.aug_audio = [1, 1, 2, 5]
        elif args.audio_augtype == 'medium':
            args.aug_audio = [1, 1, 3, 6]
        elif args.audio_augtype == 'heavy':
            args.aug_audio = [2, 2, 3, 6]

    # SK clustering inits
    dataset, data_loader = utils.get_dataloader(args, 0)
    N_dl = len(data_loader)
    N = len(dataset)
    N_distr = N_dl*data_loader.batch_size

    selflabels = torch.zeros((N, args.headcount), dtype=torch.long, device='cuda')

    global sk_schedule
    sk_schedule =  (args.epochs*N_dl*(np.linspace(0, 1, args.nopts)**args.schedulepower)[::-1]).tolist()
    sk_schedule =  [(args.epochs+2)*N_dl]+ sk_schedule # to make sure we don't make it empty
    print_or_log(f'remaining SK opts @ epochs {[np.round(1.0*t/N_dl, 2) for t in sk_schedule]}', logger=logger)

    # Set up training optimizer
    params = []
    for name, param in model_without_ddp.named_parameters():
        if 'mlp' in name:
            print(name, param.shape)
            head_lr  = args.lr * args.world_size if args.world_size > 8 and args.lr_warmup_epochs == 0 else args.lr
            params.append({'params': param, 'lr': head_lr})
        else:
            base_lr  = (args.lr / 10) * args.world_size if args.world_size > 8 and args.lr_warmup_epochs == 0 else args.lr / 10
            params.append({'params': param, 'lr': base_lr})

    optimizer = torch.optim.SGD(
        params,
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # For Mixed Precision training
    if args.apex:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.apex_opt_level
        )

    # Set up LR scheduler
    milestones = [int(lr) - args.lr_warmup_epochs for lr in args.lr_milestones.split(',')]
    lr_scheduler = None
    if args.use_scheduler:
        lr_scheduler = make_scheduler(args, optimizer, milestones,  logger)


    # if args.splitnow:
    #     args.splitnow = False # don't do it again
    #     if args.use_mlp and args.mlp_type in [0, 1]:
    #         model.module.mlp_a.block_forward._modules['6'] = utils.split_head(list(model.module.mlp_a.modules())[-1])
    #         model.module.mlp_v.block_forward._modules['6'] = utils.split_head(list(model.module.mlp_v.modules())[-1])
    #         optimizer.add_param_group(
    #             {'params': [p for p in model.module.mlp_a.block_forward._modules['6'].parameters()]}
    #         )
    #         optimizer.add_param_group(
    #             {'params': [p for p in model.module.mlp_v.block_forward._modules['6'].parameters()]}
    #         )
    #     else:
    #         print(model.module.mlp_a)
    #         model.module.mlp_a = utils.split_head(model.module.mlp_a) # just  a linear layer
    #         model.module.mlp_v = utils.split_head(model.module.mlp_v)
    #         optimizer.add_param_group(
    #             {'params': [p for p in model.module.mlp_a.parameters()]}
    #         )
    #         optimizer.add_param_group(
    #             {'params': [p for p in model.module.mlp_v.parameters()]}
    #         )
    #     lr_scheduler = make_scheduler(args, optimizer, milestones,  logger)
    #     print_or_log(f"splitted the heads. Heads ar enow: {model.module.mlp_v}, {model.module.mlp_a}", logger=logger)

    if args.start_epoch != 0:
        include = [(qq / N_dl > args.start_epoch) for qq in sk_schedule]
        global sk_counter
        sk_counter = len(sk_schedule) - sum(include) # i.e. (total number of sk-opts) - (number of sk-opts outstanding)
        sk_schedule = (np.array(sk_schedule)[include]).tolist()
        print_or_log(f'remaining SK opts @ epochs {[np.round(1.0*t/N_dl, 2) for t in sk_schedule]}', logger=logger)
        if args.use_scheduler:
            [lr_scheduler.step() for _ in range(to_restore['epoch'])]

    # Set LR if temporal finetuning (1 node)
    if args.world_size <= 8:
        print_or_log(f'World size: {args.world_size}, lr: {args.lr}', logger=logger)
        lr = args.lr
        for pg in optimizer.param_groups:
            pg['lr'] = lr

    # Load dataloader
    print("Creating data loaders", flush=True)
    train_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=None,
        drop_last=True
    )

    # Warmup BN
    '''
    start_time = time.time()
    if args.start_epoch == 0:
        if args.distributed:
            train_sampler.set_epoch(999)
        utils._warmup_batchnorm(args, model, data_loader, batches=20, group=group)
    '''

    for epoch in range(args.start_epoch, args.epochs):
        print_or_log(f'Start training epoch: {epoch}', logger=logger)
        if args.distributed:
            train_sampler.set_epoch(epoch)
        loss, selflabels = train_one_epoch(
            args,
            data_loader,
            model,
            selflabels,
            optimizer,
            device,
            epoch,
            args.print_freq,
            args.apex,
            logger=logger,
            writer=writer,
        )
        if lr_scheduler:
            lr_scheduler.step()
        if args.output_dir:
            utils.save_ckpt(args, epoch, model, optimizer, lr_scheduler, selflabels)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print_or_log(f'Training time {total_time_str}', logger=logger)


def parse_args():
    def str2bool(v):
        v = v.lower()
        if v in ('yes', 'true', 't', '1'):
            return True
        elif v in ('no', 'false', 'f', '0'):
            return False
        raise ValueError('Boolean argument needs to be true or false. '
                         'Instead, it is %s.' % v)

    import argparse
    parser = argparse.ArgumentParser(description='Video Representation Learning')
    parser.register('type', 'bool', str2bool)

    # AUDIO UTILS
    parser.add_argument(
        '--aud-sample-rate',
        default=48000,
        type=int,
        help='audio sample rate'
    )
    parser.add_argument(
        '--aud-spec-type',
        default=1,
        type=int,
        help='audio spec type' # 1 : (40, 99), (257, 199)
    )
    parser.add_argument(
        '--use-volume-jittering',
        default='False',
        type='bool',
        help='use volume jittering'
    )
    parser.add_argument(
        '--use-temporal-jittering',
        default='False',
        type='bool',
        help='use temporal jittering'
    )
    parser.add_argument(
        '--num-sec',
        default=1,
        type=int,
        help='Number of seconds'
    )
    parser.add_argument(
        '--z-normalize',
        default='False',
        type='bool',
        help='normalize audio'
    )
    parser.add_argument(
        '--target-fps',
        default=30,
        type=int,
        help='target fps'
    )
    parser.add_argument(
        "--splitnow",
        type='bool',
        default='False',
        help="Splits the heads' last linear layers into two: i.e. double K.",
    )
    parser.add_argument(
        "--match",
        type='bool',
        default='True',
        help="Match A-V Embeddings",
    )
    parser.add_argument(
        '--ind-groups',
        type=int,
        default=1,
        help='divide the heads into indepedent groups (get different augs)'
    )
    parser.add_argument(
        '--groups',
        type=int,
        default=1,
        help='do SK in multiple groups'
    )
    parser.add_argument(
        '--sk-centercrop',
        default='False',
        type='bool',
        help='Use center cropping for SK clustering'
    )
    parser.add_argument(
        '--stoch-sk-modality',
        type=float,
        default=0,
        help='use stochastic modality selection. if value == 1: uses ONLY visual. if 0: both modalities'
    )
    parser.add_argument(
        '--distribution',
        default='default',
        type=str,
        help='implemented: from `default`, `gauss` or `zipf`'
    )
    parser.add_argument(
        '--gpu-sk',
        default='False',
        type='bool',
        help='Do sinkhorn-knopp on GPU (0)'
    )
    parser.add_argument(
        '--colorjitter',
        default='False',
        type='bool',
        help='Apply random color jitter'
    )
    parser.add_argument(
        '--dualdata',
        type='bool',
        default='False',
        help='use dataloader that returns two samples per video'
    )
    parser.add_argument(
        '--asynced',
        default=0,
        type=int,
        metavar='SY',
        help='asynced: 0:(basecase), 1:(asynced==postive), -1:(asynced==additional negative)'
    )
    parser.add_argument(
        '--headcount',
        type=int,
        default=10,
        help='how many heads each modality has'
    )

    ### DATA
    parser.add_argument(
        '--dataset',
        default='kinetics',
        help='name of dataset'
    )
    parser.add_argument(
        '--augtype',
        default=1,
        type=int,
        help='augmentation type (default: 1)'
    )
    parser.add_argument(
        '--decode-audio',
        default='True',
        type='bool',
        help='Get audio spec'
    )
    parser.add_argument(
        '--aug-audio',
        default='False',
        type='bool',
        help='whether to augment audio'
    )
    parser.add_argument(
        '--audio-augtype',
        default='mild',
        type=str,
        choices=['na', 'mild', 'medium', 'heavy'],
        help='type of audio-augment default: mild'
    )
    parser.add_argument(
        '--num-data-samples',
        default=None,
        type=int,
        help='number of samples in dataset'
    )
    parser.add_argument(
        '--use-temp-jitter',
        default='True',
        type='bool',
        help='Get clips from random timestamps each epoch'
    )
    parser.add_argument(
        '--center-crop', 
        default='False', 
        type='bool', 
        help='Use center cropping instead of random cropping'
    )

    parser.add_argument(
        '--fold',
        default=1,
        type=str,
        help='name of dataset'
    )
    parser.add_argument(
        '--clip-len',
        default=30,
        type=int,
        help='number of frames per clip'
    )
    parser.add_argument(
        '--clips-per-video',
        default=1,
        type=int,
        help='number of clips to sample from video'
    )
    parser.add_argument(
        '-j', '--workers',
        default=0,
        type=int,
        metavar='N',
        help='number of data loading workers (default: 16)'
    )
    parser.add_argument(
        '--train-crop-size',
        default=112,
        type=int,
        help='Size of spatial crops'
    )
    parser.add_argument(
        '--sample-rate',
        default=1,
        type=int,
        help='Subsampling rate: num frames between clips'
    )

    ### MODEL
    parser.add_argument(
        '--model',
        default='avc',
        help='model',
        choices=['avc']
    )
    parser.add_argument(
        '--vid-base-arch',
        default='r2plus1d_18',
        help='Video Base Arch for A-V model',
        choices=['r2plus1d_18', 'mc3_18', 's3d', 'r2plus1d_34', 'r2plus1d_50']
    )
    parser.add_argument(
        '--aud-base-arch',
        default='resnet9',
        help='Audio Base Arch for A-V model',
        choices=['resnet9', 'resnet18', 'vgg_audio', 'resnet34', 'resnet50']
    )
    parser.add_argument(
        "--pretrained",
        type='bool',
        default='False',
        help="Use pre-trained models from the modelzoo",
    )
    parser.add_argument(
        '--use-mlp',
        default='True',
        type='bool',
        help='Use MLP projection head'
    )
    parser.add_argument(
        '--mlptype',
        default=-1,
        type=int,
        help='MLP type (default: -1) -1 is linear layer'
    )
    parser.add_argument(
        '--num-clusters',
        default=256,
        type=int,
        help='number of clusters'
    )

    ### TRAINING
    parser.add_argument(
        '--schedulepower',
        default=1.5,
        type=float,
        help='SK schedule power compared to linear (default: 1.5)'
    )
    parser.add_argument(
        '--nopts', 
        default=160, 
        type=int, 
        help='number of pseudo-opts (default: 100)'
    )
    parser.add_argument(
        '--lamb', 
        default=10,
        type=int, 
        help='for pseudoopt: lambda (default:25) '
    )
    parser.add_argument(
        '-b', '--batch-size',
        default=4,
        type=int
    )
    parser.add_argument(
        '--epochs',
        default=45,
        type=int,
        metavar='N',
        help='number of total epochs to run'
    )
    parser.add_argument(
        '--lr',
        default=0.01,
        type=float,
        help='initial learning rate'
    )
    parser.add_argument(
        '--use-linear-scaling',
        default='False',
        type='bool',
        help='Linearly scale learning rate'
    )
    parser.add_argument(
        '--momentum',
        default=0.9,
        type=float,
        metavar='M',
        help='momentum'
    )
    parser.add_argument(
        '--wd', '--weight-decay',
        default=1e-4,
        type=float,
        metavar='W',
        help='weight decay (default: 1e-4)',
        dest='weight_decay'
    )
    parser.add_argument(
        "--use-scheduler",
        type='bool',
        default='True',
        help="Use LR scheduler",
    )
    parser.add_argument(
        "--scheduler-type",
        type=str,
        default='multi_step',
        choices=['multi_step', 'cosine'],
        help="Type of LR scheduler",
    )
    parser.add_argument(
        '--lr-milestones',
        default='20,30,40',
        type=str,
        help='decrease lr on milestones'
    )
    parser.add_argument(
        '--lr-gamma',
        default=0.1,
        type=float,
        help='decrease lr by a factor of lr-gamma'
    )
    parser.add_argument(
        '--lr-warmup-epochs',
        default=0,
        type=int,
        help='number of warmup epochs'
    )
    parser.add_argument(
        "--sync-bn",
        type='bool',
        default='False',
        help="Use sync batch norm",
    )
    parser.add_argument(
        "--warmup-bn",
        type='bool',
        default='False',
        help="Warmup batchnorm",
    )
    parser.add_argument(
        "--norm-feat",
        type='bool',
        default='False',
        help="Normalize embeddings",
    )

    ### LOGGING
    parser.add_argument(
        '--print-freq',
        default=10,
        type=int,
        help='print frequency'
    )
    parser.add_argument(
        '--output-dir',
        default='.',
        help='path where to save'
    )
    parser.add_argument(
        '--model-name',
        default=None,
        help='exp desc'
    )

    ### CHECKPOINTING
    parser.add_argument(
        '--resume',
        type='bool',
        default='False',
        help='resume from checkpoint'
    )
    parser.add_argument(
        '--start-epoch',
        default=0, type=int,
        metavar='N',
        help='start epoch'
    )
    parser.add_argument(
        '--ckpt-path',
        default='',
        help='resume from checkpoint path'
    )

    # Mixed precision training parameters
    parser.add_argument(
        '--apex',
        type='bool',
        default='False',
        help='Use apex for mixed precision training'
    )
    parser.add_argument(
        '--apex-opt-level',
        default='O1',
        type=str,
        help='For apex mixed precision training'
             'O0 for FP32 training, O1 for mixed precision training.'
             'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
    )

    # distributed training parameters
    parser.add_argument(
        '--device',
        default='cuda',
        help='device'
    )
    parser.add_argument(
        '--distributed',
        type='bool',
        default='False',
        help="ddp mode",
    )
    parser.add_argument(
        '--dist-backend',
        default='nccl',
        type=str,
        help='distributed backend'
    )
    parser.add_argument(
        '--dist-url',
        default='env://',
        help='url used to set up distributed training'
    )
    parser.add_argument(
        '--world-size',
        default=1,
        type=int,
        help='number of distributed processes'
    )
    parser.add_argument(
        '--debug_slurm',
        type='bool',
        default='False',
        help="Debug SLURM",
    )
    parser.add_argument(
        '--local_rank',
        default=-1,
        type=int,
        help='Local rank of node')
    parser.add_argument(
        '--master_port',
        default=-1,
        type=int,
        help='Master port of Job'
    )

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()

    # set multi-processing start method
    import torch.multiprocessing as mp

    try:
        mp.set_start_method('forkserver')
    except RuntimeError:
        pass

    main(args)
