import argparse
import numpy as np
import pickle as pk
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from tinydfa import DFAManager, FeedbackLayer
from tinydfa.alignment import GradientAlignmentMetrics
from tinydfa.rp.differential_privacy import RandomProjectionDP, TernarizedRandomProjectionDP, OpticalRandomProjectionDP
from tinydfa.utils.normalizations import FeedbackNormalization
from tinydfa.asymmetric import AsymmetricFunction, ForwardFunction, BackwardFunction


class GradNoise:
    # For applying DP noise on BP (for DFA, handled in tinydfa.rp.differential_privacy ops)
    def __init__(self, tau_feedback_privacy, sigma_privacy):
        self.tau_feedback_privacy = tau_feedback_privacy
        self.sigma_privacy = sigma_privacy

        class GradNoiseFunction(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                return input

            @staticmethod
            def backward(ctx, grad_output):
                tau_feedback_clip = (
                    (self.tau_feedback_privacy / (grad_output.norm(2, dim=1)))
                    .unsqueeze(1)
                    .repeat(1, grad_output.shape[1])
                )
                grad_output[grad_output >= 1.0] = 1.0
                grad_output = grad_output * tau_feedback_clip

                noise = (
                    torch.randn(grad_output.shape, device=grad_output.device)
                    / np.sqrt(grad_output.shape[1])
                    * self.sigma_privacy
                    * 10
                )

                return (grad_output + noise) / np.sqrt(grad_output.shape[1])

        self.grad_noise_function = GradNoiseFunction.apply

    def __call__(self, input):
        return self.grad_noise_function(input)


# Fully connected neural network
class MNISTFullyConnected(nn.Module):
    def __init__(
        self,
        hidden_size,
        sigma_privacy=0.01,
        tau_privacy=(1e-6, 0.1),
        tau_feedback_privacy=0.2,
        training_method="DFA",
    ):
        super(MNISTFullyConnected, self).__init__()
        self.fc1 = nn.Linear(784, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 10)

        self.training_method = training_method

        normalization = FeedbackNormalization.FAN_OUT
        activation = torch.relu
        self.rp = RandomProjectionDP(
            sigma_privacy=sigma_privacy,
            tau_feedback_privacy=tau_feedback_privacy,
            verbose=False,
        )
        if self.training_method == "TDFA":
            self.rp = TernarizedRandomProjectionDP(
                ternarization_treshold=0.15,
                sigma_privacy=sigma_privacy,
                tau_feedback_privacy=tau_feedback_privacy,
                verbose=False,
            )
            activation = AsymmetricFunction(ForwardFunction.TANH, BackwardFunction.TANH)
        elif self.training_method == "ODFA":
            self.rp = OpticalRandomProjectionDP(sigma_privacy=sigma_privacy, ternarization_treshold=0.15, verbose=True)
            activation = AsymmetricFunction(ForwardFunction.TANH, BackwardFunction.TANH)
        elif self.training_method == "DFA":
            activation = AsymmetricFunction(ForwardFunction.TANH, BackwardFunction.TANH)

        self.dfa1, self.dfa2 = FeedbackLayer(), FeedbackLayer()
        self.dfa = DFAManager(
            [self.dfa1, self.dfa2],
            no_feedbacks=(self.training_method == "SHALLOW"),
            rp_operation=self.rp,
            normalization=normalization,
        )

        self.bp_noise_function = lambda x: x
        if self.training_method == "BP":
            # Switch to backpropagation.
            self.dfa.use_bp = True
            self.bp_noise_function = GradNoise(tau_feedback_privacy=tau_feedback_privacy, sigma_privacy=sigma_privacy)

        if self.training_method == "SHALLOW":
            # It's good practice in shallow not to collect feedback points.
            self.dfa.record_feedback_point_all = False

        self.activation = lambda x: (activation(x) + (tau_privacy[0] / np.sqrt(hidden_size))).clamp(
            -tau_privacy[1] / np.sqrt(hidden_size), tau_privacy[1] / np.sqrt(hidden_size)
        )

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        x = self.dfa1(self.activation(self.fc1(x)))
        x = self.dfa2(self.activation(self.fc2(x)))
        x = self.dfa(self.fc3(x))
        x = self.bp_noise_function(x)
        return F.log_softmax(x)


def train(model, train_loader, optimizer, alignment, device, epoch, dry_run):
    model.train()
    training_data = {"epoch": epoch, "loss": [], "alignment": None}
    for b, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        if b == 0:
            angles, alignments = alignment(data, target, F.nll_loss)
            training_data["alignment"] = alignments
            print(alignments)

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch}: training loss at batch {b}/{len(train_loader)}: {loss.item():.4f}", end="\r")

        training_data["loss"].append(float(loss.item()))

        if dry_run:
            break

    return training_data


def validation(model, validation_loader, device, epoch, val_batch_size):
    model.eval()
    validation_loss, correct = 0, 0
    validation_data = {"epoch": epoch, "loss": None, "accuracy": None}
    with torch.no_grad():
        for b, (data, target) in enumerate(validation_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            validation_loss += F.nll_loss(output, target, reduction="sum").item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    validation_samples = len(validation_loader) * val_batch_size
    validation_loss /= validation_samples

    validation_data["loss"] = validation_loss
    validation_data["accuracy"] = correct / validation_samples * 100

    print(
        f"Epoch {epoch}: validation loss {validation_loss:.4f}, " f"accuracy {correct / validation_samples * 100:.2f}."
    )

    return validation_data


def get_loaders(dataset_name, dataset_path, batch_size, val_batch_size, use_gpu):
    gpu_args = {"num_workers": 8, "pin_memory": True} if use_gpu else {}

    if dataset_name == "MNIST":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        torchvision_dataset = torchvision.datasets.MNIST
    elif dataset_name == "FashionMNIST":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        torchvision_dataset = torchvision.datasets.FashionMNIST
    else:
        raise ValueError(f"Unknown dataset {dataset_name}!")

    train_val_dataset = torchvision_dataset(dataset_path, train=True, download=True, transform=transform)
    test_dataset = torchvision_dataset(dataset_path, train=False, download=True, transform=transform)

    train_val_samples = len(train_val_dataset)
    indices = list(range(train_val_samples))
    split_train_val = int(np.floor(0.1 * train_val_samples))
    train_indices, val_indices = indices[split_train_val:], indices[:split_train_val]

    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(
        train_val_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        **gpu_args,
    )
    val_loader = torch.utils.data.DataLoader(
        train_val_dataset,
        batch_size=val_batch_size,
        sampler=val_sampler,
        **gpu_args,
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=val_batch_size,
        **gpu_args,
    )

    return train_loader, val_loader, test_loader


def main(args):
    # Setup hardware/GPU:
    use_gpu = not args.no_gpu and torch.cuda.is_available()
    device = torch.device(f"cuda:{args.gpu_id}" if use_gpu else "cpu")
    torch.manual_seed(args.seed)

    # Setup data:
    train_loader, validation_loader, test_loader = get_loaders(
        args.dataset_name, args.dataset_path, args.batch_size, args.val_batch_size, use_gpu
    )

    model = MNISTFullyConnected(
        args.hidden_size,
        sigma_privacy=args.sigma_privacy,
        tau_privacy=(args.tau_min_privacy, args.tau_max_privacy),
        tau_feedback_privacy=args.tau_feedback_privacy,
        training_method=args.training_method,
    ).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum)

    alignment = GradientAlignmentMetrics(model)

    training_data, validation_data = [], []
    for epoch in range(1, args.epochs + 1):
        epoch_training_data = train(model, train_loader, optimizer, alignment, device, epoch, args.dry_run)
        epoch_validation_data = validation(model, validation_loader, device, epoch, args.val_batch_size)

        training_data.append(epoch_training_data)
        validation_data.append(epoch_validation_data)

    print("=== FINAL PERFORMANCE (TEST SET) ===")
    test_data = validation(model, test_loader, device, "test", args.val_batch_size)
    validation_data.append(test_data)

    with open(args.save_path + f"data_{args.run_id}.pk", "wb") as data_save_file:
        pk.dump((training_data, validation_data), data_save_file)
    with open(args.save_path + f"model_{args.run_id}.torch", "wb") as model_save_file:
        torch.save(model, model_save_file)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="(Fashion)MNIST Photonic Training")

    parser.add_argument(
        "-s",
        "--save-path",
        type=str,
        default="/save",
        help="folder in which to save model and run data (default:/save)",
    )
    parser.add_argument(
        "-i", "--run-id", type=str, default="run.pk", help="name of the run for data saving (default:)"
    )

    # Training/DFA related options:
    parser.add_argument(
        "-t",
        "--training-method",
        type=str,
        choices=["BP", "DFA", "TDFA", "ODFA", "SHALLOW"],
        default="ODFA",
        help="training method to use, choose from backpropagation (BP), direct feedback "
        "alignment (DFA), or only topmost layer (SHALLOW) (default: ODFA)",
    )
    parser.add_argument(
        "--ternarization-treshold",
        type=float,
        default=0.15,
        help="treshold for ternarization with TDFA/ODFA (default: 0.15)",
    )
    parser.add_argument(
        "--sigma-privacy",
        type=float,
        default=0.2,
        help="DP sigma value for synthetic gradient noise (default: 0.01)",
    )
    parser.add_argument(
        "--tau-min-privacy",
        type=float,
        default=1e-6,
        help="tau_min value for activation offsetting (default:1e-6)",
    )
    parser.add_argument(
        "--tau-max-privacy",
        type=float,
        default=1.0,
        help="DP tau_max value for activation clipping (default:1.0)",
    )
    parser.add_argument(
        "--tau-feedback-privacy",
        type=float,
        default=1.0,
        help="DP tau_feedback value for feedback clipping (default:1.0)",
    )

    # Model definition:
    parser.add_argument("--hidden-size", type=int, default=512, help="hidden layer size (default: 256)")

    # Training batch and epochs:
    parser.add_argument("--batch-size", type=int, default=256, help="training batch size (default: 128)")
    parser.add_argument("--val-batch-size", type=int, default=1000, help="validation batch size (default: 1000)")
    parser.add_argument("--epochs", type=int, default=15, help="number of epochs to train (default: 15)")

    # Optimization:
    parser.add_argument("--learning-rate", type=float, default=1e-2, help="SGD learning rate (default: 0.01)")
    parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)")

    # Hardware/GPUs:
    parser.add_argument("--no-gpu", action="store_true", default=False, help="disables GPU training")
    parser.add_argument("--gpu-id", type=int, default=0, help="id of the gpu to use (default: 0)")

    # Technical options:
    parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass")
    parser.add_argument("--seed", type=int, default=0, help="random seed (default: 0)")

    # Dataset:
    parser.add_argument(
        "-d",
        "--dataset-name",
        type=str,
        choices=["MNIST", "FashionMNIST"],
        default="FashionMNIST",
        help="name of the MNIST variant on which to train, choose from MNIST, or FashionMNIST",
    )
    parser.add_argument("-p", "--dataset-path", type=str, default="/data", help="path to dataset (default: /data)")

    args = parser.parse_args()
    main(args)
