# This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml

import backbone
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from methods.meta_template import MetaTemplate


class MAML(MetaTemplate):
    def __init__(self, model_func, n_way, n_support, approx=False):
        super(MAML, self).__init__(model_func, n_way, n_support, change_way=False)

        self.loss_fn = nn.CrossEntropyLoss()
        self.classifier = backbone.Linear_fw(self.feat_dim, n_way)
        self.classifier.bias.data.fill_(0)

        self.n_task = 4
        self.task_update_num = 5
        self.train_lr = 0.01
        self.approx = approx  # first order approx.

    def forward(self, x):
        out = self.feature.forward(x)
        scores = self.classifier.forward(out)
        return scores

    def set_forward(self, x, is_feature=False):
        assert is_feature == False, 'MAML do not support fixed feature'
        x = x.cuda()
        x_var = Variable(x)
        x_a_i = x_var[:, :self.n_support, :, :, :].contiguous().view(self.n_way * self.n_support,
                                                                     *x.size()[2:])  # support data
        x_b_i = x_var[:, self.n_support:, :, :, :].contiguous().view(self.n_way * self.n_query,
                                                                     *x.size()[2:])  # query data
        y_a_i = Variable(
            torch.from_numpy(np.repeat(range(self.n_way), self.n_support))).cuda()  # label for support data

        fast_parameters = list(self.parameters())  # the first gradient calcuated in line 45 is based on original weight
        for weight in self.parameters():
            weight.fast = None
        self.zero_grad()

        for task_step in range(self.task_update_num):
            scores = self.forward(x_a_i)
            set_loss = self.loss_fn(scores, y_a_i)
            grad = torch.autograd.grad(set_loss, fast_parameters,
                                       create_graph=True)  # build full graph support gradient of gradient
            if self.approx:
                grad = [g.detach() for g in
                        grad]  # do not calculate gradient of gradient if using first order approximation
            fast_parameters = []
            for k, weight in enumerate(self.parameters()):
                # for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py
                if weight.fast is None:
                    weight.fast = weight - self.train_lr * grad[k]  # create weight.fast
                else:
                    weight.fast = weight.fast - self.train_lr * grad[
                        k]  # create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast
                fast_parameters.append(
                    weight.fast)  # gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts

        scores = self.forward(x_b_i)
        return scores

    def set_forward_adaptation(self, x, is_feature=False):  # overwrite parrent function
        raise ValueError('MAML performs further adapation simply by increasing task_upate_num')

    def set_forward_loss(self, x):
        scores = self.set_forward(x, is_feature=False)
        y_b_i = Variable(torch.from_numpy(np.repeat(range(self.n_way), self.n_query))).cuda()
        loss = self.loss_fn(scores, y_b_i)

        return loss

    def train_loop(self, epoch, train_loader, optimizer, tf_writer):  # overwrite parrent function
        print_freq = 10
        avg_loss = 0
        task_count = 0
        loss_all = []
        optimizer.zero_grad()

        # train
        for i, (x, _) in enumerate(train_loader):
            self.n_query = x.size(1) - self.n_support
            assert self.n_way == x.size(0), "MAML do not support way change"

            loss = self.set_forward_loss(x)
            avg_loss = avg_loss + loss.item()
            loss_all.append(loss)

            task_count += 1

            if task_count == self.n_task:  # MAML update several tasks at one time
                loss_q = torch.stack(loss_all).sum(0)
                loss_q.backward()

                optimizer.step()
                task_count = 0
                loss_all = []
            optimizer.zero_grad()
            if i % print_freq == 0:
                print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader),
                                                                        avg_loss / float(i + 1)))
                tf_writer.add_scalar('loss/train', avg_loss / float(i + 1), epoch)

    def test_loop(self, test_loader, return_std=False):  # overwrite parrent function
        correct = 0
        count = 0
        acc_all = []

        iter_num = len(test_loader)
        from tqdm import tqdm
        for i, (x, _) in enumerate(tqdm(test_loader)):
            self.n_query = x.size(1) - self.n_support
            assert self.n_way == x.size(0), "MAML do not support way change"
            correct_this, count_this = self.correct(x)
            acc_all.append(correct_this / count_this * 100)

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
        if return_std:
            return acc_mean, acc_std
        else:
            return acc_mean

    def test_loop_with_dists(self, test_loader, all_cls_dists, base_dists,
                             base_cls_dists, attr_num):  # overwrite parrent function
        correct = 0
        count = 0
        acc_all = []
        dist_all = []

        iter_num = len(test_loader)
        from tqdm import tqdm
        for i, (x, y) in enumerate(tqdm(test_loader)):
            # original mean-task (down trend)
            sc_cls = y.unique()
            task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0)
            dist_all.append(torch.abs(base_dists - task_dists).sum(-1).mean().item() / attr_num)

            self.n_query = x.size(1) - self.n_support
            assert self.n_way == x.size(0), "MAML do not support way change"
            correct_this, count_this = self.correct(x)
            acc_all.append(correct_this / count_this * 100)

            # task-agnostic (up trend)
            # task_cls_dists = all_cls_dists[sc_cls, :].unsqueeze(1)
            # distance = torch.abs(base_cls_dists.unsqueeze(0) - task_cls_dists).sum(-1).mean()
            # dist_all.append(distance.item())

            # class-agnostic (down trend)
            # task_dists = all_cls_dists[sc_cls, :].mean(0)
            # dist_all.append(torch.abs(base_dists - task_dists).sum().item() / len(attr_label_split))

            # original (up trend)
            # task_cls_dists = all_cls_dists[sc_cls, :].unsqueeze(0)
            # distance = torch.abs(base_cls_dists - task_cls_dists).sum(-1).mean()
            # dist_all.append(distance.item())

            # original mini-task (first up then down ?)
            # task_dists = all_cls_dists[sc_cls, :].mean(0).unsqueeze(0)
            # dist_all.append(torch.abs(base_dists - task_dists).sum(-1).min().item())

            # ideal with task-agnostic (down trend)
            # task_cls_dists = all_cls_dists[sc_cls, :]
            # d = 0
            # for j in range(task_cls_dists.shape[0]):
            #     task_dist = task_cls_dists[j]
            #
            #     start = 0
            #     for split in attr_label_split:
            #         end = start + split
            #         task_dist_split = task_dist[start:end].unsqueeze(0)
            #         base_dist_split = base_cls_dists[:, start:end]
            #         d_split = torch.abs(base_dist_split - task_dist_split).sum(-1).min()
            #         d += d_split.item()
            #         start = end
            # d /= (len(attr_label_split)*task_cls_dists.shape[0])
            # dist_all.append(d)

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
        return acc_all, dist_all

