import torch
import torch.nn as nn
from util import log
import numpy as np
from modules import *


class Model(nn.Module):
	def __init__(self, task_gen, args):
		super(Model, self).__init__()
		self.args = args

		# Encoder
		log.info('Building encoder...')
		if args.encoder == 'c4_group_cnn':
			if 'cifar100' in args.transformation_method:
				self.encoder = C4_Encoder_group_cnn(in_channels=3)
			else:
				self.encoder = C4_Encoder_group_cnn(in_channels=1)
		elif args.encoder == 'mlp':
			self.encoder = Encoder_mlp(args)
		elif args.encoder == 'resnet':
			self.encoder = Encoder_ResNet()

		self.z_size = 128
		self.mid_size = 256
		self.num_memory = args.num_memory
		self.num_memory_layer = args.num_memory_layer

		self.activation = nn.ReLU()
		phi = []
		for i in range(1):
			phi.append(nn.Sequential(nn.Linear(self.z_size, self.mid_size),
									 self.activation))
		self.phi = nn.ModuleList(phi)

		if self.args.use_memory == 1:
			self.memory_key_list = nn.Parameter(torch.randn(2, self.num_memory, self.z_size, self.mid_size),
												requires_grad=True)
			self.memory_value_list = nn.Parameter(torch.randn(2, self.num_memory, self.z_size, self.mid_size),
												  requires_grad=True)

		self.diff_coef = nn.Parameter(torch.zeros(self.z_size), requires_grad=True)

	def forward(self, x_seq, device):
		# Encode all images in sequence
		z_in_seq = []
		z_choices_seq = []
		for t in range(x_seq.shape[1]):
			if len(x_seq.shape) == 4:
				x_t = x_seq[:, t, :, :].unsqueeze(1)
			else:
				x_t = x_seq[:, t, :, :, :]
			z_t = self.encoder(x_t)
			if t >= 3:
				z_choices_seq.append(z_t)
			else:
				z_in_seq.append(z_t)
		z_seq = torch.stack(z_in_seq, dim=1)
		z_choices = torch.stack(z_choices_seq, dim=1)

		z_1 = z_seq[:, 0, :]
		z_2 = z_seq[:, 1, :]
		z_3 = z_seq[:, 2, :]

		# Meta-learn weights of neural net that maps z_1 to z_2
		W_list = []
		z_in = z_1
		for i in range(2):
			if i == 0:
				z_pseudotarget = self.phi[0](z_2)
			else:
				z_pseudotarget = z_2

			z_in_query = z_in

			W_query = self._meta_find_query(z_in=z_in_query, z_target=z_pseudotarget, device=device)	# of size (batch_size, z_size * 2, z_size + 1)

			if self.args.num_memory > 0:
				W = self._meta_attention(index=i, W=W_query, M_key=self.memory_key_list[i], M_value=self.memory_value_list[i],
										 device=device)

				W_list.append(W)
			else:
				W = W_query
				W_list.append(W_query)
			z_in = self._meta_forward(index=i, z=z_in, W=W, device=device)

		# Compute predicted image
		z_in = z_3
		for i in range(2):
			z_in = self._meta_forward(z=z_in, W=W_list[i], device=device, index=i)
		z_predicted = z_in

		z_predicted = torch.clip(z_predicted, min=-1e6, max=1e6)
		diff = torch.exp(self.diff_coef) * (z_choices - torch.stack([z_predicted] * 4, dim=1)) ** 2
		y_pred_linear = -torch.sum(diff, dim=[-1]) / self.z_size
		y_pred_linear = torch.clip(y_pred_linear, min=-1e6)
		y_pred = y_pred_linear.argmax(1)
		return y_pred_linear, y_pred

	def _meta_forward(self, index, z, W, device):
		if index == 0:
			out = self.activation(torch.bmm(W, z.unsqueeze(-1))).squeeze(-1)
		else:
			out = torch.bmm(W, z.unsqueeze(-1)).squeeze(-1)
		return out

	def _meta_attention(self, index, W, M_key, M_value, device):
		"""
		:param W: of size (batch_size, A, B)
		:param M: of size (num_memory, A, B)
		:return: attention output of size (batch_size, A, B)
		"""
		batch_size, A, B = W.shape[0], W.shape[1], W.shape[2]

		W_flatten = W.reshape(batch_size, -1).unsqueeze(-1)		# of size (batch_size, A*B)

		M_key_flatten = M_key.reshape(M_key.shape[0], -1)		# of size (num_memory, A*B)
		M_key_batch = torch.stack([M_key_flatten] * batch_size, dim=0)		# of size (batch_size, num_memory, A*B)

		M_value_flatten = M_value.reshape(M_value.shape[0], -1)  # of size (num_memory, A*B)
		M_value_batch = torch.stack([M_value_flatten] * batch_size, dim=0)  # of size (batch_size, num_memory, A*B)

		attention_weight = torch.bmm(M_key_batch, W_flatten) / np.sqrt(A * B)	# of size (batch_size, num_memory, 1)
		attention_weight = torch.clip(attention_weight, min=-1e6, max=1e6)

		W_attention = torch.sum(attention_weight * M_value_batch, dim=1)		# of size (batch_size, A*B)
		W_attention = W_attention.reshape(batch_size, A, B)
		return W_attention

	def _meta_find_query(self, z_in, z_target, device):
		"""
		:param z_in: of size (batch_size, z_in_size)
		:param z_target: of size (batch_size, z_target_size)
		:return: weight of size (z_in_size + 1, z_target_size) (including bias)
		"""
		z_in = z_in.unsqueeze(-1)
		z_in_pseudoinverse = self._approx_pseudo_inverse(z_in)	# of size (batch_size, 1, z_in_size)

		weights = torch.bmm(z_target.unsqueeze(-1), z_in_pseudoinverse)		# of size (batch_size, z_target_size, z_in_size)
		return weights

	def _approx_pseudo_inverse(self, A, iterative_step=3):
		A_init = self.args.pseudoinverse_init * A
		A_pseudoinverse = A_init.transpose(1, 2)  # of size (batch_size, B, A)
		for i in range(iterative_step):
			A_pseudoinverse = 2 * A_pseudoinverse - torch.bmm(torch.bmm(A_pseudoinverse, A), A_pseudoinverse)
		return A_pseudoinverse

