import torch

def psd_loss(x,y,psd=None):
    """
    Computes a loss between two datasets : relative max norm between the power spectra. 
    The power spectrum of the second process can be imposed.  

    :param x: (N,1,L,L)
    :param y: (N,1,L,L)
    :param psd: (L,L)
    """
    L = x.shape[-1]
    x_fft = torch.fft.fft2(x)[:,0,...]
    x_psd = torch.mean(torch.abs(x_fft)**2, dim=0) / L**2
    if psd is None:
        y_fft = torch.fft.fft2(y)[:,0,...]
        y_psd = torch.mean(torch.abs(y_fft)**2, dim=0) / L**2
    else:
        y_psd = psd
    psd_err = torch.max(torch.abs(x_psd - y_psd)) / torch.max(y_psd)
    return psd_err

def dtv_loss(x,y, n_bins = 1000):
    """Computes the DTV of the marginal (pixels) of the two processes.

    :param x: (N,1,W,H)
    :param y: (N,1,W,H)
    """
    m = torch.max(x)
    dtype=torch.float32
    bins = torch.linspace(-m, m, n_bins).to(dtype)
    delta = 2*m / n_bins
    
    x_h = torch.histogram(x.to(dtype), bins=bins, density=True)[0]
    y_h = torch.histogram(y.to(dtype), bins=bins, density=True)[0]
    ret = torch.abs(x_h - y_h).sum()
    return 0.5 * delta * ret

def total_loss(x,y, psd=None,n_bins=1000):
    """
    Sum of the two preceding losses.
    Order of magnitude: for two (10k, 1, 16, 16) gaussian samples, is approx 0.04. 

    :param x: (N,1,W,H)
    :param y: (N,1,W,H)
    :return: psd_loss + dtv_loss
    """
    l1 = psd_loss(x,y, psd=psd)
    l2 = dtv_loss(x,y,n_bins)
    return l1 + l2

def dkl_gaussian(psd1, psd2, fourier_wavelets=False):
    """
    psd1 and psd2 have the same size. They should contain the eigenvalues of two covariance
    matrices which are co-diagonalized.   
    if option fourier_wavelets==True then psds must have shape (c,c,L,L).
    """
    assert psd1.numel()==psd2.numel(), "size mismatch"


    if fourier_wavelets:
        channels = psd1.shape[0]
        ctype = torch.complex64
        k = psd1.numel()*0.25
        psd1 = psd1.permute(2,3,0,1).to(ctype)
        psd2 = psd2.permute(2,3,0,1).to(ctype)

        aux = torch.matmul(torch.inverse(psd2), psd1)

        term2 = 0

        for i in range(channels):
            term2 += torch.real(aux[:,:,i,i].sum())

        a = torch.log(torch.det(psd2))
        b = torch.log(torch.det(psd1))
        k = psd1.numel() / channels
        term1 = torch.real(a.sum()-b.sum())

    else:
        k = psd1.numel()
        term2 = torch.abs((psd1/psd2)).sum()
        #term1 = torch.log(torch.abs(torch.prod(psd2/psd1)))
        term1 = (torch.log(psd2)).sum() - (torch.log(psd1)).sum()
    
    loss = 0.5 * (term1 + term2 - k) 
    #print(term1, term2, loss)
    return loss

def operator_loss(a,b, fourier_wavelets=False):
    mat = a-b
    if fourier_wavelets:
        ret = torch.max(torch.linalg.matrix_norm(mat.permute(3,2,1,0), ord=2))
    else:
        ret = torch.max(torch.abs(mat))
    return ret












