import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn.parameter import Parameter
from neuralalgo.common.consts import DEVICE, NONLINEARITIES
from neuralalgo.common.utils import diff_soft_threshold, weights_init


class SigmoidDiagEnergy(nn.Module):

    def __init__(self, args, hidden_dims_eigs, hidden_dims_b, activation='relu', bias_init='zero'):
        super(SigmoidDiagEnergy, self).__init__()

        self.d = args.d
        self.dim_in = args.dim_in
        self.mu = args.mu
        self.L = args.L
        self.temp = args.temperature
        if activation != 'none':
            self.act_fcn = NONLINEARITIES[activation]

        # generate b
        layers = []
        hidden_dims_b = tuple(map(int, hidden_dims_b.split("-")))
        prev_size = args.dim_in
        for h in hidden_dims_b:
            if h > 0:
                layers.append(nn.Linear(prev_size, h, bias=True))
                prev_size = h
        layers.append(nn.Linear(prev_size, self.d, bias=True))
        self.layers_b = nn.ModuleList(layers)

        # generate diagonals
        layers = []
        hidden_dims_eigs = tuple(map(int, hidden_dims_eigs.split("-")))
        prev_size = args.dim_in
        for h in hidden_dims_eigs:
            if h>0:
                layers.append(nn.Linear(prev_size, h, bias=True))
                prev_size = h
        layers.append(nn.Linear(prev_size, self.d-2, bias=True))
        self.layers_w = nn.ModuleList(layers)

        weights_init(self, bias=bias_init)

    def forward(self, u):
        batch = u.shape[0]
        # generate b
        for l, layer in enumerate(self.layers_b):
            if l == 0:
                b = layer(u)
            else:
                b = layer(b)
            if l + 1 < len(self.layers_b):
                b = self.act_fcn(b)

        # generate diagonals
        for l, layer in enumerate(self.layers_w):
            if l == 0:
                e = layer(u)
            else:
                e = layer(e)
            if l + 1 < len(self.layers_w):
                e = self.act_fcn(e)
        # shift to [mu, L]
        e = F.softmax(e, dim=-1)
        e = e * (self.L - self.mu) + self.mu

        # concat with mu and L
        eigen1 = torch.ones([batch, 1]).to(DEVICE) * self.L
        eigend = torch.ones([batch, 1]).to(DEVICE) * self.mu
        eigens = torch.cat([eigen1.detach(), e, eigend.detach()], dim=-1)

        return eigens, b


def qeq(q, e):
    batch, d = e.shape
    qe = q * e.view(batch, 1, d)
    w = torch.matmul(qe, q.transpose(-1, -2))
    return w


def qeq_inv(q, e):
    batch, d = e.shape
    e_inv = 1 / e
    qe_inv = q * e_inv.view(batch, 1, d)
    w_inv = torch.matmul(qe_inv, q.transpose(-1, -2))
    return w_inv

###

###



class DiagEnergyNet(nn.Module):

    def __init__(self, d, dim_in, hidden_dims_eigs, hidden_dims_b, activation='relu', mu=None, L=None):
        super(DiagEnergyNet, self).__init__()

        self.d = d
        self.dim_in = dim_in
        self.mu = mu
        self.L = L
        if activation != 'none':
            self.act_fcn = NONLINEARITIES[activation]

        # generate b
        layers = []
        hidden_dims_b = tuple(map(int, hidden_dims_b.split("-")))
        prev_size = dim_in
        for h in hidden_dims_b:
            layers.append(nn.Linear(prev_size, h, bias=True))
            prev_size = h
        layers.append(nn.Linear(prev_size, self.d, bias=True))
        self.layers_b = nn.ModuleList(layers)

        # generate diagonals
        layers = []
        hidden_dims_eigs = tuple(map(int, hidden_dims_eigs.split("-")))
        prev_size = dim_in
        for h in hidden_dims_eigs:
            layers.append(nn.Linear(prev_size, h, bias=True))
            prev_size = h
        layers.append(nn.Linear(prev_size, self.d, bias=True))
        self.layers_w = nn.ModuleList(layers)

        weights_init(self)

    def forward(self, u):
        batch = u.shape[0]
        # generate b
        for l, layer in enumerate(self.layers_b):
            if l == 0:
                b = layer(u)
            else:
                b = layer(b)
            if l + 1 < len(self.layers_b):
                b = self.act_fcn(b)

        # generate diagonals
        for l, layer in enumerate(self.layers_w):
            if l == 0:
                e = layer(u)
            else:
                e = layer(e)
            if l + 1 < len(self.layers_w):
                e = self.act_fcn(e)
        # shift to [mu, L]
        smallest = torch.min(e, dim=-1)[0]
        largest = torch.max(e, dim=-1)[0]
        return e, b


class EnergyNetSigmoidDiag(DiagEnergyNet):
    def __init__(self, d, dim_in, hidden_dims_eigs, hidden_dims_b, activation='relu', mu=None, L=None):
        super(EnergyNetSigmoidDiag, self).__init__(d, dim_in, hidden_dims_eigs, hidden_dims_b, activation, mu, L)

    def forward(self, u):
        # generate b
        for l, layer in enumerate(self.layers_b):
            if l == 0:
                b = layer(u)
            else:
                b = layer(b)
            if l + 1 < len(self.layers_b):
                b = self.act_fcn(b)

        # generate diagonals
        for l, layer in enumerate(self.layers_w):
            if l == 0:
                e = layer(u)
            else:
                e = layer(e)
            if l + 1 < len(self.layers_w):
                e = self.act_fcn(e)
        # shift to [mu, L]
        e = F.softmax(e, dim=-1)
        e = e * (self.L - self.mu) + self.mu

        return e, b



class EnergyNetGivenQ(nn.Module):

    def __init__(self, d, dim_in, hidden_dims_eigs, hidden_dims_b, activation='relu', mu=None, L=None):
        super(EnergyNetGivenQ, self).__init__()

        self.d = d
        self.dim_in = dim_in
        self.mu = mu
        self.L = L
        if activation != 'none':
            self.act_fcn = NONLINEARITIES[activation]

        # generate b
        layers = []
        hidden_dims_b = tuple(map(int, hidden_dims_b.split("-")))
        prev_size = dim_in
        for h in hidden_dims_b:
            layers.append(nn.Linear(prev_size, h, bias=False))
            prev_size = h
        layers.append(nn.Linear(prev_size, self.d, bias=False))
        self.layers_b = nn.ModuleList(layers)

        # generate diagonals
        layers = []
        hidden_dims_eigs = tuple(map(int, hidden_dims_eigs.split("-")))
        prev_size = dim_in
        for h in hidden_dims_eigs:
            layers.append(nn.Linear(prev_size, h, bias=False))
            prev_size = h
        layers.append(nn.Linear(prev_size, self.d, bias=False))
        self.layers_w = nn.ModuleList(layers)

        weights_init(self)

    def forward(self, u, q, return_flag=False):
        batch = u.shape[0]
        flag = 0
        # generate b
        for l, layer in enumerate(self.layers_b):
            if l == 0:
                b = layer(u)
            else:
                b = layer(b)
            if l + 1 < len(self.layers_b):
                b = self.act_fcn(b)

        # generate diagonals
        for l, layer in enumerate(self.layers_w):
            if l == 0:
                e = layer(u)
            else:
                e = layer(e)
            if l + 1 < len(self.layers_w):
                e = self.act_fcn(e)
        # shift to [mu, L]
        smallest = torch.min(e, dim=-1)[0]
        largest = torch.max(e, dim=-1)[0]
        if (largest == smallest).any():
            flag = 1
            e = (e - smallest.view(batch, 1)) * (self.L - self.mu) / (largest - smallest + 1e-14).view(batch, 1) + self.mu
        else:
            e = (e - smallest.view(batch, 1)) * (self.L - self.mu) / (largest - smallest).view(batch, 1) + self.mu
        # QeQ'
        qe = q * e.view(batch, 1, self.d)
        w = torch.matmul(qe, q.transpose(-1, -2))

        if return_flag:
            return w, b, flag
        else:
            return w, b


class EnergyNetSigmoid(EnergyNetGivenQ):
    def __init__(self, d, dim_in, hidden_dims_eigs, hidden_dims_b, activation='relu', mu=None, L=None):
        super(EnergyNetSigmoid, self).__init__(d, dim_in, hidden_dims_eigs, hidden_dims_b, activation, mu, L)

    def forward(self, u, q, return_flag=False):
        batch = u.shape[0]
        flag = 0
        # generate b
        for l, layer in enumerate(self.layers_b):
            if l == 0:
                b = layer(u)
            else:
                b = layer(b)
            if l + 1 < len(self.layers_b):
                b = self.act_fcn(b)

        # generate diagonals
        for l, layer in enumerate(self.layers_w):
            if l == 0:
                e = layer(u)
            else:
                e = layer(e)
            if l + 1 < len(self.layers_w):
                e = self.act_fcn(e)
        # shift to [mu, L]
        e = F.softmax(e, dim=-1)
        e = e * (self.L - self.mu) + self.mu
        # QeQ'
        qe = q * e.view(batch, 1, self.d)
        w = torch.matmul(qe, q.transpose(-1, -2))

        if return_flag:
            return w, b, flag
        else:
            return w, b



class EnergyNet(nn.Module):
    def __init__(self, d, hidden_dims_w, hidden_dims_b, activation='relu', mu=None, L=None):
        super(EnergyNet, self).__init__()

        self.d = d
        self.mu = mu
        self.L = L
        if activation != 'none':
            self.act_fcn = NONLINEARITIES[activation]

        # generate b
        layers = []
        hidden_dims_b = tuple(map(int, hidden_dims_b.split("-")))
        prev_size = self.d
        for h in hidden_dims_b:
            layers.append(nn.Linear(prev_size, h))
            prev_size = h
        layers.append(nn.Linear(prev_size, self.d))
        self.layers_b = nn.ModuleList(layers)

        # generate W
        layers = []
        hidden_dims_w = tuple(map(int, hidden_dims_w.split("-")))
        prev_size = self.d
        for h in hidden_dims_w:
            layers.append(nn.Linear(prev_size, h))
            prev_size = h
        layers.append(nn.Linear(prev_size, self.d * self.d))
        self.layers_w = nn.ModuleList(layers)

        weights_init(self)

    def forward(self, u):
        batch = u.shape[0]
        # generate b
        for l, layer in enumerate(self.layers_b):
            if l == 0:
                b = layer(u)
            else:
                b = layer(b)
            if l + 1 < len(self.layers_b):
                b = self.act_fcn(b)

        # generate W
        for l, layer in enumerate(self.layers_w):
            if l == 0:
                w = layer(u)
            else:
                w = layer(w)
            if l + 1 < len(self.layers_w):
                w = self.act_fcn(w)
        w_matrix = w.view(u.shape[0], self.d, self.d)
        # PSD
        w_matrix = torch.matmul(w_matrix, w_matrix.transpose(-1, -2))

        # shift w to [mu, L]
        if self.mu is not None:
            with torch.no_grad():
                e, _ = torch.symeig(w_matrix.cpu(), eigenvectors=False)
                largest = e.max(dim=-1)[0].to(DEVICE).detach()
                smallest = e.min(dim=-1)[0].to(DEVICE).detach()
            # w - smallest I
            w_matrix[:, range(self.d), range(self.d)] = w_matrix[:, range(self.d), range(self.d)] - smallest.view(batch, 1)
            c = (self.L - self.mu) / (largest - smallest)
            # w*c
            w_matrix = torch.einsum('bij,b->bij', w_matrix, c)
            # w + mu I
            w_matrix[:, range(self.d), range(self.d)] = w_matrix[:, range(self.d), range(self.d)] + self.mu

        return w_matrix, b


class DebugModel(nn.Module):
    def __init__(self, k):
        super(DebugModel, self).__init__()
        self.k = k
        self.s = Parameter(torch.tensor(1e-6).to(DEVICE))

    def forward(self, w, b, return_all=False, x_init=None):
        if x_init is not None:
            x = x_init.clone()
        else:
            x = torch.zeros(size=b.shape).to(DEVICE)
        return -x + self.s * 0.0


class ISTA(nn.Module):
    def __init__(self, args, A, depth, recurrent, theta_init, alpha_init):
        """
        :param A: m*n matrix
        :param T_max: maximal number of layers
        :param num_output: number of output layers
        :param ld: initial weight of the sparsity coefficient
        :param untied: flag of whether weights are shared within layers
        """
        super(ISTA, self).__init__()
        self.A = A
        self.m, self.n = A.shape
        self.depth = depth
        self.recurrent = recurrent
        self.k = args.temp

        if self.recurrent:
            self.theta = Parameter(theta_init)
            self.alpha = Parameter(alpha_init)
        else:
            theta_t = []
            alpha_t = []
            for t in range(self.depth):
                theta_t.append(Parameter(theta_init))
                alpha_t.append(Parameter(alpha_init))
            self.theta_t = nn.ParameterList(theta_t)
            self.alpha_t = nn.ParameterList(alpha_t)

    def forward(self, y, x0=None):
        # initialized as zeros if there is no specification
        if x0 is None:
            batch_size = y.shape[0]
            xh = torch.zeros([batch_size, self.n]).to(DEVICE)
        else:
            xh = x0.clone()

        xhs_ = [x0]
        for t in range(self.depth):

            if self.recurrent:
                g = torch.mm(torch.mm(xh, self.A.t()) - y, self.A)
                z = xh - g / self.alpha
                xh = diff_soft_threshold(self.theta, z, self.k)
            else:
                g = self.linear_b_t[t](y) + self.linear_x_t[t](xh)
                xh = diff_soft_threshold(self.theta_t[t], g, self.k)

        #     if t == out_idx:
        #         xhs_.append(xh)
        #         out_idx += d
        # assert self.T_max == out_idx - d + 1

        return xhs_
