import torch
import numpy as np
import platform

import time
class Timer:
    def __init__(self, devices):
        self.devices = [d for d in devices if d.type=='cuda']
        self.reset()

    def reset(self):
        self.startTime = None
        self.endTime = None

    def start(self):
        for d in self.devices:
            torch.cuda.synchronize(d)
        self.startTime = time.time()
        
    def end(self):
        for d in self.devices:
            torch.cuda.synchronize(d)
        self.endTime = time.time()

    # In milliseconds
    def elapsed(self):
        return 1000*(self.endTime - self.startTime)
    
    def elapsedAndReset(self):
        self.end()
        result = self.elapsed()
        self.reset()
        self.start()
        return result
                                                                                                                                                                

def tensorSize(t):
    if isinstance(t, torch.Tensor):
        return t.element_size() * np.prod(t.shape)
    else:
        return sum(tensorSize(u) for u in t)
    
host = platform.node().split('.')[0]

def measureBW(src, dst, shape, min_duration = 50, nb_parallel = 1):
    def make_data(device, quantity, pin=False, function=torch.rand):
        xargs = {"pin_memory" : True} if pin else {}
        return [ function(*shape, device=device, **xargs) for _ in range(quantity) ]
            
    src_dev = torch.device(src)
    dst_dev = torch.device(dst)
    nb = 1
    dataSRC = make_data(src_dev, nb_parallel, pin=src.startswith('cpu'))
    dataDST = make_data(dst_dev, nb_parallel, pin=dst.startswith('cpu'), function=torch.empty)
    timer = Timer([src_dev, dst_dev])
    timer.start()
    for i in range(nb_parallel):
        dataDST[i].copy_(dataSRC[i], non_blocking = True)
    timer.end()
    duration = timer.elapsed()

    if duration < min_duration:
        nb = 1 + int(min_duration // duration)
        dataSRC = make_data(src_dev, nb * nb_parallel, pin=src.startswith('cpu'))
        dataDST = make_data(dst_dev, nb * nb_parallel, pin=dst.startswith('cpu'), function=torch.empty)
        timer.reset()
        timer.start()
        for i in range(nb * nb_parallel):
            dataDST[i].copy_(dataSRC[i], non_blocking = True)
        timer.end()
        duration = timer.elapsed()

    bw = tensorSize(dataSRC) / (duration/1000)
    print(host, src, dst, shape, nb, duration, bw/1024/1024/1024, tensorSize(dataSRC))

def measureBWContention(pairs, shape, min_duration = 50, nb_parallel = 1):
    def make_data(device, quantity, function=torch.rand):
        xargs = {"pin_memory" : True} if device.type == 'cpu' else {}
        return [ function(*shape, device=device, **xargs) for _ in range(quantity) ]

    pairs_dev = [ (torch.device(pair[0]), torch.device(pair[1])) for pair in pairs ]
    nb = 1

    def measure_time(multiplicity):
        datas = [ (make_data(devs[0], multiplicity), make_data(devs[1], multiplicity, function=torch.empty))
                  for devs in pairs_dev ]
        devices = set()
        devices.update(*(set([p[0], p[1]]) for p in pairs_dev))
        timer = Timer(devices)
        timer.start()
        for i in range(multiplicity):
            for dataSRC, dataDST in datas:
                dataDST[i].copy_(dataSRC[i], non_blocking = True)
        timer.end()
        return timer.elapsed(), tensorSize((d[0] for d in datas))

    duration, dataSize = measure_time(nb_parallel)

    if duration < min_duration:
        nb = 1 + int(min_duration // duration)
        duration, dataSize = measure_time(nb * nb_parallel)

    bw = dataSize / (duration/1000)
    print(host, pairs, shape, nb, duration, bw/1024/1024/1024, dataSize)


if __name__ == "__main__":
    bwValues = {}

    ## For vision networks
    batch_size = 32
    image_size = 500
    shape = (batch_size, 10, image_size, image_size)
    
    shape = (7, 512, 512, 512)
    measureBW('cuda:0', 'cpu', shape, min_duration = 5, nb_parallel = 1)
    measureBW('cpu', 'cuda:0', shape, min_duration = 5, nb_parallel = 1)
    measureBWContention([('cuda:0', 'cpu'), ('cpu', 'cuda:0')], shape, min_duration = 5, nb_parallel = 1)
    exit(0)

    all_cudas = ['cuda:%d' % i for i in range(torch.cuda.device_count())]
    for src in ['cpu'] + all_cudas:
        for dst in ['cpu'] + all_cudas:
            if src != dst: 
                # measureBW(src, dst, shape, min_duration = 5, nb_parallel = 1)
                measureBWContention([(src, dst)], shape, min_duration = 5, nb_parallel = 1)

    measureBWContention([('cuda:0', 'cpu'), ('cuda:1', 'cpu')], shape, min_duration = 5, nb_parallel = 1)
    measureBWContention([('cpu', 'cuda:0'), ('cpu', 'cuda:1')], shape, min_duration = 5, nb_parallel = 1)
    measureBWContention([('cuda:1', 'cuda:0'), ('cuda:0', 'cuda:1')], shape, min_duration = 5, nb_parallel = 1)
