# code for training for one-shot transfer attacks

from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms

import utils
import time
import logging
from train_utils import *
# from models import resnet
import models
from utils import MyCustomDataset, get_architecture, Input_diversity, MultiEnsemble, get_dataset, get_model,load_data, parse_config, parser_opt_scheduler, get_folder_names
from utils import AdvRegWarper, load_advreg_dataset
import sys

from opacus.dp_model_inspector import DPModelInspector
from opacus.utils import module_modification
from opacus import PrivacyEngine


parser = argparse.ArgumentParser(description='')
parser.add_argument('--model-dir', default='./trained_models',
                    help='directory of model for saving checkpoint')


parser.add_argument('--model', type=str, 
                    help='training mode')
parser.add_argument('--indice', type=int, 
                    help='training mode')
parser.add_argument('--model-num', type=int, 
                    help='training mode')
parser.add_argument('--config', type=str, default='', help='config path')

args = parser.parse_args()


kwargs = {'num_workers': 1, 'pin_memory': True}



def main():
    # utils.general_setup(model_dir, None)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = parse_config(args.config)
    model = get_model(int(args.model), load=False).cuda()
    loss = nn.CrossEntropyLoss()

    train_loader, test_loader = load_data(args.indice,batch_size=config.batch_size,num_samples=25000)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer, scheduler = parser_opt_scheduler(model, config)

    if config.defense.method == 'dpsgd':

        model = module_modification.convert_batchnorm_modules(model)
        inspector = DPModelInspector()
        assert inspector.validate(model)
        model=model.to(device)

        privacy_engine = PrivacyEngine(
                model,
                batch_size=config.batch_size,
                sample_size=len(train_loader.dataset),
                alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
                noise_multiplier=config.defense.sigma,
                max_grad_norm=config.defense.max_norm,
            )
        privacy_engine.attach(optimizer)
    
    if config.defense.method == 'dropout':

        num_ftrs = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(config.defense.dropout_ratio), 
            nn.Linear(num_ftrs, model.fc.out_features)
        )
        model.cuda()

    config_path=args.config
    all_dir_names = get_folder_names(config_path)
    config_dir_name = all_dir_names[2] #trainer
    config_second_dir_name = all_dir_names[3] #defense
    config_third_dir_name = all_dir_names[4] #attack config

    if args.indice==0: # train the target model
        model_dir = os.path.join(args.model_dir, config_dir_name, config_second_dir_name)
    else: # train source/validation models
        model_dir = os.path.join(args.model_dir, config_dir_name)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
        
    save_path = os.path.join(model_dir,f'{args.model}_{args.indice}_{args.model_num}.pt')
    print(save_path)

    if config.defense.method == 'advreg':
        privateset,testset,refset,privateset_origin = load_advreg_dataset(args.indice)
        print(len(privateset), len(privateset_origin), len(testset), len(refset))
        BATCH_SIZE = config.batch_size
        private_dataloader = torch.utils.data.DataLoader(privateset, batch_size=BATCH_SIZE, shuffle=True)
        private_dataloader_origin = torch.utils.data.DataLoader(privateset_origin, batch_size=BATCH_SIZE, shuffle=True)
        ref_dataloader = torch.utils.data.DataLoader(refset, batch_size=BATCH_SIZE, shuffle=True)
        test_dataloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

        model = AdvRegWarper(model)
        model=model.to(device)
        criterion=nn.CrossEntropyLoss()
        optimizer, scheduler = parser_opt_scheduler(model,config)
        attack_model=InferenceAttack_HZ(config.num_classes)
        attack_optimizer=optim.Adam(attack_model.parameters(),lr=0.0001)
        attack_criterion=nn.MSELoss()
        attack_model=attack_model.cuda()
        model=model.cuda()
        advtune_defense(model = model, private_dataloader=private_dataloader, test_dataloader=test_dataloader,ref_dataloader=ref_dataloader,private_dataloader_origin=private_dataloader_origin, scheduler = scheduler,optimizer=optimizer,criterion=criterion,attack_optimizer=attack_optimizer,attack_criterion=attack_criterion,attack_model=attack_model, config=config, save_path = save_path, alpha=config.defense.alpha, batch_size=config.batch_size, num_epochs=config.epochs, use_cuda=True,n_classes=config.num_classes)
        sys.exit()

    for epoch in range(1, config.epochs + 1):

        if config.defense.method=='l1':
            train_for_one_epoch_with_l1(model, loss, train_loader, optimizer, epoch, l1_lambda=config.defense.l1_lambda)
        else:
            train_for_one_epoch(model, loss, train_loader, optimizer, epoch)
        # evaluation on natural examples
        print('================================================================')
        test_for_one_epoch(model, loss, train_loader, epoch)
        test_for_one_epoch(model, loss, test_loader, epoch)
        print('================================================================')
        if scheduler is not None:
            scheduler.step()
        # save checkpoint

        if epoch % config.save_freq == 0:
            print(save_path)
            torch.save(model.state_dict(),save_path)

if __name__ == '__main__':
    main()
