import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.checkpoint import checkpoint
from math import prod


class STEBinary(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.sign()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        deriv = ((input > -1) & (input < 1))
        grad_output = grad_output * deriv
        return grad_output

class BinaryMoS(nn.Module):
    """
    This implementation is
    strictly equivalent to standard MoE with full capacity (no
    dropped tokens). It's faster since it formulates MoE operations
    in terms of block-sparse operations to accomodate imbalanced
    assignments of tokens to experts, whereas standard MoE either
    (1) drop tokens at the cost of reduced performance or (2) set
    capacity factor to number of experts and thus waste computation
    and memory on padding.
    """

    def __init__(self, config, weight, bias):
        super(BinaryMoS, self).__init__()
        if config['train_only_scale']:
            self.weight = nn.Parameter(weight.data, requires_grad=False)
        else:
            self.weight = nn.Parameter(weight.data)
        if bias is not None:
            if config['train_only_scale']:
                self.bias = nn.Parameter(bias.data, requires_grad=False)
            else:
                self.bias = nn.Parameter(bias.data)
        else:
            self.bias = None
            
        self.out_channel_shape = self.weight.shape[0]
        self.in_channel_shape = self.weight.shape[1]
        self.global_name = None
        self.hidden_dim = self.weight.shape[1]
        self.num_experts = config['num_experts']
        self.scale_init = config['scale_init']

        self.gate_linear = nn.Linear(self.hidden_dim, self.num_experts, bias=False, device=self.weight.device)

        if self.scale_init:
            reduced_rank = 1
            U, S, Vh = torch.linalg.svd(abs(weight.data.clone().float()), full_matrices=False)
            out_channel_scale = (U @ (torch.sqrt(torch.diag(S)[:, 0:reduced_rank]))).view(-1).repeat(self.num_experts, 1)
            in_channel_scale = (torch.sqrt(torch.diag(S)[0:reduced_rank, :]) @ Vh).view(-1).repeat(self.num_experts, 1)
        else:
            in_channel_scale = torch.zeros(self.num_experts, self.weight.shape[1])
            out_channel_scale = torch.zeros(self.num_experts, self.weight.shape[0])

        self.register_parameter('in_channel_scale', nn.Parameter(in_channel_scale))
        self.register_parameter('out_channel_scale', nn.Parameter(out_channel_scale))
            

    def forward(self, x):
        *seqlen, hidden_dim = x.shape
        seqlen.append(self.out_channel_shape)
        final_hidden_output_dim = tuple(seqlen)
        x = x.view(-1, hidden_dim)

        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate_linear(x)
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

        # we cast back to the input dtype
        routing_weights = routing_weights.to(x.dtype)
        
        in_scale_expert = routing_weights.matmul(self.in_channel_scale)
        out_scale_expert = routing_weights.matmul(self.out_channel_scale)
        
        if self.bias is not None:
            final_hidden_states = (((x * in_scale_expert) @ self.binarize().t()) * out_scale_expert) + self.bias
        else:
            final_hidden_states = (((x * in_scale_expert) @ self.binarize().t()) * out_scale_expert)
        final_hidden_states = final_hidden_states.reshape(final_hidden_output_dim)

        return final_hidden_states

    def binarize(self):
        binary_weight = STEBinary().apply(self.weight)

        return binary_weight

    def to_regular_linear(self):
        linear = nn.Linear(self.weight.shape[1], self.weight.shape[0], bias=self.bias is not None)
        linear.weight.data = self.weight
        if self.bias is not None:
            linear.bias.data = self.bias
        return linear

    def extra_repr(self):
        return f'in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, bias={self.bias is not None}, num_experts={self.num_experts}'
