import argparse
import os
import sys
import time
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Prevent python from saving out .pyc files
sys.dont_write_bytecode = True
# Add models and tasks to path
sys.path.insert(0, './models')
sys.path.insert(0, './tasks')
# Logging utility
from util import log

# Method for creating directory if it doesn't exist yet
def check_path(path):
	if not os.path.exists(path):
		os.mkdir(path)

class seq_dataset(Dataset):
	def __init__(self, dset, args):
		self.seq = dset['seq']
		self.y = dset['y']
		self.len = self.seq.shape[0]
	def __len__(self):
		return self.len
	def __getitem__(self, idx):
		seq = self.seq[idx]
		y = self.y[idx]
		return seq, y


def plot(tensor):
	import matplotlib.pyplot as plt
	num_img = len(tensor)
	fig, ax_list = plt.subplots(1, num_img)

	if len(tensor.shape) == 3:
		for i in range(num_img):
			ax_list[i].imshow(np.array(tensor[i]), cmap='gray')
			ax_list[i].set_xticks([])
			ax_list[i].set_yticks([])
		# plt.show()
		plt.tight_layout(pad=0)
		plt.savefig('svg/temp.svg', format='svg', dpi=1200)
	else:
		for i in range(num_img):
			ax_list[i].imshow(np.array(tensor[i]).transpose((1, 2, 0)))
			ax_list[i].set_xticks([])
			ax_list[i].set_yticks([])
		plt.show()


def train(args, model, device, optimizer, epoch, train_loader):
	# Set to training mode
	model.train()
	# Iterate over batches

	loss_sum = 0
	acc_sum = 0
	for batch_idx, (x_seq, y) in enumerate(train_loader):
		# Batch start time

		# print(x_seq.shape)
		# plot(x_seq[0])
		# log.info(y[0])

		# Load data to device
		x_seq = x_seq.to(device)
		y = y.to(device)
		# Zero out gradients for optimizer
		optimizer.zero_grad()
		# Run model
		y_pred_linear, y_pred = model(x_seq, device)
		# Loss
		loss_fn = nn.CrossEntropyLoss()
		loss = loss_fn(y_pred_linear, y.long())

		# Update model
		loss.backward()
		optimizer.step()
		# Batch duration
		end_time = time.time()

		loss_sum += loss
		acc = torch.eq(y_pred, y).float().mean().item() * 100.0
		acc_sum += acc
	log.info('[Epoch: ' + str(epoch) + '] ' + \
			 '[Loss = ' + '{:.4f}'.format((loss_sum / (batch_idx + 1)).item()) + '] ' + \
			 '[Accuracy = ' + '{:.2f}'.format((acc_sum) / (batch_idx + 1)) + '] ')


def test(args, model, device, test_loader):
	log.info('Evaluating on test set...')
	# Set to eval mode
	model.eval()

	# Iterate over batches
	all_acc = []
	all_loss = []
	for batch_idx, (x_seq, y) in enumerate(test_loader):
		# plot(x_seq[0])
		# log.info(y[0])

		x_seq = x_seq.to(device)
		y = y.to(device)

		# Run model
		y_pred_linear, y_pred = model(x_seq, device)
		# Loss
		loss_fn = nn.CrossEntropyLoss()
		loss = loss_fn(y_pred_linear, y.long())

		all_loss.append(loss.item())
		# Accuracy
		acc = torch.eq(y_pred, y).float().mean().item() * 100.0
		all_acc.append(acc)
	# Report overall test performance
	avg_loss = np.mean(all_loss)
	avg_acc = np.mean(all_acc)

	log.info('[Test Summary] ' + \
				 '[Loss = ' + '{:.4f}'.format(avg_loss) + '] ' + \
				 '[Accuracy = ' + '{:.2f}'.format(avg_acc) + ']')
	# Save performance
	test_dir = './test/'
	check_path(test_dir)
	task_dir = test_dir + args.task + '/'
	check_path(task_dir)
	gen_dir = task_dir + 'm' + str(args.m_holdout) + '/'
	check_path(gen_dir)
	model_dir = gen_dir + args.model_name + '/'
	check_path(model_dir)
	test_fname = model_dir + 'run' + args.run + '.txt'
	test_f = open(test_fname, 'w')
	test_f.write('loss acc\n')
	test_f.write('{:.4f}'.format(avg_loss) + ' ' + \
				 '{:.2f}'.format(avg_acc))
	test_f.close()

def main():

	# Settings
	parser = argparse.ArgumentParser()
	# Model settings
	parser.add_argument('--model_name', type=str, default='ESBN', help="{'ESBN', 'Transformer', 'NTM', 'LSTM', 'PrediNet', 'RN', 'MNM', 'TRN', 'ESBN_confidence_ablation', 'ESBN_default_memory'}")
	parser.add_argument('--norm_type', type=str, default='contextnorm', help="{'nonorm', 'contextnorm', 'tasksegmented_contextnorm'}")
	parser.add_argument('--encoder', type=str, default='conv', help="{'c4_group_cnn', 'mlp', 'resnet'}")

	parser.add_argument('--use_memory', type=int, default=1, help="whether use memory or not - used for FINE model")
	parser.add_argument('--num_memory', type=int, default=16, help="number of memory matrices - used for FINE model")
	parser.add_argument('--num_memory_layer', type=int, default=4, help="number of NICE layers - used for FINE model")
	parser.add_argument('--pseudoinverse_init', type=float, default=1e-4, help="initialization of Ben-Cohen algo for approximating pseudo-inverses")

	# Task settings
	parser.add_argument('--task', type=str, default='same_diff', help="{'transformation'}")
	parser.add_argument('--n_shapes', type=int, default=100, help="n = total number of shapes available for training and testing")
	parser.add_argument('--m_holdout', type=int, default=0, help="m = number of objects (out of n) withheld during training")
	parser.add_argument('--transformation_mixture', type=str, default='translation,rotation,scale,shear,reflection',
						help="transformations used in mixture tasks")
	parser.add_argument('--generalization_type', type=str, default='object', help="{object, object+function}")

	# Training settings
	parser.add_argument('--num_train', type=int, default=10000, help="number of training data points")
	parser.add_argument('--train_batch_size', type=int, default=32)
	parser.add_argument('--lr', type=float, default=5e-4)
	parser.add_argument('--weight_decay', type=float, default=0)
	parser.add_argument('--epochs', type=int, default=50)
	parser.add_argument('--transformation_method', type=str)
	parser.add_argument('--log_interval', type=int, default=10)

	# Test settings
	parser.add_argument('--num_test', type=int, default=20000, help="number of testing data points")
	parser.add_argument('--test_batch_size', type=int, default=100)

	# Device settings
	parser.add_argument('--no-cuda', action='store_true', default=False)
	parser.add_argument('--device', type=int, default=0)

	# Run number
	parser.add_argument('--run', type=str, default='1')
	args = parser.parse_args()

	# Set up cuda
	use_cuda = not args.no_cuda and torch.cuda.is_available()
	device = torch.device("cuda:" + str(args.device) if use_cuda else "cpu")
	kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

	# Randomly assign objects to training or test set
	if 'omniglot' in args.transformation_method:
		all_shapes = np.arange(900)
	elif 'cifar100-' in args.transformation_method:
		all_shapes = np.arange(100)
	else:
		all_shapes = np.arange(args.n_shapes)

	np.random.shuffle(all_shapes)
	if args.m_holdout > 0:
		train_shapes = all_shapes[args.m_holdout:]
		test_shapes = all_shapes[:args.m_holdout]
	else:
		train_shapes = all_shapes
		test_shapes = all_shapes

	# Generate training and test sets
	task_gen = __import__(args.task)
	log.info('Generating task: ' + args.task + '...')
	args, train_set, test_set = task_gen.create_task(args, train_shapes, test_shapes)

	# Convert to PyTorch DataLoaders
	train_set = seq_dataset(train_set, args)
	train_loader = DataLoader(train_set, batch_size=args.train_batch_size, shuffle=True)
	test_set = seq_dataset(test_set, args)
	test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=True)

	# Create model
	model_class = __import__(args.model_name)
	model = model_class.Model(task_gen, args).to(device)

	# Append relevant hyperparameter values to model name
	args.model_name = args.model_name + '_' + args.norm_type + '_lr' + str(args.lr)

	# Create optimizer
	log.info('Setting up optimizer...')
	optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

	# Train
	log.info('Training begins...')
	for epoch in range(1, args.epochs + 1):
		# Training loop
		train(args, model, device, optimizer, epoch, train_loader)
		if epoch % 10 == 0:
			test(args, model, device, test_loader)
	# Test model
	test(args, model, device, test_loader)


if __name__ == '__main__':
	main()
