import torch
import torch.nn as nn
import time
import numpy as np

# try:
#     from util import gpu_mul_Ax, gpu_mul_xA, aggreg_multi_gpu, gpu_mul_AB, py_softmax
# except:
#     from .util import gpu_mul_Ax, gpu_mul_xA, aggreg_multi_gpu, gpu_mul_AB, py_softmax
# from utils import print_or_log,get_dataloader


def cpu_sk(args, dataset, model, device, outs, logger=None):

    orig_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=None,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )

    # Set up Clustering code
    N = len(dataset)
    if args.hc == 1:
        L = torch.LongTensor(N).random_(0, args.ncl).cuda()
    else:
        L = torch.LongTensor(args.hc, N).random_(0, args.ncl).cuda()

    print(('Before:', L[0:10], args.global_rank), logger=logger)

    presize = 4096 if args.model == 'alexnet' else 512 * 2
    if args.dtype == 'f32':
        dtype = torch.float32 if not args.device == 'cpu' else np.float32
    else:
        dtype = torch.float64 if not args.device == 'cpu' else np.float64

    # 1. aggregate inputs:
    N = len(orig_loader.dataset)
    if args.hc == 1:
        PS = np.zeros((N, args.ncl), dtype=dtype)
    else:
        PS_pre = np.zeros((N, presize), dtype=dtype)
    now = time.time()
    for batch_idx, batch in enumerate(orig_loader):
        video, audio, _, vid_idx, idx = batch
        video, audio, idx = video.to(device), audio.to(device), idx.to(device)
        if args.hc == 1:
            p = nn.functional.softmax(model(video, audio), 1)
            PS[idx, :] = p.detach().cpu().numpy().astype(dtype)
        else:
            model.headcount = 1
            p = model(video, audio)
            PS_pre[idx, :] = p.detach().cpu().numpy().astype(dtype)
    print("Aggreg of outputs  took {0:.2f} min".format((time.time() - now) / 60.), flush=True)

    # 2. solve label assignment via sinkhorn-knopp:
    if args.hc == 1:
        cost = optimize_L_sk(args, PS, L, outs, dtype, device, nh=0)
        _costs = [cost]
    else:
        _costs = np.zeros(args.hc)
        for nh in range(args.hc):
            print("computing head %s " % nh, end="\r", flush=True)
            tl = getattr(model, "top_layer%d" % nh)
            time_mat = time.time()
            try:
                del PS
            except:
                pass
            PS = (PS_pre @ tl.weight.cpu().numpy().T.astype(dtype)
                       + tl.bias.cpu().numpy().astype(dtype))
            print("matmul took %smin" % ((time.time() - time_mat) / 60.), flush=True)
            PS = py_softmax(PS, 1)
            c = optimize_L_sk(args, PS, L, outs, dtype, device, nh=nh)
            _costs[nh] = c
    return L

def gpu_sk(args, dataset, model, device, outs, logger=None):

    if args.sk_centercrop:
        vargs = args
        vargs.center_crop = True
        dataset, _ = get_dataloader(vargs, 0)
    orig_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=None,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )

    # Set up Clustering code
    N = len(dataset)
    if args.hc == 1:
        L = torch.LongTensor(N).random_(0, args.ncl).cuda()
    else:
        L = torch.LongTensor(args.hc, N).random_(0, args.ncl).cuda()

    print_or_log(('Before:', L[0:10], args.global_rank), logger=logger)

    presize = 4096 if args.model == 'alexnet' else 512 * 2
    if args.dtype == 'f32':
        dtype = torch.float32 if not args.device == 'cpu' else np.float32
    else:
        dtype = torch.float64 if not args.device == 'cpu' else np.float64

    # 1. aggregate inputs:
    start_t = time.time()
    if args.hc == 1:
        PS, indices = aggreg_multi_gpu(model, orig_loader,
                                            hc=args.hc, dim=outs[0], TYPE=dtype)
    else:
        try: # just in case stuff
            del PS_pre
        except:
            pass
        torch.cuda.empty_cache()
        time.sleep(1)
        PS_pre, indices = aggreg_multi_gpu(model, orig_loader,
                                                hc=args.hc, dim=presize, TYPE=torch.float32)
        model.headcount = args.hc
    print("Aggreg of outputs  took {0:.2f} min".format((time.time() - start_t) / 60.), flush=True)
    # 2. solve label assignment via sinkhorn-knopp:
    if args.hc == 1:
        cost = optimize_L_sk_multi(args, PS, L, outs, dtype, device, nh=0)
        L = L[indices]
        _costs = [cost]
    else:
        _costs = np.zeros(args.hc)
        for nh in range(args.hc):
            tl = getattr(model, "top_layer%d" % nh)
            time_mat = time.time()
            try:
                del PS
                torch.cuda.empty_cache()
            except:
                pass
            PS = gpu_mul_AB(PS_pre, tl.weight.t(),
                                 c=tl.bias, dim=outs[nh], TYPE=dtype)
            print("matmul took %smin" % ((time.time() - time_mat) / 60.), flush=True)
            c = optimize_L_sk_multi(args, PS, L, outs, dtype, device, nh=nh)
            L[nh] = L[nh][indices]
            _costs[nh] = c
    return L

def optimize_L_sk(args, PS,device='cpu'):

    # create L
    N = PS.shape[0]
    K = PS.shape[1]
    L = torch.LongTensor(N).random_(0, K).cuda()
    tt = time.time()
    PS = PS.T # now it is K x N
    r = np.ones((K, 1)) / K
    c = np.ones((N, 1)) / N
    PS **= 0.5*args.lamb  # K x N
    inv_K = 1./K
    inv_N = 1./N
    err = 1e6
    _counter = 0
    while err > 1e-2:
        r = inv_K / (PS @ c)          # (KxN)@(N,1) = K x 1
        c_new = inv_N / (r.T @ PS).T  # ((1,K)@(KxN)).t() = N x 1
        if _counter % 10 == 0:
            err = np.nansum(np.abs(c / c_new - 1))
        c = c_new
        _counter += 1
    print("error: ", err, 'step ', _counter, flush=True)  # " nonneg: ", sum(I), flush=True)
    # inplace calculations.
    PS *= np.squeeze(c)
    PS = PS.T
    PS *= np.squeeze(r)
    PS = PS.T
    argmaxes = np.nanargmax(PS, 0) # size N
    newL = torch.LongTensor(argmaxes).cuda()

    PS = PS.T
    PS /= np.squeeze(r)
    PS = PS.T
    PS /= np.squeeze(c)
    sol = PS[argmaxes, np.arange(N)]
    np.log(sol, sol)
    cost = -(1./args.lamb)*np.nansum(sol)/N
    print('cost: ', cost,'nans in log(q): ', np.sum(np.isnan(sol)), flush=True)
    print('opt took {0:.2f}min, {1:4d}iters'.format(((time.time() - tt) / 60.), _counter), flush=True)
    return cost, newL

def optimize_L_sk_gpu(args, PS):

    # create L
    N = PS.size(0) # is N x K
    K = PS.size(1)
    tt = time.time()
    r = torch.ones((K, 1), dtype=torch.float64, device='cuda') / K
    if args.distribution != 'default':
        marginals_argsort = torch.argsort(PS.sum(0)) # size = K
        if args.distribution == 'gauss':
            r = (torch.randn(size=(K, 1), dtype=torch.float64, device='cuda')/0.1 +1)*N/K
            r = torch.clamp(r, min=1)
        if args.distribution == 'zipf':
            r = torch.cuda.DoubleTensor(torch.tensor(np.random.zipf(a=2, size=K)).view(K,1))
            r = torch.clamp(r, min=1)

        print(f"distribution used: {r}", flush=True)
        r /= r.sum()
        r[marginals_argsort] = torch.sort(r)[0]

    c = torch.ones((N, 1), dtype=torch.float64, device='cuda') / N
    PS.pow_(0.5*args.lamb)  # N x K
    inv_K = 1./K
    inv_N = 1./N
    err = 1e6
    _counter = 0
    ones = torch.ones(N, device='cuda:0', dtype=torch.float64)
    while err > 1e-1:
        r = inv_K / torch.matmul(c.t(), PS).t()          # ((1xN)@(NxK)).T = Kx1
        c_new = inv_N / torch.matmul(PS, r)              # (NxK)@(K,1) = N x 1
        if _counter % 10 == 0:
            err = torch.sum(torch.abs((c.squeeze() / c_new.squeeze()) - ones)).cpu().item()
        c = c_new
        _counter += 1
    print("error: ", err, 'step ', _counter, flush=True)  # " nonneg: ", sum(I), flush=True)
    # inplace calculations
    torch.mul(PS, c, out=PS)
    torch.mul(r.t(), PS, out=PS)
    newL = torch.argmax(PS, 1).cuda()

    # return back to obtain cost (optional)
    torch.mul((1./r).t(), PS, out=PS)
    torch.mul(PS, 1./c, out=PS)
    sol = np.nansum(torch.log(PS[torch.arange(0, len(newL)).long(), newL]).cpu().numpy())
    cost = -(1. / args.lamb) * sol / N
    print('opt took {0:.2f}min, {1:4d}iters'.format(((time.time() - tt) / 60.), _counter), flush=True)
    return cost, newL

# def optimize_L_sk_multi(args, PS, L, outs, dtype, device, nh=0):
#     """ optimizes label assignment via Sinkhorn-Knopp.
#
#          this implementation uses multiple GPUs to store the activations which allow fast matrix multiplies
#
#          Parameters:
#              nh (int) number of the head that is being optimized.
#
#              """
#     N = max(L.size())
#     tt = time.time()
#     r = torch.ones((outs[nh], 1), device='cuda:0', dtype=dtype) / outs[nh]
#     c = torch.ones((N, 1), device='cuda:0', dtype=dtype) / N
#     ones = torch.ones(N, device='cuda:0', dtype=dtype)
#     inv_K = 1. / outs[nh]
#     inv_N = 1. / N
#
#     # inplace power of softmax activations:
#     [qq.pow_(args.lamb) for qq in PS]  # K x N
#
#     err = 1e6
#     _counter = 0
#     ngpu = torch.cuda.device_count()
#     splits = np.cumsum([0] + [a.size(0) for a in PS])
#     while err > 1e-1:
#         r = inv_K / (gpu_mul_xA(c.t(), PS,
#                                 ngpu=ngpu, splits=splits, TYPE=dtype)).t()  # ((1xN)@(NxK)).T = Kx1
#         c_new = inv_N / (gpu_mul_Ax(PS, r,
#                                     ngpu=ngpu, splits=splits, TYPE=dtype))  # (NxK)@(K,1) = N x 1
#         torch.cuda.synchronize()  # just in case
#         if _counter % 10 == 0:
#             err = torch.sum(torch.abs((c.squeeze() / c_new.squeeze()) - ones)).cpu().item()
#         c = c_new
#         _counter += 1
#     print("error: ", err, 'step ', _counter, flush=True)
#
#     # getting the final tranportation matrix #####################
#     for i, qq in enumerate(PS):
#         torch.mul(qq, c[splits[i]:splits[i + 1], :].to('cuda:' + str(i + 1)), out=qq)
#     [torch.mul(r.to('cuda:' + str(i + 1)).t(), qq, out=qq) for i, qq in enumerate(PS)]
#     argmaxes = torch.empty(N, dtype=torch.int64, device='cuda:0')
#
#     start_idx = 0
#     sol = 0
#     for i, qq in enumerate(PS):
#         amax = torch.argmax(qq, 1)
#         argmaxes[start_idx:start_idx + len(qq)].copy_(amax)
#         torch.mul((1. / r).to('cuda:' + str(i + 1)).t(), qq, out=qq)
#         torch.mul(qq, (1. / c[splits[i]:splits[i + 1], :]).to('cuda:' + str(i + 1)), out=qq)
#         sol += np.nansum(np.log(qq[torch.range(0, len(amax) - 1).long(), amax].cpu().numpy()))
#         start_idx += len(qq)
#     newL = argmaxes
#     tt = time.time()
#     print('opt took {0:.2f}min, {1:4d}iters'.format(((time.time() - tt) / 60.), _counter), flush=True)
#
#     # finally, assign the new labels ########################
#     if args.hc == 1:
#         L = newL
#     else:
#         L[nh] = newL
#     cost = -(1. / args.lamb) * sol / N
#     return cost
