'''
Script to reproduce the distance figure of the article.
'''
import numpy as np

import matplotlib.pyplot as plt
from sklearn.linear_model import RidgeCV
from sklearn.dummy import DummyRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_val_score

from pyriemann.tangentspace import TangentSpace as GeomVect
from wasserstein_tangent import TangentSpace as WassVect
from generation import generate_covariances
from utils import logDiag


print('Running distance experiment...')

rng = 5

# Distances range
distances = np.linspace(0.001, 3, 10)


# Parameters
n_matrices = 100  # Number of matrices
n_channels = 5   # Number of channels
n_sources = 2  # Number of sources
sigma = .31  # Noise level
f_powers = 'log'  # link function between the y and the source powers
# Direction for the change in A
direction_A = np.random.RandomState(4).randn(n_channels, n_channels)
direction_A /= np.linalg.norm(direction_A)


# Chose embeddings
embeddings = [None, logDiag(), WassVect(n_channels), GeomVect()]
names = ['Chance level', 'Log-powers', 'Wasserstein', 'Geometric']

# Run experiments
results = np.zeros((len(names), len(distances)))
for j, distance_A_id in enumerate(distances):
    X, y = generate_covariances(n_matrices, n_channels, n_sources, sigma=sigma,
                                distance_A_id=distance_A_id, f_p=f_powers,
                                direction_A=direction_A, rng=rng)
    for i, (name, embedding) in enumerate(zip(names, embeddings)):
        print('distance = {}, {} method'.format(distance_A_id, name))
        lr = RidgeCV(alphas=np.logspace(-7, 3, 10),
                     scoring='neg_mean_absolute_error')
        if name == 'Chance level':
            pipeline = Pipeline([('emb', logDiag()),
                                 ('sc', StandardScaler()),
                                 ('lr', DummyRegressor())])
        else:
            pipeline = Pipeline([('emb', embedding),
                                 ('sc', StandardScaler()),
                                 ('lr', lr)])

        sc = cross_val_score(pipeline, X, y,
                             scoring='neg_mean_absolute_error',
                             cv=10, n_jobs=3)
        results[i, j] = - np.mean(sc)


# Plot
f, ax = plt.subplots(figsize=(4, 3))
results /= results[0]
for i, name in enumerate(names):
    if name != 'Chance level':
        ls = None
    else:
        ls = '--'
    ax.plot(distances, results[i],
            label=name,
            linewidth=3,
            linestyle=ls)


ax.set_xlabel('distance')
plt.grid()
ax.set_ylabel('Normalized M.A.E.')
ax.hlines(0, distances[0], distances[-1], label=r'Perfect',
          color='k', linestyle='--', linewidth=3)
ax.legend(loc='lower right')
f.tight_layout()
plt.show()
