import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA

def draw_all(latent, latent_gmm, data, data_hat, labels_gmm, preds, data_labels, epoch, dir_out):
    fig = plt.figure(figsize=(12, 3.5))
    plt.subplot(2, 2, 1)
    plt.scatter(latent[:,0], latent[:,1], c=preds, s=6, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.title("Encoder's z", fontsize=16, family='serif')
        
    plt.subplot(2, 2, 2)
    plt.scatter(latent_gmm[:,0], latent_gmm[:,1], c=labels_gmm, s=6, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.title('Latent GMM', fontsize=16, family='serif')

    plt.subplot(2, 2, 3)
    plt.scatter(data_hat[:,0], data_hat[:,1], c=preds, s=6, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.title("Decoder's output", fontsize=16, family='serif')
    
    plt.subplot(2, 2, 4)
    plt.scatter(data[:,0], data[:,1], c=data_labels, s=6, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.title("Observed distribution", fontsize=16, family='serif')

    plt.tight_layout()
    plt.savefig(dir_out + 'learnt_latents_{}.png'.format(epoch))
    # plt.show()
    plt.close()
    
def run_pca(matrix, dim = 2):
    #U, S, V = torch.pca_lowrank(latent, q = dim)
    #return torch.matmul(matrix, V)
    pca = PCA(n_components=2)
    return pca.fit_transform(matrix)
    
def draw_together(latent, latent_gmm, data, data_hat, labels_gmm, preds, data_labels, epoch, dir_out):
    
    if (latent.shape[1]>2):
        latent = run_pca(latent)
        latent_gmm = run_pca(latent_gmm)
    data = run_pca(data)
    data_hat = run_pca(data_hat)
    
    
    
    shift = np.max(preds)
    fig = plt.figure(figsize=(12, 7))
    plt.subplot(1, 2, 1)
    plt.scatter(latent[:,0], latent[:,1], c=preds, s=6, alpha=0.9)
    plt.scatter(latent_gmm[:,0], latent_gmm[:,1], c=(labels_gmm+shift+1), s=6, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.title("Latent GMM+Encoder's z", fontsize=16, family='serif')

    plt.subplot(1, 2, 2)
    plt.scatter(data_hat[:,0], data_hat[:,1], c=preds, s=6, alpha=0.9)    
    plt.scatter(data[:,0], data[:,1], c=(data_labels+shift+1), s=6, alpha=0.2)
    plt.xticks([])
    plt.yticks([])
    plt.title("Observed distribution+Decoder's output", fontsize=16, family='serif')

    plt.tight_layout()
    plt.savefig(dir_out + 'learnt_latents_{}.png'.format(epoch))
    # plt.show()
    plt.close()