import numpy as np
import torch
from numpy.fft import fft, fft2
from scipy.signal import convolve2d
import scipy

def fold2d(x):
    """
    Folding a vector in two dimensions.
    input : (B,C,L,L)
    output : (B,C,L/2,L/2)
    """
    l = int(x.shape[-1]/2)
    ret = x[...,:l,:l] + x[...,l:,:l] + x[...,:l,l:] + x[...,l:,l:]
    return 0.25 * ret

def translate(data, n1, n2):
    """
    Translate an array with boundary periodic condition by a translation (n1, n2) along the last two
    dimensions
    Args:
        data: ndarray os size (..., L, L)
        n1 (int)
        n2 (int)

    Returns:
        ndarray of size (..., L, L)

    """
    result_temp = np.copy(data)
    if n1 > 0:
        result_temp[..., :n1, :] = data[..., -n1:, :]
        result_temp[..., n1:, :] = data[..., :-n1, :]
    result = np.copy(result_temp)
    if n2 > 0:
        result[..., :, :n2] = result_temp[..., :, -n2:]
        result[..., :, n2:] = result_temp[..., :, :-n2]
    return result


def cut_off_filter(filter, cut_off):
    """

    :param filter: np array of size (...,L,L)
    :param cut_off: int
    :return: np array of size (..., L,L) where only the first element are non-zero
    """

    filter = filter.copy()
    L = filter.shape[-1]

    if cut_off == 0:
        result = np.zeros(filter.shape)
        result[..., 0, 0] = filter[..., 0, 0]
        return result
    if 2*cut_off + 1 >= L:
        return filter

    filter[..., cut_off+1:-cut_off, :cut_off+1] = 0.
    filter[..., cut_off+1:-cut_off, -cut_off:] = 0.
    filter[..., -cut_off:, cut_off+1:-cut_off,] = 0.
    filter[..., :cut_off+1, cut_off+1:-cut_off, ] = 0.
    filter[..., cut_off+1:-cut_off, cut_off+1:-cut_off] = 0.

    return filter

def multiply_filters(filter1, filter2, filter3):
    """
    Fourier multiplication of three filters.
    """

    F1 = np.fft.fft2(filter1, norm='ortho')
    F2 = np.fft.fft2(filter2, norm='ortho')
    F3 = np.fft.fft2(filter3, norm='ortho')
    fft_filter = np.multiply(np.multiply(F1, F2), F3)
    filter = np.real(np.fft.ifft2(fft_filter, norm='ortho'))
    return filter

def scharr_convolution(x):
    """
    x is a (n, 1, L, L) arr
    output : a (n,1,L,L) arr
    """
    x_filtered = np.zeros_like(x)
    f = (1/9) * np.ones((3,3))
    for i in range(x.shape[0]):
        x_filtered[i, 0, :, :] = np.real(convolve2d(x[i,0,:,:], f, boundary='symm', mode='same'))
    return x_filtered


#################################################################
#        various utilities for the hamiltonians
#################################################################


def compute_stationary_covariance(phi):
    """
    input: array (n,1,L,L)
    output : array (n,L,L)
    output[i, :, :] = stationary covariance of input[i, 0, :, :]
    """
    phi_fft = torch.fft.fft2(phi[:,0,:,:])
    cov_fft = torch.abs(phi_fft)**2
    ret = torch.real(torch.fft.ifft2(cov_fft)) / phi.shape[-1]**2
    return ret


def compute_quadratic_energy(quadratic_coupling, phi):
    """
    Computes < phi, quadratic_coupling * phi > where * is the convolution.

    input:
    - quadratic_coupling array (L,L)
    - phi = array (n,1,L,L)

    output : array (n,)
    """
    n = phi.shape[0]
    phi_fft_sqnorm = torch.abs(torch.fft.fft2(phi[:,0,:,:]))**2
    k_fft = np.fft.fft2(quadratic_coupling)
    p_fft = np.multiply(k_fft, phi_fft_sqnorm) / phi.shape[-1]**2
    ret = np.sum(p_fft, axis=(1,2))
    return ret

def compute_power_spectrum(operator):
    '''
    :param operator: (L,L)
    :return: power spectrum
    '''
    L  = operator.shape[-1]

    fft_operator = np.fft.fft2(operator, norm="ortho")
    fft_operator = np.fft.fftshift(fft_operator)
    power_spectrum_2d = np.abs(fft_operator) ** 2

    h = power_spectrum_2d.shape[0]
    w = power_spectrum_2d.shape[1]
    wc = w // 2
    hc = h // 2

    # create an array of integer radial distances from the center
    Y, X = np.ogrid[0:h, 0:w]
    r = np.hypot(X - wc, Y - hc).astype(int)
    psd1D = scipy.ndimage.sum(power_spectrum_2d, r, index=np.arange(0, wc))

    nb_mode = []
    for value_r in range(0, np.max(r) + 1):
        nb_mode.append(np.sum(r == value_r))
    return psd1D / np.array(nb_mode[:L//2])

def circular_average(x):
    
    L  = x.shape[-1]

    h = x.shape[0]
    w = x.shape[1]
    wc = w // 2
    hc = h // 2

    # create an array of integer radial distances from the center
    Y, X = np.ogrid[0:h, 0:w]
    r = np.hypot(X - wc, Y - hc).astype(int)
    psd1D = scipy.ndimage.sum(x, r, index=np.arange(0, wc))

    nb_mode = []
    for value_r in range(0, np.max(r) + 1):
        nb_mode.append(np.sum(r == value_r))
    return psd1D / np.array(nb_mode[:L//2])