import os
import os.path as osp
import numpy as np
from copy import deepcopy
import argparse
import csv

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

from dataset import setup_dataset
from models import setup_model

def target_to_roi(target):
    if target == 'c':
        return ['ventral']
    elif target == 'init_latent':
        return ['early']
    elif target == 'blip':
        return ['early', 'ventral', 'midventral',
                'midlateral', 'lateral', 'parietal']

def write_metrics(row, sweep_dir):
    csv_fp = osp.join(sweep_dir, f'train_{target}.csv')
    with open(csv_fp, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(row)

def save_preds_and_model(model, dl_te, root_dir):
    model.eval()
    preds = []
    for x, _ in dl_te:
        with torch.no_grad():
            y_pred = model(x.to(opt.gpu))
            preds.append(y_pred.cpu())
    preds = torch.cat(preds).numpy()
    roi = '_'.join(target_to_roi(target))
    preds_fp = osp.join(root_dir, f'{target}_preds_from_{roi}')
    np.save(preds_fp, preds)
    model_fp = osp.join(root_dir, f'best_model_state_dict_target={target}.pt')
    torch.save(model.state_dict(), model_fp)

def eval(model, dl_va):
    model.eval()
    va_loss = 0.0
    mse = torch.nn.MSELoss()
    for x, y in dl_va:
        x, y = x.to(opt.gpu), y.to(opt.gpu)
        with torch.no_grad():
            y_pred = model(x)
        loss = mse(y, y_pred)
        va_loss += loss.item()
    va_loss /= len(dl_va)
    model.train()
    return va_loss

def single_sweep_setting(model, lr, wd, dl_tr, dl_va, epochs, sweep_dir):
    model = model.to(opt.gpu)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = CosineAnnealingLR(optimizer, opt.epochs)
    criterion = torch.nn.MSELoss()

    metrics = ['epoch', 'train_loss', 'val_loss', 'lr']
    write_metrics(metrics, sweep_dir)

    best_model = None
    min_val_loss = 100
    epochs_since_best = 0
    for epoch in range(epochs):

        if epochs_since_best > 15:
            break

        avg_loss = 0.0
        for x, y in dl_tr:
            x, y = x.to(opt.gpu), y.to(opt.gpu)
            y_pred = model(x)
            loss = criterion(y, y_pred)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            avg_loss += loss.item()
        epoch_lr = optimizer.param_groups[0]['lr']
        scheduler.step()
        avg_loss /= len(dl_tr)

        curr_val_loss = eval(model, dl_va)
        if curr_val_loss < min_val_loss:
            min_val_loss = curr_val_loss
            best_model = deepcopy(model)
            epochs_since_best = 0
        else:
            epochs_since_best += 1
        
        metrics = [epoch, avg_loss, curr_val_loss, epoch_lr]
        write_metrics(metrics, sweep_dir)

    return min_val_loss, best_model

def main_single_target(target):

    root_dir = osp.join(opt.subject,
                        opt.bottleneck_model,
                        f'bottleneck_dim={opt.bottleneck_dim}')
    os.makedirs(root_dir, exist_ok=True)

    # set up the dataset
    ds_tr, ds_va, ds_te = setup_dataset(opt.subject, target_to_roi(target),
                                         target, split_fraction=opt.split_fraction)
    dl_tr = DataLoader(ds_tr, batch_size=opt.bs_tr, shuffle=True,
                        num_workers=opt.num_workers)
    dl_va = DataLoader(ds_va, batch_size=opt.bs_va, shuffle=True,
                        num_workers=opt.num_workers)
    dl_te = DataLoader(ds_te, batch_size=opt.bs_te, shuffle=False,
                        num_workers=opt.num_workers, drop_last=False)

    # use bottleneck dimension, input and output dimensions to set up the model
    orig_model = setup_model(opt.bottleneck_model, opt.bottleneck_dim,
                             ds_te.X.shape[1], ds_te.Y.shape[1])

    # sweep over the hyperparameters
    global_min_val_loss = 100
    best_lr, best_wd = None, None
    for lr in opt.lrs:
        for wd in opt.wds:

            sweep_dir = osp.join(root_dir, f'lr={lr}_wd={wd}')
            os.makedirs(sweep_dir, exist_ok=True)

            val_loss, model = single_sweep_setting(deepcopy(orig_model),
                                                    lr, wd, dl_tr, dl_va,
                                                    opt.sweep_epochs, sweep_dir)

            if val_loss < global_min_val_loss:
                global_min_val_loss = val_loss
                best_lr, best_wd = lr, wd

    _, model = single_sweep_setting(deepcopy(orig_model),
                                    best_lr, best_wd, dl_tr, dl_va,
                                    opt.epochs, root_dir)
    save_preds_and_model(model, dl_te, root_dir)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--subject", type=str, default=None,
                        help="subj01 or subj02  or subj05  or subj07")
    parser.add_argument("--bottleneck_dim", type=int, default=None,
                        help="1, 5, 10, etc.")
    parser.add_argument("--bottleneck_model", type=str, default=None,
                        help="check options in models.py")
    parser.add_argument("--gpu", type=int, default=None)

    # NOTE these don't really change
    parser.add_argument("--lrs", nargs='+', type=float, default=[1e-5, 1e-4, 1e-3, 1e-2],
                        help='list of lrs to sweep over')
    parser.add_argument("--wds", nargs='+', type=float, default=[1e-3, 1e-2, 1e-1, 5e-1],
                        help='list of wds to sweep over')
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--sweep_epochs", type=int, default=25)
    parser.add_argument("--bs_tr", type=int, default=1024)
    parser.add_argument("--bs_va", type=int, default=350)
    parser.add_argument("--bs_te", type=int, default=100)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--split_fraction", type=float, default=0.1)

    opt = parser.parse_args()

    target = 'init_latent'
    main_single_target(target)

    target = 'blip'
    main_single_target(target)