"""""""""
Pytorch implementation of "A simple neural network module for relational reasoning
Code is based on pytorch/examples/mnist (https://github.com/pytorch/examples/tree/master/mnist)
"""""""""
from __future__ import print_function
import argparse
import os
#import cPickle as pickle
import pickle
import random
import numpy as np

import torch
from torch.autograd import Variable

from model import RN, RN_reconstruct
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt


# Training settings
parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example')
parser.add_argument('--model', type=str, choices=['RN_SRN', 'RN', 'CNN_MLP'], default='RN', 
                    help='resume from model stored')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--start-epoch', type=int, default=50, help='start epoch')
parser.add_argument('--epochs', type=int, default=20, metavar='N',
                    help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
                    help='learning rate (default: 0.0001)')
parser.add_argument('--inner-lr', type=float, default=0.1, 
                    help='inner learning rate (default: 1)')
parser.add_argument('--inner-iters', type=int, default=10,
                    help='inner iterations (default: 10)')
parser.add_argument('--sparse-loss', type=float, default=0,
                    help='sparse loss (default: 0)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--resume', type=str,
                    help='resume from model stored')
parser.add_argument('--no-pretrained', action='store_true', default=False,
                    help='start from scratch')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.model == 'RN':
    args.use_srn=False
    model = RN_reconstruct(args)
    if not args.no_pretrained:
        model.load_state_dict(torch.load('model/428_epoch_RN_50.pth'), strict=False)
else:
    args.use_srn = True
    model = RN_reconstruct(args)
    model.load_state_dict(torch.load('recon_model/epoch_RN_SRN_50.pth'))
#     if not args.no_pretrained:
#         model.load_state_dict(torch.load('model/5_25_SRN_40.pth'), strict=False)

  
model_dirs = './model'
bs = args.batch_size
input_img = torch.FloatTensor(bs, 3, 75, 75)
input_qst = torch.FloatTensor(bs, 11)
label = torch.LongTensor(bs)
import time

writer = SummaryWriter(f"reconstruct_run/{args.model}_{args.lr}_{args.inner_lr}_{args.inner_iters}_{args.sparse_loss}_{str(time.time())[-4:]}", purge_step=0, flush_secs = 10)

if args.cuda:
    model.cuda()
    input_img = input_img.cuda()
    input_qst = input_qst.cuda()
    label = label.cuda()

input_img = Variable(input_img)
input_qst = Variable(input_qst)
label = Variable(label)

def tensor_data(data, i):
    img = torch.from_numpy(np.asarray(data[0][bs*i:bs*(i+1)]))
    qst = torch.from_numpy(np.asarray(data[1][bs*i:bs*(i+1)]))
    ans = torch.from_numpy(np.asarray(data[2][bs*i:bs*(i+1)]))

    input_img.data.resize_(img.size()).copy_(img)
    input_qst.data.resize_(qst.size()).copy_(qst)
    label.data.resize_(ans.size()).copy_(ans)


def cvt_data_axis(data):
    img = [e[0] for e in data]
    qst = [e[1] for e in data]
    ans = [e[2] for e in data]
    return (img,qst,ans)

running_loss = 0
optimizer = torch.optim.Adam(model.reconstruct.parameters(), lr=args.lr)
def train(epoch, rel):
    global running_loss
    model.train()
    random.shuffle(rel)

    rel = cvt_data_axis(rel)
    alpha = 0.05
    for batch_idx in range(len(rel[0]) // bs):
        optimizer.zero_grad()
        tensor_data(rel, batch_idx)
        recon, gen_set, losses = model(input_img)
        target_img = 1 - input_img
        loss = ((target_img - recon)**2).sum()
        cl = loss.detach().cpu().item()
        writer.add_scalar("train/loss", cl, global_step=epoch*(len(rel[0])//bs) + batch_idx)
        if batch_idx % 1000 == 0:
            for j, s_ in enumerate(gen_set[0]):
                fig = plt.figure()
                plt.imshow(s_.transpose(0,2).detach().cpu())
                writer.add_figure(f"epoch-{epoch}/img-{batch_idx}", fig, global_step=j)

            fig = plt.figure()
            plt.imshow(recon[0].transpose(0,2).detach().cpu())
            writer.add_figure(f"epoch-{epoch}/img-{batch_idx}", fig, global_step = len(gen_set[0]))

            fig = plt.figure()
            plt.imshow(target_img[0].transpose(0,2).detach().cpu())
            writer.add_figure(f"epoch-{epoch}/img-{batch_idx}-target", fig)

        running_loss = alpha*cl + (1-alpha)*running_loss
        print(f'running loss [{batch_idx}/{len(rel[0])//bs}]: {running_loss}')
        
        print('inner losses: ', [round(l.item()/bs, 2) for l in losses])
        loss.backward()
        optimizer.step()
        
def eval(epoch, rel):
    model.eval()
    rel = cvt_data_axis(rel)
    running_loss = 0
    for batch_idx in range(len(rel[0]) // bs):
        tensor_data(rel, batch_idx)
        recon, gen_set, losses = model(input_img)
        target_img = 1 - input_img
        loss = ((target_img - recon)**2).sum()
        cl = loss.detach().cpu().item()
        running_loss += cl
        if batch_idx % 10 == 0:
            for j, s_ in enumerate(gen_set[0].detach().cpu()):
                fig = plt.figure()
                plt.imshow(s_.transpose(0,2))
                writer.add_figure(f"epoch-eval-{epoch}/img-{batch_idx}", fig, global_step=j)

            fig = plt.figure()
            plt.imshow(recon[0].detach().cpu().transpose(0,2))
            writer.add_figure(f"epoch-eval-{epoch}/img-{batch_idx}", fig, global_step = len(gen_set[0]))

            fig = plt.figure()
            plt.imshow(target_img[0].detach().cpu().transpose(0,2))
            writer.add_figure(f"epoch-eval-{epoch}/img-{batch_idx}-target", fig)
    running_loss /= len(rel[0])//bs
    print(f'Eval loss: {running_loss}')
    writer.add_scalar("eval/loss", running_loss, global_step=epoch*(len(rel[0])//bs) + batch_idx)

def load_data():
    print('loading data...')
    dirs = './recon_data_small'
    filename = os.path.join(dirs,'sort-of-clevr.pickle')
    with open(filename, 'rb') as f:
      train_datasets, test_datasets = pickle.load(f)
    rel_train = []
    rel_test = []
    norel_train = []
    norel_test = []
    print('processing data...')

    for img, relations, norelations in train_datasets:
        img = np.swapaxes(img,0,2)
        for qst,ans in zip(relations[0], relations[1]):
            rel_train.append((img,qst,ans))
        for qst,ans in zip(norelations[0], norelations[1]):
            norel_train.append((img,qst,ans))

    for img, relations, norelations in test_datasets:
        img = np.swapaxes(img,0,2)
        for qst,ans in zip(relations[0], relations[1]):
            rel_test.append((img,qst,ans))
        for qst,ans in zip(norelations[0], norelations[1]):
            norel_test.append((img,qst,ans))
    
    return (rel_train, rel_test, norel_train, norel_test)
    

rel_train, rel_test, norel_train, norel_test = load_data()

try:
    os.makedirs(model_dirs)
except:
    print('directory {} already exists'.format(model_dirs))

if args.resume:
    filename = os.path.join(model_dirs, args.resume)
    if os.path.isfile(filename):
        print('==> loading checkpoint {}'.format(filename))
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint)
        print('==> loaded checkpoint {}'.format(filename))

for epoch in range(args.start_epoch, args.epochs + 1):
    eval(epoch, rel_test)
    train(epoch, rel_train)
    torch.save(model.state_dict(), 'recon_model/epoch_{}_{:02d}.pth'.format(args.model, epoch))

