import torch
import torch.nn as nn

from main import adjust_learning_rate
from prune.masks import GatedBatchNorm2d
import numpy as np

def get_gate_sparse_loss(masks, sparse_lambda):
    def _loss_hook(data, label, logits):
        loss = 0.0
        for gbn in masks:
            if isinstance(gbn, GatedBatchNorm2d):
                loss += gbn.g.abs().sum()
        return sparse_lambda * loss

    return _loss_hook

class IterRecoverFramework():
    def __init__(self, pack, masks, sparse_lambda=1e-5, minium_filter=10):
        self.pack = pack
        self.masks = masks
        self.sparse_loss_hook = get_gate_sparse_loss(masks, sparse_lambda)
        self.logs = []
        self.minium_filter = minium_filter
        self.sparse_lambda = sparse_lambda

        self.total_filters = sum([m.bn.weight.shape[0] for m in masks])
        self.pruned_filters = 0

    def recover(self, lr, test):
        hooks = []
        for gbn in self.masks:
            if isinstance(gbn, GatedBatchNorm2d):
                gbn.reset_score()
                hooks.append(gbn.g.register_hook(gbn.cal_score))

        for g in self.pack.optimizer.param_groups:
            g['lr'] = lr

        tmp = self.pack.train_loader
        self.pack.train_loader = self.pack.tick_trainset
        info = self.pack.trainer.train(self.pack)
        self.pack.train_loader = tmp

        if test:
            info.update(self.pack.trainer.test(self.pack))

        info.update({'LR': lr})

        for hook in hooks:
            hook.remove()
        
        return info

    def get_threshold(self, score_list, num):
        '''
            input score list from layers, and the number of filter to prune
        '''
        scores = np.concatenate(score_list)
        left_scores = scores[scores != 0]
        filtered_scores = np.concatenate([np.sort(g)[:-self.minium_filter] for g in score_list])
        filtered_scores = filtered_scores[filtered_scores != 0]
        threshold = np.sort(filtered_scores)[num]
        to_prune = int((left_scores <= threshold).sum())
        
        info = {'left': len(left_scores), 'to_prune': to_prune, 'total_pruned_ratio': (len(scores) - len(left_scores) + to_prune) / len(scores)}
        return threshold, info

    def set_mask(self, layers, threshold):
        for gbn in layers:
            score = gbn.get_score()
            hard_threshold = float(np.sort(score.cpu().data.numpy())[-self.minium_filter])
            gbn.bn_mask.data.set_((gbn.get_score() > min(threshold, hard_threshold)).float().view(1, -1, 1, 1) * gbn.bn_mask)

    def freeze_conv(self):
        self._status = {}
        for m in self.pack.net.modules():
            if isinstance(m, nn.Conv2d):
                for p in m.parameters():
                    self._status[id(p)] = p.requires_grad
                    p.requires_grad = False

    def restore_conv(self):
        for m in self.pack.net.modules():
            if isinstance(m, nn.Conv2d):
                for p in m.parameters():
                    p.requires_grad = self._status[id(p)]

    def tock(self, lr_min=0.001, lr_max=0.01, tock_epoch = 20, mute=False):
        logs = []
        epoch = 0
        T = tock_epoch
        def iter_hook(curr_iter, total_iter):
            _total = (T / 2) * total_iter
            if epoch < T / 2:
                _iter = epoch * total_iter + curr_iter
                _lr = (1- _iter / _total) * lr_min + (_iter / _total) * lr_max
            else:
                _iter = (epoch - T/2) * total_iter + curr_iter
                _lr = (1- _iter / _total) * lr_max + (_iter / _total) * lr_min
            
            for g in self.pack.optimizer.param_groups:
                g['lr'] = max(_lr, 0)
                # g['momentum'] = 0.9
        
        for i in range(T):
            info = self.pack.trainer.train(self.pack, loss_hook = self.sparse_loss_hook, iter_hook = iter_hook)
            info.update(self.pack.trainer.test(self.pack))
            info.update({'LR': self.pack.optimizer.param_groups[0]['lr']})
            epoch += 1
            if not mute:
                print('Tock - %d,\t Test Loss: %.4f,\t Test Acc: %.2f, Final LR: %.5f' % (i, info['test_loss'], info['test_accuracy'], info['LR']))
            logs.append(info)
        return logs

    def tick(self, lr, test):
        ''' Do Prune '''
        self.freeze_conv()
        info = self.recover(lr, test)
        self.restore_conv()
        return info
        

    def prune(self, num, tick=False, lr=0.01, test=True):
        info = {}
        if tick:
            info = self.tick(lr, test)
        
        threshold, r = self.get_threshold([g.get_score().cpu().data.numpy() for g in self.masks], num)
        info.update(r)
        threshold = float(threshold)
        self.set_mask(self.masks, threshold)
        if test:
            info.update({'after_prune_test_acc': self.pack.trainer.test(self.pack)['test_accuracy']})
        self.logs.append(info)
        self.pruned_filters = self.total_filters - info['left']
        info['total'] = self.total_filters
        return info
