import argparse

import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from model import *
from utils import *

from betty.engine import Engine
from betty.problems import ImplicitProblem
from betty.configs import Config, EngineConfig


parser = argparse.ArgumentParser(description="Meta_Weight_Net")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--strategy", type=str, default="default")
parser.add_argument("--rollback", action="store_true")
parser.add_argument("--seed", type=int, default=2)

parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--dampening", type=float, default=0.0)
parser.add_argument("--nesterov", type=bool, default=True)
parser.add_argument("--weight_decay", type=float, default=5e-4)

parser.add_argument("--ratio", type=float, default=0.9)
parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--baseline", action="store_true")
parser.add_argument("--random", action="store_true")

args = parser.parse_args()
set_seed(args.seed)

print("Load saved data")
transform_train = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ]
)
transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ]
)

dataset_cls = datasets.CIFAR10 if args.dataset == "cifar10" else datasets.CIFAR100
train_dataset = dataset_cls(
    root="./data", train=True, download=True, transform=transform_train
)
orig_len = len(train_dataset)
if not args.baseline:
    sorted_idx = np.random.permutation(len(train_dataset))
    if not args.random:
        print("meta pruning")
        sorted_idx = torch.load("sorted_index.pt")
    else:
        print("random pruning")
    filter_len = int(len(train_dataset) * args.ratio)
    filter_idx = sorted_idx[:filter_len]
    train_dataset = torch.utils.data.Subset(train_dataset, filter_idx)
pruned_len = len(train_dataset)
print("before pruning:", orig_len, "|| after pruning:", pruned_len)
test_dataset = dataset_cls(
    root="./data", train=False, download=True, transform=transform_test
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=args.batch_size,
    pin_memory=True,
    num_workers=2,
)
EPOCH_LEN = int(len(train_dataset) * 1.0 // args.batch_size)
TOTAL_ITER = EPOCH_LEN * 200


class Inner(ImplicitProblem):
    def training_step(self, batch):
        inputs, labels = batch
        outputs, _= self.forward(inputs)
        loss = F.cross_entropy(outputs, labels.long())

        return loss

    def configure_train_data_loader(self):
        return train_dataloader

    def configure_module(self):
        return ResNet32(args.dataset == "cifar10" and 10 or 100)

    def configure_optimizer(self):
        optimizer = optim.SGD(
            self.module.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            dampening=args.dampening,
            weight_decay=args.weight_decay,
            nesterov=True,
        )
        return optimizer

    def configure_scheduler(self):
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            TOTAL_ITER,
            eta_min=1e-4,
        )
        return scheduler


best_acc = -1


class ReweightingEngine(Engine):
    @torch.no_grad()
    def validation(self):
        correct = 0
        total = 0
        global best_acc
        for x, target in test_dataloader:
            x, target = x.to(args.device), target.to(args.device)
            out, _ = self.inner(x)
            correct += (out.argmax(dim=1) == target).sum().item()
            total += x.size(0)
        acc = correct / total * 100
        if best_acc < acc:
            best_acc = acc
        return {"acc": acc, "best_acc": best_acc}


inner_config = Config(type="darts", fp16=args.fp16, unroll_steps=2, darts_alpha=0.01)
engine_config = EngineConfig(
    train_iters=TOTAL_ITER,
    valid_step=EPOCH_LEN,
    strategy=args.strategy,
    roll_back=args.rollback,
    logger_type="tensorboard",
)
inner = Inner(name="inner", config=inner_config)

problems = [inner]
u2l, l2u = {}, {}
dependencies = {"l2u": l2u, "u2l": u2l}

engine = ReweightingEngine(
    config=engine_config, problems=problems, dependencies=dependencies
)
engine.run()
