import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from cvxopt import matrix, solvers

sys.path.insert(0, './')


def calc_inradius(y,Y_knn,w_knn,eta): #current implementation is for dim = 2
    Y_knn = np.vstack([Y_knn,[eta,eta]])
    l = np.linalg.norm(Y_knn[:,0]/w_knn[0]+Y_knn[:,1]/w_knn[1])
    b = np.linalg.norm(Y_knn[:,0]/w_knn[0]-Y_knn[:,1]/w_knn[1])
    if l<b:
        return l/2
    else:
        return b/2

def solve_dual(y,Y_knn,w_knn,lam,eta): #current implementation is for dim = 2
    Y_knn = np.vstack([Y_knn,[eta,eta]])
    def solve_cvx(G, a, C, b):
        G = matrix(G)
        a = matrix(a)
        C = matrix(C.T)
        b = matrix(b)
        solvers.options['show_progress'] = False
        sol=solvers.qp(G, a, C, b)
        return sol['x']


    # data: 
    G = (1/(2*lam))*np.identity(Y_knn.shape[0])
    a = -np.append(y,eta)
    C = np.array([Y_knn[:,0], Y_knn[:,1], -Y_knn[:,0], -Y_knn[:,1]])
    C = C.T
    b = np.array([w_knn[0],w_knn[1],w_knn[0],w_knn[1]])    

    #print("    min. 1/2 x^T G x + a^T x")
    #print("    s.t. C * x <= b")
    soln = solve_cvx(G, a, C, b)
    return soln

def project_on_norm_subspace(nu, Y_knn,eta): #current implementation is for dim = 2
    Y_knn = np.vstack([Y_knn,[eta,eta]])
    a = np.array([[np.dot(Y_knn[:,0],Y_knn[:,0]), np.dot(Y_knn[:,0],Y_knn[:,1])], [np.dot(Y_knn[:,0],Y_knn[:,1]), np.dot(Y_knn[:,1],Y_knn[:,1])]])
    b = np.array([np.dot(Y_knn[:,0],nu), np.dot(Y_knn[:,1],nu)])
    x = np.linalg.solve(a, b)
    #print(x)
    nu_proj = x[0]*Y_knn[:,0]+x[1]*Y_knn[:,1]
    #print(np.linalg.norm(nu),np.linalg.norm(nu_proj))
    return nu_proj

def project_on_subspace(x, Y_knn): #current implementation is for dim = 2
    p3 = x
    p1 = Y_knn[:,0]
    p2 = Y_knn[:,1]
    #distance between p1 and p2
    l2 = np.sum((p1-p2)**2)
    if l2 == 0:
        print('p1 and p2 are the same points')

    #project on line extention connecting p1 and p2
    t = np.sum((p3 - p1) * (p2 - p1)) / l2

    #if t > 1 or t < 0:
    #    print('p3 does not project onto p1-p2 line segment')

    #project on line segment between p1 and p2 or closest point of the line segment
    #t = max(0, min(1, np.sum((p3 - p1) * (p2 - p1)) / l2))

    projection = p1 + t * (p2 - p1)
    
    return np.linalg.norm(p3-projection)
    
def calc_gamma(Y,nu_proj,N,N2,w,r,eta):
    mu = 10
    for k in range(N,N+N2):
        mu_curr = w[k-1]*r-(np.abs(np.dot(np.append(Y[:,k],eta),nu_proj))/np.linalg.norm(nu_proj))
        mu_curr = mu_curr/np.linalg.norm(np.append(Y[:,k],eta))
        if mu_curr < mu:
            mu = mu_curr
    return mu
