#!/usr/bin/env python
# coding: utf-8

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import ortho_group
from scipy.linalg import block_diag
from scipy import ndimage
import time
import scipy.io

tbegScript = time.time()

#%% # Functions

def R(theta):
    return [[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]

def rotation(Q,theta_vector):
    rot = block_diag(*[R(th_) for th_ in theta_vector])
    return Q@rot@Q.T

def bivec(a,b):
    outer = np.outer(a,b)
    return ( outer - outer.T )

def cosine_similarity(a,b):
    return np.dot(a,b)/np.linalg.norm(a)/np.linalg.norm(b)

def plot_convergence(data):
    plt.figure(figsize=(4,6))
    plt.semilogy(1-data) # plot(data) #semilogy(1-data)
    plt.legend(['Subspace 1','Subspace 2','Subspace 3'])
    plt.xlabel('Iterations')
    plt.ylabel('subspace fit')

def plot_angles(dim,true_angles,est_angles,color=None,title=None):
    if color is None: color=np.zeros(dim) 
    x = np.linspace(-np.pi, np.pi, 3)
    fig, axs = plt.subplots(dim, sharex=True, figsize=(4,2*p_dim)) 
    #fig.suptitle(title)
    for i in range(dim):
        axs[i].plot(x, x, c='black', lw = 1, ls = '--')
        #axs[i].plot(-x, x, c='black', lw = 1, ls = '--')
        axs[i].scatter(true_angles[i,:],est_angles[i,:],s=2,c='C%g'%color[i])
        if i==dim-1: 
            axs[i].set_xlabel('true angle',fontsize=16)
        axs[i].set_xlim(-np.pi,np.pi)
        axs[i].set_ylabel(f'estimated angle {i+1}')
        axs[i].set_ylim(-np.pi,np.pi)
        #axs[i].grid() 
    axs[0].set_title(title)
    
    
# Anti-symm SVD online by deflation, a-la GHA (Sanger 89) 
# from Matlab: skewGhaXYStep
def xpowerStep(x,y,u,v,eta):
    len,kk=u.shape
    u1=np.zeros((x_dim,p_dim))
    v1=np.zeros((x_dim,p_dim))
    xCumul=np.zeros(len);
    yCumul=np.zeros(len);
    for k in range(kk) :
        ax=np.dot(u[:,k],x); ay=np.dot(u[:,k],y);
        bx=np.dot(v[:,k],x); by=np.dot(v[:,k],y);
        z = ay*bx - ax*by;
        u1[:,k] = u[:,k] + eta*( bx*(y-yCumul) - by*(x-xCumul) - z*u[:,k] );
        v1[:,k] = v[:,k] - eta*( ax*(y-yCumul) - ay*(x-xCumul) + z*v[:,k] );
        xCumul = xCumul + ax*u[:,k] + bx*v[:,k];
        yCumul = yCumul + ay*u[:,k] + by*v[:,k];

    return u1,v1

#%% # Generating the timeseries
# seed=0
# np.random.seed(seed)
tbeg = time.time()

r_dim = 5; x_dim = 2*r_dim; 
p_dim = 3; y_dim = 2*p_dim; 
samples = 10**5; eps = 0

sigma = np.array([.4,.3,.2,0,0]); 
shifts = np.array([.3,.2,.1,.01,.01]);  #[.1,.2,.3,.01,.01]);
Q = ortho_group.rvs(x_dim); 
proj_Q = Q[:,:y_dim]@Q[:,:y_dim].T

X = np.random.randn(x_dim,samples); 
Theta = np.diag(np.array(sigma))@np.random.randn(r_dim,samples) + np.array(shifts).reshape(-1,1)

svdOrd = np.argsort(np.mean(Theta,1))[::-1][:p_dim]
pcaOrd = np.argsort(np.mean(np.cos(Theta),1))[:p_dim]
#%% rotate
X_rot = np.zeros((x_dim,samples))

for t in range(samples):
    X_rot[:,t] = rotation(Q,Theta[:,t])@X[:,t] + eps*np.random.randn(x_dim)

X_dot = X_rot - X

C_dot = np.cov(X_dot)

print("Time data gen", time.time()-tbeg)

# #%%  PCA eigvals
plt.bar(range(x_dim),np.linalg.eig(C_dot)[0]) #hist(np.linalg.eig(C_dot)[0],bins=10)
plt.show()

print(np.mean(Theta,1).T)
print( np.mean(np.cos(Theta),1).T )
#%%  true subspaces
    
Q_ss = np.zeros((p_dim,x_dim*x_dim))
for i in range(p_dim):
    outer = np.outer(Q[:,2*i],Q[:,2*i+1])
    Q_ss[i,:] = ( outer - outer.T ).flatten()

#%% run SVD
def runSVD(x_dim,p_dim,samples,X,X_rot,Q_ss_ord,eta) :
    u=np.random.randn(x_dim,p_dim); u=u/np.linalg.norm(u,axis=0)
    v=np.random.randn(x_dim,p_dim); v=v/np.linalg.norm(v,axis=0)
    fitLog=np.zeros((samples,p_dim));
    speedest=np.zeros((samples,p_dim));
    aveChi=np.zeros((x_dim,x_dim));
    # loop
    tbeg = time.time()
    for i in range(samples) :
        x = X[:,i]
        x_rot = X_rot[:,i]
        
        [u1,v1]=xpowerStep( x, x_rot, u, v, eta );
        u=u1;v=v1
        aveChi = aveChi*(i/(i+1)) + bivec(x_rot, x)/(i+1);
    
        for k in range(p_dim) :        #- plane fit
            fitLog[i,k] = cosine_similarity( bivec(u[:,k],v[:,k]).T.flatten(), Q_ss_ord[k,:] )  # targetBivec[:,:,k].flatten() );
    
        #- evaluate speed
        a=u.T @ x_rot; b=v.T @ x_rot; 
        ylen=np.sqrt(a**2+b**2);
        atau=u.T @ x; btau=v.T @ x; ytaulen=np.sqrt(atau**2+btau**2);
        speedest[i,:]=np.arcsin( (a*btau - b*atau) /ylen/ytaulen );
    print("SVD time", time.time()-tbeg)
    return fitLog,speedest

#%% 10 reps
eta=0.0005;
reps=10
svdFit=np.zeros((samples, p_dim, reps))
svdSpeed=np.zeros((1000, p_dim, reps))
for r in range(reps) :
    fitLog,speedest = runSVD(x_dim,p_dim,samples,X,X_rot,Q_ss[svdOrd,:],eta)
    svdFit[:,:,r]=fitLog
    svdSpeed[:,:,r]=speedest[-1000:,:]
    print('rep ',r)
    plot_convergence(fitLog)
    plot_angles( p_dim, Theta[svdOrd,-1000:], speedest[-1000:,:].T )
#%% joint fit plots
def plotAnglesReps(speed,Theta,ord_,title=None):
    n,k,reps=speed.shape
    true=Theta[:k,-1000:].T
    estSelect=np.zeros((n,k))
    trueSelect=np.zeros((n,k))
    indratio=int(n/reps)
    for irep  in range(reps):
        ind = np.random.permutation(range(n))[:indratio]
        trueSelect[irep*indratio:(irep+1)*indratio,:]=true[ind,:]
        estSelect[irep*indratio:(irep+1)*indratio,:]=speed[ind,:,irep]
    plot_angles( p_dim, trueSelect.T, estSelect.T,ord_,title )
#%%
def plotFitReps(fit,color=None):
    if color is None: color=np.zeros(fit.shape[1]) 
    plt.figure(figsize=(4,6))
    n,k,qq=fit.shape
    for i in range(k):
        plt.plot( range(n), scipy.ndimage.median_filter(np.mean(1-fit[:,i,:],1),501), c='C%g'%color[i] );
        plt.yscale('log')
        plt.fill_between(range(n),np.min(1-fit[:,i,:],1),np.max(1-fit[:,i,:],1),alpha=.3, color='C%g'%color[i] )
    plt.xlabel('Iterations',fontsize=16)
    plt.ylabel('Subspace fit loss',fontsize=16)
#%% run PCA online

def runPCA(x_dim,p_dim,y_dim,samples,X,X_rot,Q_ss,etaPCA) :

    Lam = np.eye(y_dim)    
    for i in range(p_dim):
        Lam[2*i,2*i] = p_dim - i
        Lam[2*i+1,2*i+1] = p_dim -i
    
    pcaTbeg = time.time()
    cosine = np.zeros((p_dim,samples))
    error = np.zeros(samples)
    dot = np.zeros((p_dim,samples))
    sign = np.ones((p_dim,samples))
    
    W = np.random.randn(y_dim,x_dim)
    M = np.eye(y_dim)
    
    for t in range(samples):
        F = np.linalg.inv(M)@W
            
        x = X[:,t]
        x_rot = X_rot[:,t]
        y = F@x
        y_rot = F@x_rot
        
        W = W + etaPCA*(np.outer(y_rot-y,x_rot-x) - W)
        M = M + etaPCA*(np.outer(y_rot-y,y_rot-y) - Lam@M@Lam)
        
        for i in range(p_dim):
            norm = np.linalg.norm(y[2*i:2*i+2])*np.linalg.norm(y_rot[2*i:2*i+2])
            dot[i,t] = np.dot(y[2*i:2*i+2],y_rot[2*i:2*i+2])/norm
            sign[i,t] = (y[2*i]*y_rot[2*i+1]-y[2*i+1]*y_rot[2*i])/norm
            outer = np.outer(F[2*i,:],F[2*i+1,:])
            cosine[i,t] = cosine_similarity( (outer-outer.T).flatten(), Q_ss[i,:] )
        
        #error[t] = np.linalg.norm(F.T@np.linalg.inv(Lam**2)@F - proj_Q)
    print("PCA time", time.time()-pcaTbeg)
    
    return cosine,sign/dot
#%% 10 reps
etaPCA = 5e-4
reps=10
pcaFit=np.zeros((samples, p_dim, reps))
pcaSpeed=np.zeros((1000, p_dim, reps))
for r in range(reps) :
    cosine,tangent = runPCA(x_dim,p_dim,y_dim,samples,X,X_rot,Q_ss[pcaOrd,:],etaPCA)
    cosSign=(np.mean(cosine,1)>0)*2-1
    pcaFit[:,:,r]=(np.diag(cosSign) @ cosine).T
    pcaSpeed[:,:,r]=(np.diag(cosSign)@np.arctan(tangent[:,-1000:])).T
    print('rep ',r)
    plot_convergence(pcaFit[:,:,r])
    plot_angles( p_dim, Theta[pcaOrd,-1000:], pcaSpeed[:,:,r].T )
#%%  P L O T S   with REPS
plotAnglesReps(svdSpeed,Theta[svdOrd,:],svdOrd,'SVD')
plt.savefig(f'anglesSVDreps.png', dpi=300, transparent='true', bbox_inches='tight')
plotFitReps(svdFit,svdOrd);plt.title('SVD')
plt.savefig(f'convSVDreps.png', dpi=300, transparent='true', bbox_inches='tight')
plotAnglesReps(pcaSpeed,Theta[pcaOrd,:],pcaOrd,'PCA')
plt.savefig(f'anglesPCAreps.png', dpi=300, transparent='true', bbox_inches='tight')
plotFitReps(pcaFit,pcaOrd);plt.title('PCA')
plt.savefig(f'convPCAreps.png', dpi=300, transparent='true', bbox_inches='tight')
#%%
print("Script Time", time.time()-tbegScript)

