import os

import numpy as np
import torch
torch.autograd.set_detect_anomaly(True)
from torch import optim

from data_loader.data_loader import data_loaders
from model.MLP_vae import MLPVAE
from model.conv_vae import ConvVAE
import warnings; warnings.filterwarnings("ignore")

torch.set_printoptions(2, linewidth=200)

def initialize_model():
    """ Initializes model, sets up data loader and optimizers with necessary hyperparameters"""

    global device, train_loader, test_loader
    global model, arch_optimizer, weight_optimizer, scheduler

    #use gpu is available
    args.cuda = torch.cuda.is_available()
    gpu = "cuda:"+str(args.gpu)
    device = torch.device(gpu if args.cuda else "cpu")
    print("Training on ", device)

    #load data
    train_loader, test_loader = data_loaders(args)

    # random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed_all(args.seed)

    # initialize a model
    models = {"MLP": MLPVAE,"conv": ConvVAE}
    model_class = models[args.arch]
    model = model_class(
        args=args,
        device=device,
        img_shape=args.img_shape,
        h_dim=args.h_dim,
        z_dim=args.z_dim,
        truncation = args.truncation
    ).to(device)

    #create directories for saved models and log files if not already present
    os.makedirs("./saved_models/" + args.dataset, exist_ok=True)
    os.makedirs("./loss_logs/" + args.dataset, exist_ok=True)

    #load saved file from directory if present
    if os.path.exists(args.model_file):
        model.load_state_dict(torch.load(args.model_file))

    #separating network weight parameters and architecture params (beta process params)
    weights = []
    arch_params = []
    for name, param in model.named_parameters():
        if "var_" in name:
            arch_params.append(param)
        else:
            weights.append(param)

    # adjusting learning rate over epochs
    milestones = [3, 10, 20, 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 710, 720, 750, 800, 850, 900, 1000, 1200, 1400]

    #optimizers and scheduler
    arch_optimizer = optim.Adam(arch_params, lr=args.arch_learning_rate)
    weight_optimizer = optim.AdamW(weights, lr=args.learning_rate, eps=3e-3, weight_decay=1e-6)
    scheduler = optim.lr_scheduler.MultiStepLR(weight_optimizer, milestones=milestones, gamma=0.70)


def train(epoch):
    model.train()

    # scale factor for KL divergence for mini batch training
    kl_scale = args.arch_beta * (len(train_loader)/len(train_loader.dataset))

    #iterating over mini batches
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.squeeze()
        _, loss, enc_kl_arch, dec_kl_arch = model(data, args.S, args.M, args.K)

        arch_optimizer.zero_grad()
        weight_optimizer.zero_grad()

        loss = loss + (enc_kl_arch + dec_kl_arch) * kl_scale
        loss.backward(retain_graph=True)
        arch_optimizer.step()
        weight_optimizer.step()

        model.train_step += 1
        if model.train_step % args.log_interval == 0:
            print("Train Epoch: {}/{} ({:.0f}%)\tLoss: {:.6f}\t ".format(epoch, args.epochs, 100.0 * batch_idx / len(train_loader), loss.item()))


def test(epoch):
    model.eval()

    with torch.no_grad():
        elbos = []
        for data, _ in test_loader:
            data = data.squeeze()
            elbo, _, _, _ = model(data, args.test_arch_n, 1, args.log_likelihood_k)
            elbos.append(elbo.squeeze(0))

        losses = [model.logmeanexp(elbo, 0).squeeze(0)[:args.S].cpu().numpy().flatten() for elbo in elbos]
        test_loss = -np.concatenate(losses).mean()
        print("-" * 100)
        print("ELBO: ", round(test_loss, 2))
        print("-"*100)

        return test_loss

def fit(args_l):
    """ Sets up a model and does the training"""

    global args
    args = args_l

    # initializes model, optimizers and data loaders for training
    initialize_model()

    #instantiate a logfile
    fpr = open(args.loss_logfile, "a")
    fpr.write("batch_size = " + str(args.batch_size) + "; learning_rate = " + str(args.learning_rate)  +"; arch_learning_rate = " +
              str(args.arch_learning_rate) + os.linesep)
    fpr.flush()

    #carry out evaluation on the saved model/ initialized model
    with torch.no_grad():
        test_loss = test(0)
        model.best_loss = test_loss

    #start training
    for epoch in range(1, args.epochs + 1):
        train(epoch)

        #evaluate every 20 epochs
        if ((epoch==1) or (epoch%20==0)):
            with torch.no_grad():
                test_loss = test(epoch)
                fpr.write("epoch = " + str(epoch) + "; loss = " + str(test_loss) + os.linesep)
                fpr.flush()

                #save this model if the performance is better
                if test_loss < model.best_loss:
                    model.best_loss = test_loss
                    torch.save(model.state_dict(), args.model_file)
        scheduler.step()
    fpr.close()

def eval(args_l):
    global args
    args = args_l
    initialize_model()
    with torch.no_grad():
        test_loss = test(0)
    print(test_loss)