import numpy as np
from matplotlib import pyplot as plt

def LU_fac(matrix):
    assert matrix.shape[0] == matrix.shape[1]
    size = matrix.shape[0]
    if size == 1:
        return np.sign(matrix) * np.sqrt(np.abs(matrix)), np.sqrt(np.abs(matrix))
    
    left_fac = np.zeros((size, size))
    right_fac = np.zeros((size, size))
    pivot = matrix[0][0]
    
    x = matrix[:,0]
    y = matrix[0,:]
    
    left_fac[:,0] = x * np.sqrt(np.linalg.norm(y) / (np.linalg.norm(x) * abs(pivot)))
    if pivot < 0:
        right_fac[0,:] = -y * np.sqrt(np.linalg.norm(x) / (np.linalg.norm(y) * abs(pivot)))
    else:
        right_fac[0,:] = y * np.sqrt(np.linalg.norm(x) / (np.linalg.norm(y) * abs(pivot)))
    left_component, right_component = LU_fac((matrix - np.tensordot(left_fac[:,0], right_fac[0,:], axes = 0))[1:,1:])
    left_fac[1:,1:] = left_component
    right_fac[1:,1:] = right_component
    return left_fac, right_fac

def oracle_generate(matrix, eps_min = 0.1, eps_max = 0.5, n_samples = 100, n_data = 10000):
    test_error = []
    norm = []
    x = np.random.randn(matrix.shape[1], n_data)
    y = matrix @ x

    eps = np.linspace(eps_min * np.linalg.norm(matrix), eps_max * np.linalg.norm(matrix), n_samples)
    for epsilon in eps:
        noisy = np.random.rand(matrix.shape[0], matrix.shape[1])
        noisy = noisy / np.linalg.norm(noisy) 

        new_result = matrix + epsilon * noisy

        a, b = LU_fac(new_result)

        y_prime = new_result @ x

        test_error.append(np.mean(np.linalg.norm(y - y_prime, axis = 1) / np.linalg.norm(y, axis = 1)))
        norm.append(np.sqrt(np.linalg.norm(a) ** 2 + np.linalg.norm(b) ** 2))
    return test_error, norm

if __name__ == '__main__':
    n = 100
    result = np.zeros((n,n))
    for i in range(n):
        result[i, n - i - 1] = 1.0

    test_error, norm = oracle_generate(result)
    

    plt.scatter(norm, test_error)
    plt.show()