"""Wrapper of vanilla CRT and HRT, R implementation from
https://github.com/moleibobliu/Distillation-CRT/
"""
from pathlib import Path

import numpy as np
from rpy2 import robjects
from rpy2.robjects import numpy2ri
from sklearn.preprocessing import StandardScaler

from sandbox.gaussian_knockoff import _estimate_distribution

CRT_FILE = str(Path(__file__).with_suffix('.R'))
# load objects in R
robjects.r(f'''source('{CRT_FILE}')''')
numpy2ri.activate()


def crt(X, y, n_samplings=100, center=True, method='LASSO', fdr=0.1,
        model='gaussian', verbose=False, n_jobs=1):

    if center:
        X = StandardScaler().fit_transform(X)

    n_samples, n_features = X.shape
    mu, Sigma = _estimate_distribution(X)
    crt_smc = robjects.r.CRT_sMC
    results = crt_smc(y.reshape(1, -1), X, Sigma, m=n_samplings, FDR=fdr,
                      model=model, n_jobs=n_jobs)

    if len(results[0]) > 0:
        selected_index = np.array(results[0]) - 1  # R index starting from 1
    else:
        selected_index = np.array([])

    if verbose:
        pvals = results[1]
        return selected_index, pvals

    return selected_index


def hrt(X, y, n_samplings=100, center=True, method='CV', fdr=0.1,
        screening=True, model='gaussian', verbose=False, n_jobs=1):

    if center:
        X = StandardScaler().fit_transform(X)

    n_samples, n_features = X.shape
    mu, Sigma = _estimate_distribution(X)

    pvl_study = not screening
    hrt_r = robjects.r.HRT
    results = hrt_r(y.reshape(1, -1), X, Sigma, N=n_samplings, FDR=fdr,
                    pvl_study=pvl_study, model_select=method, model=model,
                    n_jobs=n_jobs)

    if len(results[0]) > 0:
        selected_index = np.array(results[0]) - 1  # R index starting from 1
    else:
        selected_index = np.array([])

    if verbose:
        pvals = results[1]
        return selected_index, pvals

    return selected_index
