import torch
import torch.nn as nn

import os, contextlib
from thop import profile

def analyse_model(net, input_size=(1, 3, 32, 32)):
    # silence
    with open(os.devnull, 'w') as devnull:
        with contextlib.redirect_stdout(devnull):
            flops, params = profile(net, input_size=input_size, device="cuda")
    return flops, params


def finetune(pack, lr_min, lr_max, T, mute=False):
    logs = []
    epoch = 0

    def iter_hook(curr_iter, total_iter):
        total = T * total_iter
        half = total / 2
        itered = epoch * total_iter + curr_iter
        if itered < half:
            _iter = epoch * total_iter + curr_iter
            _lr = (1- _iter / half) * lr_min + (_iter / half) * lr_max
        else:
            _iter = (epoch - T/2) * total_iter + curr_iter
            _lr = (1- _iter / half) * lr_max + (_iter / half) * lr_min

        for g in pack.optimizer.param_groups:
            g['lr'] = max(_lr, 0)
            g['momentum'] = 0.0

    for i in range(T):
        info = pack.trainer.train(pack, iter_hook = iter_hook)
        info.update(pack.trainer.test(pack))
        info.update({'LR': pack.optimizer.param_groups[0]['lr']})
        epoch += 1
        if not mute:
            print(info)
        logs.append(info)

    return logs


class DoRealPrune():
    ''' Given a mask pruned model, turn the model into a real pruned model '''
    
    @classmethod
    def for_cifar_vgg(cls, net, GBNs):
        bn_index = [1, 4, 8, 11, 15, 18, 21, 25, 28, 31, 35, 38, 41]
        conv_idx = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40]
        last_gbn = GBNs[-1]
        new_classifier = nn.Linear(int(last_gbn.bn_mask.sum()), net.classifier.weight.shape[0])
        new_classifier.to(net.classifier.weight.device)
        new_classifier.weight.data.set_(net.classifier.weight[:, last_gbn.bn_mask.view(-1) !=0])
        new_classifier.bias.data.set_(net.classifier.bias)
        net.classifier = new_classifier

        CONVs = []
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                CONVs.append(m)

        new_convs = []
        in_mask = torch.Tensor([1,1,1])
        for conv, gbn in zip(CONVs, GBNs):
            out_mask = gbn.bn_mask.view(-1)
            rep_conv = nn.Conv2d(int(in_mask.sum()), int(out_mask.sum()), kernel_size=3, padding=1)
            rep_conv.to(conv.weight.device)

            rep_conv.weight.data.set_(conv.weight[out_mask != 0][:, in_mask != 0])
            rep_conv.bias.data.set_(conv.bias[out_mask != 0])
            new_convs.append(rep_conv)
            in_mask = out_mask

        for idx, conv in zip(conv_idx, new_convs):
            original = net.features._modules[str(idx)]
            assert isinstance(original, nn.Conv2d)
            net.features._modules[str(idx)] = conv

        for idx, gbn in zip(bn_index, GBNs):
            mask = gbn.bn_mask.view(-1)
            new_bn = nn.BatchNorm2d(int(mask.sum()))
            new_bn.to(gbn.bn.weight.device)
            new_bn.running_var.data.set_(gbn.bn.running_var[mask != 0])
            new_bn.running_mean.data.set_(gbn.bn.running_mean[mask != 0])
            if hasattr(gbn, 'g'):
                new_bn.weight.data.set_((gbn.bn.weight * gbn.g.view(-1))[mask != 0])
                new_bn.bias.data.set_((gbn.bn.bias * gbn.g.view(-1))[mask != 0])
            else:
                new_bn.weight.data.set_((gbn.bn.weight)[mask != 0])
                new_bn.bias.data.set_((gbn.bn.bias)[mask != 0])
            net.features._modules[str(idx)] = new_bn
        return net
