import os
import random
import math
import time
import numpy as np
from PIL import Image
from scipy.special import logsumexp

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.nn import ModuleList
import torchvision.datasets
from torchvision.utils import make_grid

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def setup_runtime(seed=0, cuda_dev_id=[0]):
    """Initialize CUDA, CuDNN and the random seeds."""
    # Setup CUDA
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if len(cuda_dev_id) == 1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_dev_id[0])
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_dev_id[0])
        for i in cuda_dev_id[1:]:
            os.environ["CUDA_VISIBLE_DEVICES"] += "," + str(i)

    # global cuda_dev_id
    _cuda_device_id = cuda_dev_id
    if torch.cuda.is_available():
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
    # Fix random seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

class TotalAverage():
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0.
        self.mass = 0.
        self.sum = 0.
        self.avg = 0.

    def update(self, val, mass=1):
        self.val = val
        self.mass += mass
        self.sum += val * mass
        self.avg = self.sum / self.mass


class MovingAverage():
    def __init__(self, intertia=0.9):
        self.intertia = intertia
        self.reset()

    def reset(self):
        self.avg = 0.

    def update(self, val):
        self.avg = self.intertia * self.avg + (1 - self.intertia) * val


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k."""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def write_conv(writer, model, epoch, sobel=False):
    if not sobel:
        conv1_ = make_grid(list(ModuleList(list(model.children())[0].children())[0].parameters())[0],
                           nrow=8, normalize=True, scale_each=True)
        writer.add_image('conv1', conv1_, epoch)
    else:
        conv1_sobel_w = list(ModuleList(list(model.children())[0].children())[0].parameters())[0]
        conv1_ = make_grid(conv1_sobel_w[:, 0:1, :, :], nrow=8,
                           normalize=True, scale_each=True)
        self.writer.add_image('conv1_sobel_1', conv1_, epoch)
        conv2_ = make_grid(conv1_sobel_w[:, 1:2, :, :], nrow=8,
                           normalize=True, scale_each=True)
        self.writer.add_image('conv1_sobel_2', conv2_, epoch)
        conv1_x = make_grid(torch.sum(conv1_sobel_w[:, :, :, :], 1, keepdim=True), nrow=8,
                            normalize=True, scale_each=True)
        writer.add_image('conv1', conv1_x, epoch)


### LP stuff ###
def absorb_bn(module, bn_module):
    w = module.weight.data
    if module.bias is None:
        if isinstance(module, nn.Linear):
            zeros = torch.Tensor(module.out_features).zero_().type(w.type())
        else:
            zeros = torch.Tensor(module.out_channels).zero_().type(w.type())
        module.bias = nn.Parameter(zeros)
    b = module.bias.data
    invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5)
    if isinstance(module, nn.Conv2d):
        w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
    else:
        w.mul_(invstd.unsqueeze(1).expand_as(w))
    b.add_(-bn_module.running_mean).mul_(invstd)

    if bn_module.affine:
        if isinstance(module, nn.Conv2d):
            w.mul_(bn_module.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
        else:
            w.mul_(bn_module.weight.data.unsqueeze(1).expand_as(w))
        b.mul_(bn_module.weight.data).add_(bn_module.bias.data)

    bn_module.reset_parameters()
    bn_module.register_buffer('running_mean', None)
    bn_module.register_buffer('running_var', None)
    bn_module.affine = False
    bn_module.register_parameter('weight', None)
    bn_module.register_parameter('bias', None)


def is_bn(m):
    return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)


def is_absorbing(m):
    return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)


def search_absorb_bn(model):
    prev = None
    for m in model.children():
        if is_bn(m) and is_absorbing(prev):
            print("absorbing",m)
            absorb_bn(prev, m)
        search_absorb_bn(m)
        prev = m


class View(nn.Module):
    """A shape adaptation layer to patch certain networks."""
    def __init__(self):
        super(View, self).__init__()

    def forward(self, x):
        return x.view(x.shape[0], -1)


def sequential_skipping_bn_cut(model):
    mods = []
    layers = list(model.features) + [View()]
    if 'sobel' in dict(model.named_children()).keys():
        layers = list(model.sobel) + layers
    for m in nn.Sequential(*(layers)).children():
        if not is_bn(m):
            mods.append(m)
    return nn.Sequential(*mods)


def py_softmax(x, axis=None):
    return np.exp(x - logsumexp(x, axis=axis, keepdims=True))

def _warmup_batchnorm(model, data_loader, device, batches=100):
    """
    Run some batches through all parts of the model to warmup the running
    stats for batchnorm layers.
    """
    model.train()
    for i, q in enumerate(data_loader):
        images = q[0]
        if i == batches:
            break
        images = images.to(device)
        _ = model(images)


def aggreg_multi_gpu(model, dataloader, hc, dim, TYPE, model_gpus=1):
    ngpu_store = torch.cuda.device_count() - model_gpus  # one for the model
    l_dl = len(dataloader)              # number of batches in DL
    batches_per_gpu = l_dl // ngpu_store
    points_per_gpu = len(dataloader.dataset) // ngpu_store  # number of batches each gpu gets
    print(f"Points per GPU (before): {points_per_gpu}")
    points_per_gpu = int(dataloader.batch_size * round(points_per_gpu / float(dataloader.batch_size)))
    print(f"Points per GPU (after): {points_per_gpu}")
    print(f"NGPU store: {ngpu_store}")
    print(f"Dataset len: {len(dataloader.dataset)}")
    indices = torch.empty(len(dataloader.dataset), dtype=torch.long)
    PS = [torch.empty(points_per_gpu, dim,
                      device='cuda:' + str(i), dtype=TYPE)
          for i in range(model_gpus, model_gpus + ngpu_store-1)]
    print(f"Remainder: {len(dataloader.dataset) - (ngpu_store-1)*points_per_gpu}")
    print(f"ON CUDA: {model_gpus + ngpu_store - 1}")
    PS.append(torch.empty(len(dataloader.dataset) - (ngpu_store-1)*points_per_gpu,
                          dim, device='cuda:' + str(model_gpus + ngpu_store - 1), dtype=TYPE))  # accomodate remainder
    slices = [qq.shape[0] for qq in PS]
    print("slice sizes: ", slices, flush=True)
    batch_time = MovingAverage(intertia=0.9)
    now = time.time()
    st = 0
    softmax = torch.nn.Softmax(dim=1).to('cuda:0')
    # model.to('cuda:0')
    model.headcount = 1
    for batch_idx, batch in enumerate(dataloader):
        video, audio, _, _, idx = batch
        video = video.to(torch.device('cuda:0'))
        audio = audio.to(torch.device('cuda:0'))
        mass = video.size(0)
        en = st + mass
        j = min(((batch_idx * dataloader.batch_size) // points_per_gpu), ngpu_store - 1)
        subs = j * points_per_gpu
        if hc == 1:
            p = softmax(model(video, audio)).detach().to(TYPE)
            PS[j][st-subs:en-subs, :].copy_(p)
        else:
            PS[j][st-subs:en-subs, :].copy_(model(video, audio).detach())
        indices[st:en].copy_(idx)
        st = en
        batch_time.update(time.time() - now)
        now = time.time()
        if batch_idx % 50 == 0:
            print(f"Aggregating batch {batch_idx:03}/{l_dl}, speed: {mass / batch_time.avg:04.1f}Hz. To rGPU {j+1}",
                  end='\r', flush=True)
    torch.cuda.synchronize()
    _, indices = torch.sort(indices)
    return PS, indices


def gpu_mul_Ax(A, b, ngpu, splits, TYPE=torch.float32,model_gpus=1):
    # Step 1: make a copy of B on each GPU
    N = splits[-1]
    b_ = []
    for i in range(model_gpus,  ngpu):
        b_.append(b.to('cuda:' + str(i)))

    # Step 2: issue the matmul on each GPU
    C = torch.empty(N, 1, device='cuda:0', dtype=TYPE)
    for a,i in enumerate(range(model_gpus,  ngpu)):
        C[splits[a]:splits[a+1], :].copy_(torch.matmul(A[a], b_[a]))
    return C


def gpu_mul_AB(A, B, c, dim, TYPE=torch.float32, model_gpus=1):
    # Step 1: make a copy of B on each GPU
    ngpu = torch.cuda.device_count()  # one for the model
    b_ = []
    for i in range(model_gpus, ngpu):
        b_.append(B.to('cuda:' + str(i)))
    # Step 2: issue the matmul on each GPU
    PS = []
    for a, i in enumerate(range(model_gpus, ngpu)):
        PS.append((torch.matmul(A[a], b_[a]) + c.to('cuda:'+str(i))).to(torch.float64))
        # PS[i].copy_(torch.matmul(A[i], b_[i]) + c)
        # the softmax
        torch.exp(PS[a], out=PS[a])
        summed = torch.sum(PS[a], dim=1, keepdim=True)
        PS[a] /= summed

    return PS


def gpu_mul_xA(b, A, ngpu, splits, TYPE=torch.float32, model_gpus=1):
    # Step 1: make a copy of B on each GPU
    b_ = []
    for a, i in enumerate(range(model_gpus, ngpu)):
        b_.append(b[:, splits[a]:splits[a+1]].to('cuda:' + str(i)))
    # Step 2: issue the matmul on each GPU
    C = torch.empty(ngpu-model_gpus, A[0].size(1), device='cuda:0', dtype=TYPE)
    for a, i in enumerate(range(model_gpus,  ngpu)):
        C[a:a+1, :].copy_(torch.matmul(b_[a], A[a]))
    # Step 3: need to sum these up
    torch.cuda.synchronize()
    C = torch.sum(C, 0, keepdim=True)
    return C


def init_pytorch_defaults(m, version='041'):
    '''
    copied from AMDIM repo: https://github.com/Philip-Bachman/amdim-public/
    '''
    if version == '041':
        # print('init.pt041: {0:s}'.format(str(m.weight.data.size())))
        if isinstance(m, nn.Linear):
            stdv = 1. / math.sqrt(m.weight.size(1))
            m.weight.data.uniform_(-stdv, stdv)
            if m.bias is not None:
                m.bias.data.uniform_(-stdv, stdv)
        elif isinstance(m, nn.Conv2d):
            n = m.in_channels
            for k in m.kernel_size:
                n *= k
            stdv = 1. / math.sqrt(n)
            m.weight.data.uniform_(-stdv, stdv)
            if m.bias is not None:
                m.bias.data.uniform_(-stdv, stdv)
        elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            if m.affine:
                m.weight.data.uniform_()
                m.bias.data.zero_()
        else:
            assert False
    elif version == '100':
        # print('init.pt100: {0:s}'.format(str(m.weight.data.size())))
        if isinstance(m, nn.Linear):
            init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                fan_in, _ = init._calculate_fan_in_and_fan_out(m.weight)
                bound = 1 / math.sqrt(fan_in)
                init.uniform_(m.bias, -bound, bound)
        elif isinstance(m, nn.Conv2d):
            n = m.in_channels
            init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                fan_in, _ = init._calculate_fan_in_and_fan_out(m.weight)
                bound = 1 / math.sqrt(fan_in)
                init.uniform_(m.bias, -bound, bound)
        elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            if m.affine:
                m.weight.data.uniform_()
                m.bias.data.zero_()
        else:
            assert False
    elif version == 'custom':
        # print('init.custom: {0:s}'.format(str(m.weight.data.size())))
        if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            init.normal_(m.weight.data, mean=1, std=0.02)
            init.constant_(m.bias.data, 0)
        else:
            assert False
    else:
        assert False


def weight_init(m):
    '''
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Linear):
        init_pytorch_defaults(m, version='041')
    elif isinstance(m, nn.Conv2d):
        init_pytorch_defaults(m, version='041')
    elif isinstance(m, nn.BatchNorm1d):
        init_pytorch_defaults(m, version='041')
    elif isinstance(m, nn.BatchNorm2d):
        init_pytorch_defaults(m, version='041')
    elif isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)


def search_set_bn_eval(model,toeval):
    for m in model.children():
        if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
            if toeval:
                m.eval()
            else:
                m.train()
        search_set_bn_eval(m, toeval)
