import torch
import torch.distributed as dist
from torch import nn
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
import time
import numpy as np

class OnebitBinarySGDState:
    def __init__(self, beta, eps=1.0e-8, comp_flag=False, record_time=False, packbits_by_cupy=False, hierarchy=False, all_gather_by_chunks=False):
        self.beta = beta
        self.eps = eps
        self.comp_flag = comp_flag
        self.comp_delta = {}
        self.comp_delta_server = {}
        self.m = {}
        self.v = {}
        self.index_flag = set()
        self.not_first_iter_flag = False
        self.record_time = record_time
        self.time_counter = {}
        self.packbits_by_cupy = packbits_by_cupy
        if self.packbits_by_cupy:
            import cupy
        self.hierarchy = hierarchy
        self.all_gather_by_chunks = all_gather_by_chunks
        self.new_dist_groups()

    
    def new_dist_groups(self):
        if self.hierarchy:
            stage_1_cur_subgroup, stage_1_subgroups = dist.new_subgroups()
            self.stage_1_cur_subgroup = stage_1_cur_subgroup
            self.stage_1_subgroups = stage_1_subgroups
            intra_size = torch.cuda.device_count()
            word_size = dist.get_world_size()
            inter_size = word_size // intra_size
            stage_2_cur_subgroup, stage_2_subgroups = dist.new_subgroups_by_enumeration(
                ranks_per_subgroup_list=[[i+j*intra_size for j in range(inter_size)]for i in range(intra_size)]
                )
            self.stage_2_cur_subgroup = stage_2_cur_subgroup
            self.stage_2_subgroups = stage_2_subgroups
    
    def state_dict(self):
        return {'comp_delta': self.comp_delta,
                'comp_delta_server': self.comp_delta_server,
                'm': self.m,
                'v': self.v
                }

    def load_state_dict(self, state_dict):
        self.comp_delta = state_dict['comp_delta']
        self.comp_delta_server = state_dict['comp_delta_server']
        self.m = state_dict['m']
        self.v = state_dict['v']
    


def _packbits(torch_tensor, by_cupy=False):
    if by_cupy:
        cupy_device = torch_tensor.device.index
        with cupy.cuda.Device(cupy_device):
            cupy_tensor = cupy.fromDlpack(to_dlpack(torch_tensor))
            cupy_packed_tensor = cupy.packbits(cupy_tensor)
            torch_packed_tensor = from_dlpack(cupy_packed_tensor.toDlpack())
            return torch_packed_tensor

    else:
        empty_flag = torch_tensor.shape[0] % 8
        if empty_flag:
            empty_tensor = torch.zeros(8 - empty_flag, device=torch_tensor.device, dtype=torch_tensor.dtype)
            torch_tensor = torch.cat([torch_tensor, empty_tensor])
        mask = torch.tensor([2**(7-i) for i in range(8)], device=torch_tensor.device, dtype=torch_tensor.dtype)
        torch_packed_tensor = torch.sum(torch_tensor.view(-1, 8) * mask, axis=1, dtype=torch_tensor.dtype)
        return torch_packed_tensor

def _unpackbits(torch_packed_tensor, by_cupy=False):
    if by_cupy:
        cupy_device = torch_packed_tensor.device.index
        with cupy.cuda.Device(cupy_device):
            cupy_packed_tensor = cupy.fromDlpack(to_dlpack(torch_packed_tensor))
            cupy_tensor = cupy.unpackbits(cupy_packed_tensor)
            torch_tensor = from_dlpack(cupy_tensor.toDlpack())
            return torch_tensor
    else:
        mask = torch.tensor([2**(7-i) for i in range(8)], device=torch_packed_tensor.device, dtype=torch_packed_tensor.dtype)
        torch_tensor = torch.bitwise_and(torch_packed_tensor.view(-1, 1).expand(-1, 8), mask).to(torch.bool).to(torch_packed_tensor.dtype)
        return torch_tensor.view(-1)


def _quantize_onebit_tensor_cuda(state, tensor, index, by_cupy=False):

    tensor = torch.where(torch.isinf(tensor), 0.1*tensor.sign(), tensor).float()
    if index not in state.m:
        state.m[index] = tensor * (1 - state.beta)
        state.v[index] = tensor.abs() * (1 - state.beta)
    else:
        state.m[index].mul_(state.beta).add_(tensor, alpha=(1 - state.beta))
        state.v[index].mul_(state.beta).add_(tensor.abs(), alpha=(1 - state.beta))
    norm_factor = (state.m[index] / (state.v[index] + state.eps) + 1) / 2
    if torch.isnan(norm_factor).any():
        print("m_nan: ", state.m[index][torch.isnan(norm_factor)])
        print("v_nan: ", state.v[index][torch.isnan(norm_factor)])
        print("m_inf: ", state.m[index][torch.isinf(norm_factor)])
        print("v_inf: ", state.v[index][torch.isinf(norm_factor)])
    if state.comp_flag:
        
        if index not in state.comp_delta:
            distributed_tensor = torch.bernoulli(norm_factor)
            state.comp_delta[index] = (norm_factor - distributed_tensor)
        else:
            distributed_tensor = torch.bernoulli((norm_factor + state.comp_delta[index]).clamp(0, 1))
            state.comp_delta[index].add_(norm_factor - distributed_tensor)
    else:
        distributed_tensor = torch.bernoulli(norm_factor)
    compressed_tensor = distributed_tensor.to(torch.uint8)
    compressed_tensor = _packbits(compressed_tensor, by_cupy)

    return compressed_tensor

def _quantize_onebit_tensor_cuda_server(state, tensor, index, by_cupy=False):
    norm_factor = (tensor + 1) / 2
    if state.comp_flag:
        if index not in state.comp_delta_server:
            distributed_tensor = torch.bernoulli(norm_factor)
            state.comp_delta_server[index] = (norm_factor - distributed_tensor)
        else:
            distributed_tensor = torch.bernoulli((norm_factor + state.comp_delta_server[index]).clamp(0, 1))
            state.comp_delta_server[index].add_(norm_factor - distributed_tensor)
    else:
        distributed_tensor = torch.bernoulli(norm_factor)
    compressed_tensor = distributed_tensor.to(torch.uint8)
    compressed_tensor = _packbits(compressed_tensor, by_cupy)

    return compressed_tensor


def _dequantize_onebit_tensor_cuda(torch_packed_tensor, shape, by_cupy=False):

    torch_tensor = _unpackbits(torch_packed_tensor, by_cupy)
    return torch_tensor[:shape]

def _get_allgather_out_list(all_gather_in_list, world_size):
    out_list = [
        torch.zeros_like(
            all_gather_in_list,
            device=all_gather_in_list.device,
            dtype=all_gather_in_list.dtype,
        )
        for _ in range(world_size)
    ]
    return out_list


def quantization_onebit_hook(state: OnebitBinarySGDState, bucket):
    if state.record_time:
        time_list = []
        time_list.append(time.time())

    group_to_use = dist.group.WORLD
    world_size = group_to_use.size()
    bucket_tensor = bucket.buffer()
    bucket_index = bucket.index()
    # print("*********bucket_index_{}_bucket_tensor_{}********".format(bucket_index, bucket_tensor))
    if state.hierarchy:
        bucket_tensor_comp_shape = bucket_tensor.shape[0]
        empty_flag = bucket_tensor.shape[0] % (dist.get_world_size() * 8)
        if empty_flag:
            bucket_tensor_comp_shape += (dist.get_world_size() * 8) - empty_flag
        stage_1_comp_shape = bucket_tensor_comp_shape // dist.get_world_size(group=state.stage_1_cur_subgroup)
        stage_2_comp_shape = stage_1_comp_shape // dist.get_world_size(group=state.stage_2_cur_subgroup)
    
    if not state.not_first_iter_flag:
        if bucket_index in state.index_flag:
            state.not_first_iter_flag = True
        else:
            state.index_flag.add(bucket_index)

    if state.record_time:
        time_list.append(time.time())
    
    def first_iter_avg(fut):
        decompressed_tensor = bucket.buffer()
        tensor = fut.value()[0]
        decompressed_tensor.copy_(tensor)
        decompressed_tensor.div_(dist.group.WORLD.size())
        return decompressed_tensor

    def dequantize_and_aggregate(fut):
        decompressed_tensor = bucket.buffer()
        if dist.get_backend() == 'nccl':
            all_ranks_quantized_tensor = fut.wait()[0]
            if state.record_time or state.packbits_by_cupy:
                torch.cuda.synchronize()
        else:
            assert dist.get_backend() == 'gloo'
            all_ranks_quantized_tensor = fut.value()
        if state.record_time:
            time_list.append(time.time())
        aggregated_dequantized_tensor = torch.zeros_like(
            decompressed_tensor, device=decompressed_tensor.device, dtype=torch.int32
        )
        for quantized_tensor in all_ranks_quantized_tensor:
            aggregated_dequantized_tensor += _dequantize_onebit_tensor_cuda(quantized_tensor, decompressed_tensor.shape[0], by_cupy=state.packbits_by_cupy)
        if state.record_time:
            time_list.append(time.time())
        decompressed_tensor.copy_(aggregated_dequantized_tensor)
        decompressed_tensor.mul_(2.0 / dist.group.WORLD.size()).sub_(1.0)
        if state.record_time:
            time_list.append(time.time())
            if bucket_index not in state.time_counter:
                state.time_counter[bucket_index] = np.diff(np.array(time_list))
            else:
                state.time_counter[bucket_index] += np.diff(np.array(time_list))

        return decompressed_tensor
    
    def stage_2_process(fut):

        assert dist.get_backend() == 'nccl'
        stage_2_input = fut.wait()[0]
        # print('*****bucket_index_{}_stage_2_input_{}*********'.format(bucket_index, stage_2_input))
        if state.record_time or state.packbits_by_cupy:
            torch.cuda.synchronize()
        stage_2_input = _quantize_onebit_tensor_cuda(state, stage_2_input, bucket_index, by_cupy=state.packbits_by_cupy)
        
        cur_subgroup = state.stage_2_cur_subgroup
        cur_size = dist.get_world_size(group=cur_subgroup)
        stage_2_out_list = _get_allgather_out_list(stage_2_input, cur_size)

        fut = dist.all_gather(
            stage_2_out_list,
            stage_2_input,
            group=cur_subgroup,
            async_op=True,
        ).get_future()

        return fut.wait()

    def stage_3_process(fut):

        aggregated_dequantized_tensor = torch.zeros(
            stage_1_comp_shape, device=bucket_tensor.device, dtype=bucket_tensor.dtype
        )

        assert dist.get_backend() == 'nccl'
        quantized_tensor_list = fut.wait()[0]
        # print('*****bucket_index_{}_quantized_tensor_list_{}*********'.format(bucket_index, quantized_tensor_list))
        if state.record_time or state.packbits_by_cupy:
            torch.cuda.synchronize()

        for quantized_tensor in quantized_tensor_list:
            aggregated_dequantized_tensor += _dequantize_onebit_tensor_cuda(quantized_tensor, stage_1_comp_shape, by_cupy=state.packbits_by_cupy)
        if state.record_time:
            time_list.append(time.time())
        aggregated_dequantized_tensor.mul_(2.0 / dist.get_world_size(group=state.stage_2_cur_subgroup)).sub_(1.0)
        if state.record_time:
            time_list.append(time.time())
            if bucket_index not in state.time_counter:
                state.time_counter[bucket_index] = np.diff(np.array(time_list))
            else:
                state.time_counter[bucket_index] += np.diff(np.array(time_list))

        cur_subgroup = state.stage_1_cur_subgroup
        cur_size = dist.get_world_size(group=cur_subgroup)
        stage_3_input = aggregated_dequantized_tensor
        stage_3_output_list = _get_allgather_out_list(stage_3_input, cur_size)

        fut = dist.all_gather(
            stage_3_output_list,
            stage_3_input,
            group=cur_subgroup,
            async_op=True,
        ).get_future()
        return fut.wait()
    
    def stage_4_process(fut):
        decompressed_tensor = bucket.buffer()

        final_comp_tensor = fut.wait()[0]
        final_comp_tensor = final_comp_tensor.view(-1)[:decompressed_tensor.shape[0]]
        decompressed_tensor.copy_(final_comp_tensor)
        # print('*****bucket_index_{}_decompressed_tensor_{}*********'.format(bucket_index, decompressed_tensor))

        return decompressed_tensor

    
    def stage_2_all_gather_by_chunks_process(fut):
        assert dist.get_backend() == 'nccl'
        stage_2_input = fut.wait()[0]

        if state.record_time or state.packbits_by_cupy:
            torch.cuda.synchronize()
        stage_2_input = _quantize_onebit_tensor_cuda(state, stage_2_input, bucket_index, by_cupy=state.packbits_by_cupy)
        stage_2_output = torch.zeros_like(
            stage_2_input, device=stage_2_input.device, dtype=stage_2_input.dtype
        )
        cur_subgroup = state.stage_2_cur_subgroup

        fut = dist.all_to_all_single(
            stage_2_output,
            stage_2_input,
            group=cur_subgroup,
            async_op=True,
        ).get_future()
        return fut.wait()

    def stage_3_all_gahter_by_chunks_process(fut):

        cur_subgroup = state.stage_2_cur_subgroup
        cur_size = dist.get_world_size(group=cur_subgroup)
        assert dist.get_backend() == 'nccl'
        quantized_tensor_list = fut.wait()[0]

        quantized_tensor_list = quantized_tensor_list.chunk(cur_size)
        if state.record_time or state.packbits_by_cupy:
            torch.cuda.synchronize()

        aggregated_dequantized_tensor = torch.zeros(
            stage_2_comp_shape, device=bucket_tensor.device, dtype=bucket_tensor.dtype
        )
        for quantized_tensor in quantized_tensor_list:
            aggregated_dequantized_tensor += _dequantize_onebit_tensor_cuda(quantized_tensor, stage_2_comp_shape, by_cupy=state.packbits_by_cupy)
        
        if state.record_time:
            time_list.append(time.time())
        aggregated_dequantized_tensor.mul_(2.0 / dist.get_world_size(group=state.stage_2_cur_subgroup)).sub_(1.0)

        if state.record_time:
            time_list.append(time.time())
            if bucket_index not in state.time_counter:
                state.time_counter[bucket_index] = np.diff(np.array(time_list))
            else:
                state.time_counter[bucket_index] += np.diff(np.array(time_list))

        stage_3_input = _quantize_onebit_tensor_cuda_server(state, aggregated_dequantized_tensor, bucket_index, by_cupy=state.packbits_by_cupy)
        stage_3_output_list = _get_allgather_out_list(stage_3_input, cur_size)

        fut = dist.all_gather(
            stage_3_output_list,
            stage_3_input,
            group=cur_subgroup,
            async_op=True,
        ).get_future()
        return fut.wait()
    
    def stage_4_all_gather_by_chunks_process(fut):
        cur_subgroup = state.stage_1_cur_subgroup
        cur_size = dist.get_world_size(group=cur_subgroup)

        assert dist.get_backend() == 'nccl'
        quantized_tensor_list = fut.wait()[0]
        if state.record_time or state.packbits_by_cupy:
            torch.cuda.synchronize()

        all_aggregated_dequantized_tensor = torch.zeros(
            stage_1_comp_shape, device=bucket_tensor.device, dtype=bucket_tensor.dtype
        )
        aggregated_dequantized_tensor_list = all_aggregated_dequantized_tensor.chunk(
            dist.get_world_size(group=state.stage_2_cur_subgroup)
        )

        for quantized_tensor, aggregated_dequantized_tensor in zip(quantized_tensor_list, aggregated_dequantized_tensor_list):
            aggregated_dequantized_tensor.copy_(
                _dequantize_onebit_tensor_cuda(quantized_tensor, stage_2_comp_shape, by_cupy=state.packbits_by_cupy)
            )

        all_aggregated_dequantized_tensor.mul_(2.0).sub_(1.0)
        stage_4_input = all_aggregated_dequantized_tensor

        stage_4_output_list = _get_allgather_out_list(stage_4_input, cur_size)

        fut = dist.all_gather(
            stage_4_output_list,
            stage_4_input,
            group=cur_subgroup,
            async_op=True,
        ).get_future()
        return fut.wait()

    bucket_tensor = torch.where(torch.isinf(bucket_tensor), bucket_tensor.sign(), bucket_tensor).float()
    if  state.not_first_iter_flag:
        if state.hierarchy:

            cur_subgroup = state.stage_1_cur_subgroup
            cur_size = dist.get_world_size(group=cur_subgroup)
            cur_rank = dist.get_rank(group=cur_subgroup)
            stage_1_input = bucket.buffer()
            stage_1_input.div_(cur_size)
            empty_flag = stage_1_input.shape[0] % (dist.get_world_size() * 8)
            if empty_flag:
                empty_tensor = torch.zeros((dist.get_world_size() * 8) - empty_flag, device=stage_1_input.device, dtype=stage_1_input.dtype)
                stage_1_input = torch.cat([stage_1_input, empty_tensor])

            stage_1_input_list = list(stage_1_input.chunk(cur_size))
            stage_1_output = torch.zeros(
                stage_1_comp_shape, device=stage_1_input.device, dtype=stage_1_input.dtype
            )

            fut = dist.reduce_scatter(stage_1_output, stage_1_input_list, group=cur_subgroup, async_op=True).get_future()

            if state.all_gather_by_chunks:
                return fut.then(stage_2_all_gather_by_chunks_process).then(stage_3_all_gahter_by_chunks_process).then(stage_4_all_gather_by_chunks_process).then(stage_4_process)
            else:
                return fut.then(stage_2_process).then(stage_3_process).then(stage_4_process)

        else:
            quantized_tensor = _quantize_onebit_tensor_cuda(state, bucket_tensor, bucket_index, by_cupy=state.packbits_by_cupy)
            if state.record_time:
                time_list.append(time.time())

            out_list = _get_allgather_out_list(quantized_tensor, world_size)
            fut = dist.all_gather(
                out_list,
                quantized_tensor,
                group=group_to_use,
                async_op=True,
            ).get_future()
            return fut.then(dequantize_and_aggregate)
    else:
        fut = dist.all_reduce(bucket_tensor, group=group_to_use, async_op=True).get_future()
        return fut.then(first_iter_avg)



class SGDState:
    def __init__(self, record_time=False):
        self.record_time = record_time
        self.time_counter = {}


def my_allreduce_hook(state, bucket):
    if state.record_time:
        time_list = []
        time_list.append(time.time())
    tensor = bucket.buffer()

    bucket_index = bucket.index()
    group_to_use = dist.group.WORLD

    if state.record_time:
        time_list.append(time.time())

    tensor.div_(group_to_use.size())
    def count_time(fut):
        ar_tensor = bucket.buffer()
        fut_tensor = fut.value()[0]
        if state.record_time:
            if dist.get_backend() == 'nccl':
                torch.cuda.synchronize()
            time_list.append(time.time())
        ar_tensor.copy_(fut_tensor)
        if state.record_time:
            time_list.append(time.time())
            if bucket_index not in state.time_counter:
                state.time_counter[bucket_index] = np.diff(np.array(time_list))
            else:
                state.time_counter[bucket_index] += np.diff(np.array(time_list))
        return ar_tensor
    return (
        dist.all_reduce(tensor, group=group_to_use, async_op=True)
        .get_future()
        .then(count_time)
    )


