from argparse import ArgumentParser


def add_basic_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument(
        "--dataset",
        default="cifar10",
        type=str,
        choices=["cifar10", "svhn", "catsvsdogs",
                 "in9l", "mnist", "waterbirds", "celeba", "imagenet"],
    )
    parser.add_argument(
        "--workers",
        default=8,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 4)",
    )
    parser.add_argument(
        "--train",
        default=False,
        action="store_true",
        help="train the model.",
    )
    parser.add_argument(
        "--test_batch", default=128, type=int, metavar="N", help="test batchsize"
    )
    parser.add_argument(
        "--arch",
        metavar="ARCH",
        default="small_cnn",
        choices=["resnet50", "resnet18", "resnet32", "small_cnn"],
    )
    parser.add_argument(
        "--data_ratio_to_inject_bias",
        type=float,
        default=0.8,
        help="percentage of train data we want to have specific color of squares",
    )
    parser.add_argument(
        "--base_dir",
        type=str,
        metavar="PATH",
        help="base directory to save data and experiments",
    )
    parser.add_argument(
        "--dataset_dir",
        type=str,
        metavar="PATH",
    )
    parser.add_argument(
        "--saved_mask_dir",
        default=None,
        type=str,
        metavar="PATH",
        help="path to the saved checkpoints and augmented data",
    )
    parser.add_argument(
        "--saved_checkpoint_dir",
        default=None,
        type=str,
        metavar="PATH",
        help="path to the saved checkpoints and augmented data",
    )
    return parser


def add_optimizer_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument("--optimizer", default="sgd", choices=["sgd", "adam"])
    parser.add_argument(
        "--lr",
        "--learning-rate",
        default=0.1,
        type=float,
        metavar="LR",
        help="initial learning rate",
    )
    parser.add_argument(
        "--schedule",
        type=int,
        nargs="+",
        default=[25, 50, 75, 100, 125, 150, 175, 200, 225, 250, 275],
        help="Multiply learning rate by gamma at the scheduled epochs (default: 25,50,75,100,125,150,175,200,225,250,275)",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.5,
        help="LR is multiplied by gamma on schedule (default: 0.5)",
    )
    parser.add_argument(
        "--momentum", default=0.9, type=float, metavar="M", help="momentum"
    )
    parser.add_argument(
        "--weight_decay",
        "--wd",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
    )
    parser.add_argument(
        "--lr_scheduler_name",
        default="multi_step",
        type=str,
        help="learing rate scheduler",
    )
    parser.add_argument(
        "--use_nesterov",
        action="store_true",
        default=False,
        help="use nesterov in sgd",
    )

    return parser


def add_device_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument("--use_cuda", action="store_true", default=False)
    parser.add_argument("--gpu_ids", type=int, nargs="*", default=[0])
    return parser


def add_train_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument("--masktune",
                        action="store_true", default=False)
    parser.add_argument("--num_phases", type=int, default=2)
    parser.add_argument(
        "--epochs",
        default=10,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--final_train_epochs",
        default=20,
        type=int,
        help="number of total epochs to train the model on all masked data+clean data",
    )
    parser.add_argument(
        "--train_batch", default=128, type=int, metavar="N", help="train batchsize"
    )
    parser.add_argument(
        "--selective_classification", action="store_true", default=False
    )
    parser.add_argument(
        "--class_weights", type=float, nargs="*", default=None,
    )
    parser.add_argument(
        "--use_pretrained_weights", action="store_true", default=False
    )
    return parser


def add_test_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument(
        "--checkpoint_name", type=str, default="best_model_for_final.pt"
    )
    parser.add_argument(
        "--test_data_types",
        default=["biased"],
        type=str,
        nargs="+",
        choices=["mixed_rand", "mixed_same", "original", "biased", "fg_mask",
                 "mixed_next", "no_fg", "only_bg_b", "only_bg_t", "only_fg"],
    )
    return parser


def add_augmask_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument("--remove_k", type=int, default=0.2)
    parser.add_argument(
        "--masking", type=str, choices=["sort_mask", "threshold_mask", "max_pixel", "soft_mask", "bernoulli", "soft_mask_using_threshold"], default="threshold_mask"
    )
    parser.add_argument(
        "--mask_mean_mode", type=str, default="global", choices=["global", "persample"]
    )
    parser.add_argument(
        "--masking_arch",
        default=None,
        choices=["resnet50", "resnet18", "small_cnn"],
    )
    parser.add_argument(
        "--threshold_types_for_masking",
        type=str,
        nargs="+",
        default=["mean+1std"],
    )
    parser.add_argument(
        "--initialize_final_train_model_with_erm_weights",
        action="store_true",
        default=False
    )
    parser.add_argument(
        "--masking_batch_size",
        type=int,
        default=128,
    )
    parser.add_argument(
        "--resume_final_training",
        default=False,
        action='store_true'
    )
    parser.add_argument(
        "--checkopints_dir",
        type=str,
    )
    parser.add_argument(
        "--continue_final_train_with_erm_lr",
        default=False,
        action="store_true"
    )
    return parser


def biased_mnist_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument(
        "--square_number",
        default=1,
        type=int,
        help="Number of squares to be added to images",
    )
    parser.add_argument(
        "--bias_type",
        type=str,
        default="square",
        choices=["square", "background", "none", "one_square"],
        help="type of bias to be injected into the MNIST data"
    )
    return parser


def selective_classification_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument(
        "--coverage",
        type=int,
        nargs="+",
        default=[80, 85, 90, 95, 100],
        help="percentage of test data to be covered by the selective classification",
    )
    return parser


def heat_map_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument(
        "--grad_cam_weight",
        default="mean",
        type=str,
        choices=["mean", "raw"],
        help="how to calculate weight of each activation map",
    )
    parser.add_argument(
        "--heat_map_generation_method",
        default="grad_cam",
        type=str,
        choices=["grad_cam", "score_cam",
                 'sparsity', "xgrad_cam", "ablation_cam", "eigen_cam", "gradcam_plusplus", "full_grad", "layer_cam"],
        help="method to generate heat map",
    )
    return parser


def cutout_args(parser: ArgumentParser) -> ArgumentParser:
    parser.add_argument(
        "--cutout",
        action="store_true",
        default=False,
        help="use cutout augmentation",
    )
    parser.add_argument(
        "--n_holes",
        default=1,
        type=int,
        help="Number of holes for cutout",
    )
    parser.add_argument(
        "--length",
        default=8,
        type=int,
        help="length of squares chosen for cutout",
    )
    return parser


def init_train_argparse() -> ArgumentParser:
    parser = ArgumentParser(description="PyTorch Augmask training")
    parser = add_basic_args(parser)
    parser = add_device_args(parser)
    parser = add_optimizer_args(parser)
    parser = add_train_args(parser)
    parser = add_test_args(parser)
    parser = add_augmask_args(parser)
    parser = biased_mnist_args(parser)
    parser = selective_classification_args(parser)
    parser = heat_map_args(parser)
    parser = cutout_args(parser)
    return parser
