import torch
import torch.nn as nn
from functools import partial
from torchvision.models import resnet, resnet18
from torchvision.models.resnet import Bottleneck, BasicBlock, conv1x1


class SplitBatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, num_splits, **kw):
        super().__init__(num_features, **kw)
        self.num_splits = num_splits

    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            running_mean_split = self.running_mean.repeat(self.num_splits)
            running_var_split = self.running_var.repeat(self.num_splits)
            outcome = nn.functional.batch_norm(
                input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split,
                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
                True, self.momentum, self.eps).view(N, C, H, W)
            self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0))
            self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0))
            return outcome
        else:
            return nn.functional.batch_norm(
                input, self.running_mean, self.running_var,
                self.weight, self.bias, False, self.momentum, self.eps)


class MultiBranchResNet(nn.Module):
    def __init__( self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64,
        replace_stride_with_dilation=None, norm_layer=None, branch_depth=3, same_branch_size=False, same_branch=False):
        super(MultiBranchResNet, self).__init__()

        self.same_branch = same_branch

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.branch_depth = branch_depth
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])

        branches = []
        if same_branch_size:
            if branch_depth == 3:
                for _ in range(3):
                    self.inplanes = 256
                    layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
                    branches.append(layer)
            elif branch_depth == 2:
                for _ in range(3):
                    self.inplanes = 128
                    layer = nn.Sequential(
                        self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]),
                        self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
                    )
                    branches.append(layer)
            elif branch_depth == 1:
                for _ in range(3):
                    self.inplanes = 64
                    layer = nn.Sequential(
                        self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]),
                        self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]),
                        self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
                    )
                    branches.append(layer)
        else:
            if branch_depth == 3:
                for _ in range(3):
                    self.inplanes = 256
                    layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
                    branches.append(layer)
            elif branch_depth == 2:
                for _ in range(3):
                    self.inplanes = 128
                    layer = self._make_layer(block, 512, layers[3], stride=4, dilate=replace_stride_with_dilation[2])
                    branches.append(layer)
            elif branch_depth == 1:
                for _ in range(3):
                    self.inplanes = 64
                    layer = self._make_layer(block, 512, layers[3], stride=8, dilate=replace_stride_with_dilation[2])
                    branches.append(layer)

        self.branches = nn.ModuleList(branches)
        # self.layer4 = nn.ModuleList([self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) \
        #     for _ in range(4)])

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.ModuleList([nn.Linear(512 * block.expansion, num_classes) for _ in range(4)])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, stride_expansion=-1):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            if stride_expansion < 0:
                layers.append(block(self.inplanes, planes, groups=self.groups,
                                    base_width=self.base_width, dilation=self.dilation,
                                    norm_layer=norm_layer))
            else:
                layers.append(block(self.inplanes, planes, stride * stride_expansion, groups=self.groups,
                                    base_width=self.base_width, dilation=self.dilation,
                                    norm_layer=norm_layer))


        return nn.Sequential(*layers)

    def forward(self, x, branch):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)

        if branch == 0:
            x = self.layer4(x3)
        else:
            b = branch if not self.same_branch else 1
            if self.branch_depth == 3:
                x = self.branches[b - 1](x3)
            elif self.branch_depth == 2:
                x = self.branches[b - 1](x2)
            elif self.branch_depth == 1:
                x = self.branches[b - 1](x1)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        if self.same_branch and branch > 0:
            x = self.fc[1](x)
        else:
            x = self.fc[branch](x)

        return x

class MultiBranchResNetFully(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64,
        replace_stride_with_dilation=None, norm_layer=None, branch_depth=3, same_branch_size=False, same_branch=False):
        super(MultiBranchResNetFully, self).__init__()

        self.same_branch = same_branch

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.branch_depth = branch_depth
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])

        branches = []
        if same_branch_size:
            if branch_depth == 3:
                for _ in range(3):
                    self.inplanes = 256
                    layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
                    branches.append(layer)
            elif branch_depth == 2:
                for _ in range(3):
                    self.inplanes = 128
                    layer = nn.Sequential(
                        self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]),
                        self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
                    )
                    branches.append(layer)
            elif branch_depth == 1:
                for _ in range(3):
                    self.inplanes = 64
                    layer = nn.Sequential(
                        self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]),
                        self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]),
                        self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
                    )
                    branches.append(layer)
        else:
            if branch_depth == 3:
                for _ in range(3):
                    self.inplanes = 256
                    layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
                    branches.append(layer)
            elif branch_depth == 2:
                for _ in range(3):
                    self.inplanes = 128
                    layer = self._make_layer(block, 512, layers[3], stride=4, dilate=replace_stride_with_dilation[2])
                    branches.append(layer)
            elif branch_depth == 1:
                for _ in range(3):
                    self.inplanes = 64
                    layer = self._make_layer(block, 512, layers[3], stride=8, dilate=replace_stride_with_dilation[2])
                    branches.append(layer)

        self.branches = nn.ModuleList(branches)
        # self.layer4 = nn.ModuleList([self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) \
        #     for _ in range(4)])

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.ModuleList([nn.Linear(512 * block.expansion, num_classes) for _ in range(4)])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, stride_expansion=-1):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            if stride_expansion < 0:
                layers.append(block(self.inplanes, planes, groups=self.groups,
                                    base_width=self.base_width, dilation=self.dilation,
                                    norm_layer=norm_layer))
            else:
                layers.append(block(self.inplanes, planes, stride * stride_expansion, groups=self.groups,
                                    base_width=self.base_width, dilation=self.dilation,
                                    norm_layer=norm_layer))


        return nn.Sequential(*layers)

    def forward(self, x, branch):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)

        if branch == 0:
            x = self.layer4(x3)
        else:
            b = branch if not self.same_branch else 1
            if self.branch_depth == 3:
                x = self.branches[b - 1](x3)
            elif self.branch_depth == 2:
                x = self.branches[b - 1](x2)
            elif self.branch_depth == 1:
                x = self.branches[b - 1](x1)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        if self.same_branch and branch > 0:
            x = self.fc[1](x)
        else:
            x = self.fc[branch](x)

        return x


class ModelBase(nn.Module):
    """
    Common CIFAR ResNet recipe.
    Comparing with ImageNet ResNet recipe, it:
    (i) replaces conv1 with kernel=3, str=1
    (ii) removes pool1
    """
    def __init__(self, feature_dim=128, arch=None, bn_splits=16):
        super(ModelBase, self).__init__()

        # use split batchnorm
        norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
        resnet_arch = getattr(resnet, arch)
        net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)

        self.net = []
        for name, module in net.named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if isinstance(module, nn.MaxPool2d):
                continue
            if isinstance(module, nn.Linear):
                self.net.append(nn.Flatten(1))
            self.net.append(module)

        self.net = nn.Sequential(*self.net)

    def forward(self, x):
        x = self.net(x)
        # note: not normalized here
        return x


class MultiBranchModelBase(nn.Module):
    """
    Common CIFAR ResNet recipe.
    Comparing with ImageNet ResNet recipe, it:
    (i) replaces conv1 with kernel=3, str=1
    (ii) removes pool1
    """
    def __init__(self, feature_dim=128, arch=None, bn_splits=16, branch_depth=3, same_branch=False, fully=False):
        super(MultiBranchModelBase, self).__init__()

        # use split batchnorm
        norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
        if fully:
            self.net = MultiBranchResNetFully(BasicBlock, [2, 2, 2, 2], num_classes=feature_dim, norm_layer=norm_layer, branch_depth=branch_depth, same_branch=same_branch)
            pretrained = resnet18(pretrained=True)
            self.net.load_state_dict(pretrained.state_dict(), strict=False)
        else:
            self.net = MultiBranchResNet(BasicBlock, [2, 2, 2, 2], num_classes=feature_dim, norm_layer=norm_layer, branch_depth=branch_depth, same_branch=same_branch)


    def forward(self, x, branch=0):
        x = self.net(x, branch)
        # note: not normalized here
        return x

    @torch.no_grad()
    def sync_weight(self):
        for branch in range(1, 4):
            for param_0, param_i in zip(self.net[6][0].parameters(), self.net[6][branch].parameters()):
                param_i.data.copy_(param_0.data)


class ModelMoCo(nn.Module):
    def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True, multi_branch=False, branch_depth=3, all_branch=False, same_branch=False, fully=False):
        super(ModelMoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T
        self.symmetric = symmetric
        self.all_branch = all_branch

        # create the encoders
        if multi_branch:
            self.encoder_q = MultiBranchModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits, branch_depth=branch_depth, same_branch=same_branch, fully=fully)
            self.encoder_k = MultiBranchModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits, branch_depth=branch_depth, same_branch=same_branch, fully=fully)
        else:
            self.encoder_q = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
            self.encoder_k = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.t()  # transpose
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_single_gpu(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        """
        # random shuffle index
        idx_shuffle = torch.randperm(x.shape[0]).cuda()

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        return x[idx_shuffle], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        """
        return x[idx_unshuffle]

    def contrastive_loss(self, im_q, im_k, branch):
        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)  # already normalized

        # compute key features
        with torch.no_grad():  # no gradient to keys
            # shuffle for making use of BN
            im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)

            k = self.encoder_k(im_k_)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)  # already normalized

            # undo shuffle
            k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        loss = nn.CrossEntropyLoss().cuda()(logits, labels)

        return loss, q, k

    def forward(self, im1, im2):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            loss
        """

        # update the key encoder
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()

        # compute loss
        if self.symmetric:  # asymmetric loss
            loss_12, q1, k2 = self.contrastive_loss(im1, im2, 0)
            loss_21, q2, k1 = self.contrastive_loss(im2, im1, 0)
            loss = loss_12 + loss_21
            k = torch.cat([k1, k2], dim=0)
            if self.all_branch:
                loss_121, _, _ = self.contrastive_loss(im1, im2, 1)
                loss_211, _, _ = self.contrastive_loss(im2, im1, 1)
                loss += loss_121 + loss_211
        else:  # asymmetric loss
            loss, q, k = self.contrastive_loss(im1, im2, 0)
            if self.all_branch:
                loss_1, _, _ = self.contrastive_loss(im1, im2, 1)

        self._dequeue_and_enqueue(k)

        return loss