import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

from pathlib import Path

home = str(Path.home())


def get_data(dataset_name,
             test_size=1.0 / 3.0,
             shuffle=True,
             standardize_x=True,
             standardize_y=True):
    print(f"Loading dataset {dataset_name}... ", end='')
    x, y = data_loaders[dataset_name]()

    if isinstance(x, list) and isinstance(y, list):
        # This is for datasets where the data provider has specified explicit {x,y}_{trn,tst} e.g. blog_feedback
        x_trn, x_tst, y_trn, y_tst = x + y
    else:
        x_trn, x_tst, y_trn, y_tst = train_test_split(x, y, test_size=test_size, shuffle=shuffle)

    if standardize_x:
        x_scaler = StandardScaler()
        x_scaler.fit(x_trn)
        x_trn = x_scaler.transform(x_trn)
        x_tst = x_scaler.transform(x_tst)
    else:
        x_scaler = None
    if standardize_y:
        y_scaler = StandardScaler()
        y_scaler.fit(y_trn)
        y_trn = y_scaler.transform(y_trn)
        y_tst = y_scaler.transform(y_tst)
    else:
        y_scaler = None

    return x_trn, x_tst, y_trn, y_tst, x_scaler, y_scaler


def load_elevators():
    """
    regression target:  goal
    data source:        https://web.archive.org/web/*/http://www.liacc.up.pt/~ltorgo/Regression/*
    preprocessing:      NA
    :return:
    """
    dataset_path = f"{home}/datasets/uci/elevators/elevators.data"
    data = pd.read_csv(dataset_path).values
    x = data[:, :-1]
    y = data[:, [-1]]
    return x, y


def load_song():
    """
    regression target:  The year of song release (first column)
    data source:        https://archive.ics.uci.edu/ml/datasets/Buzz+in+social+media+#
    preprocessing:      None
    :note:              N=500,000, D=90
    :return:
    """

    dataset_path = f"{home}/datasets/uci/Song/YearPredictionMSD.csv"
    data = pd.read_csv(dataset_path, header=None).values
    x = data[:, 1:]
    y = data[:, [0]]
    return x, y


def load_buzz():
    """
    regression target:  "Mean Number of active discussion (NAD). This attribute is a positive integer
                        that describe the popularity of the instance's topic. It is stored is
                        the rightmost column." from Twitter.names
    data source:        https://archive.ics.uci.edu/ml/datasets/Buzz+in+social+media+#
    preprocessing:      log(1+y) transform of target y
    :return:
    """
    dataset_path = f"{home}/datasets/uci/Buzz/Twitter.csv"
    data = pd.read_csv(dataset_path, header=None).values
    x = data[:, :-1]
    y = data[:, [-1]]
    return x, np.log(1 + y)


def load_bike_sharing_hourly():
    """
    regression target:  Number of bike shares per hour
    data source:        https://archive.ics.uci.edu/ml/datasets/Bike+Sharing+Dataset
    preprocessing:      NA
    :return:
    """
    dataset_path = f"{home}/datasets/uci/Bike-Sharing-Dataset/hour.csv"
    data = pd.read_csv(dataset_path)
    data['dteday'] = pd.to_datetime(data['dteday']).astype(np.int64) / 1000000000
    x = data.values[:, 1:-1]  # Doesn't make sense to include the sample ID
    # y = np.array(data['cnt']).reshape(-1,1)
    y = data.values[:, [-1]]
    return x, y


def load_airfoil_noise():
    """
    regression target:  sound pressure in decibels
    data source:        https://archive.ics.uci.edu/ml/datasets/Airfoil+Self-Noise
    preprocessing:      NA
    :return:
    """
    dataset_path = f"{home}/datasets/uci/airfoil_noise/airfoil_self_noise.dat"
    data = pd.read_csv(dataset_path, header=None, sep=r'\t+', engine='python').values
    x = data[:, 0:-1]
    y = data[:, [-1]]
    return x, y


def load_concrete_compressive():
    """
    regression target:  compressive concrete strength
    data source:        https://archive.ics.uci.edu/ml/datasets/concrete+compressive+strength
    preprocessing:      NA
    :return:
    """
    dataset_path = f"{home}/datasets/uci/concrete/Concrete_Data.xls"
    data = pd.read_excel(dataset_path).values
    x = data[:, 0:-1]
    y = data[:, [-1]]
    return x, y


def load_protein_structure():
    """
    regression target:  RMSD
    data source:        https://archive.ics.uci.edu/ml/datasets/Physicochemical+Properties+of+Protein+Tertiary+Structure
    preprocessing:      log(1+y) transform for target y
    :return:

    """
    dataset_path = f"{home}/datasets/uci/Protein/CASP.csv"
    data = pd.read_csv(dataset_path).values
    x = data[:, 1:]
    y = np.log(1 + data[:, [0]])
    return x, y


def load_superconductivity():
    """
    regression target:  Critical Temperature
    data source:        https://archive.ics.uci.edu/ml/datasets/Superconductivty+Data
    preprocessing:      None
    :return:

    """
    dataset_path = f"{home}/datasets/uci/superconductivity/train.csv"
    data = pd.read_csv(dataset_path).values
    x = data[:, :-1]
    y = data[:, [-1]]
    return x, y


def load_ct_slice():
    """
    regression target:  Reference (relative location)
    data source:        https://archive.ics.uci.edu/ml/datasets/Relative+location+of+CT+slices+on+axial+axis
    preprocessing:      - Drop patient ID
                        - Remove columns that are constant throughout the entire available dataset
    :return:

    """
    dataset_path = f"{home}/datasets/uci/CTslice/slice_localization_data.csv"
    data = pd.read_csv(dataset_path).values
    x = data[:, 1:-1]
    y = data[:, [-1]]
    # Find columns in the entire dataset with constant values. Remove them.
    x_range = np.ptp(x, axis=0)
    const_column_idxs = np.nonzero(x_range.flatten() == 0)[0]
    x = np.delete(x, const_column_idxs, axis=1)
    return x, y


def load_parkinsons_total():
    """
    regression target:  total udpr
    data source:        https://archive.ics.uci.edu/ml/datasets/Parkinsons+Telemonitoring
    preprocessing:      - Drop the first 5 columns as they are not used in the original problem
    :return:

    """
    dataset_path = f"{home}/datasets/uci/parkinsons/parkinsons_updrs.csv"
    data = pd.read_csv(dataset_path).values
    x = data[:, 6:]
    # x = np.arcsinh(data[:, 6:])
    y = data[:, [5]]
    return x, y


def load_abalone():
    """
    regression target:  Number of Rings
    data source:        https://web.archive.org/web/*/http://www.liacc.up.pt/~ltorgo/Regression/*
    preprocessing:      NA
    :return:
    """
    dataset_path = f"{home}/datasets/abalone/abalone.data"
    names = ['Sex', 'Length', 'Diam', 'Height', 'Whole', 'Shucked', 'Viscera', 'Shell', 'Rings']
    abalone = pd.read_csv(dataset_path, names=names)
    abalone = pd.get_dummies(abalone, drop_first=True)

    x = abalone.drop('Rings', axis=1).values.astype(dtype=np.float64)
    y = abalone['Rings'].values.astype(dtype=np.float64)
    return x, y.reshape(-1, 1)


def load_creep():
    """
    regression target:  Rupture Stress
    data source:        https://web.archive.org/web/*/http://www.liacc.up.pt/~ltorgo/Regression/*
    preprocessing:      NA
    :return:
    """
    dataset_path = f"{home}/datasets/creep/creep.data"
    names = ['Lifetime', 'Rupture_stress', 'Temperature', 'Carbon', 'Silicon', 'Manganese', \
             'Phosphorus', 'Sulphur', 'Chromium', 'Molybdenum', 'Tungsten', 'Nickel', 'Copper', \
             'Vanadium', 'Niobium', 'Nitrogen', 'Aluminium', 'Boron', 'Cobalt', 'Tantalum', 'Oxygen', \
             'Normalising_temperature', 'Normalising_time', 'Cooling_rate', 'Tempering_temperature', \
             'Tempering_time', 'Cooling_rate_tempering', 'Annealing_temperature', 'Annealing_time', \
             'Cooling_rate_annealing', 'Rhenium']
    creep = pd.read_table(dataset_path, names=names).astype('float64')

    x = creep.drop('Rupture_stress', axis=1).values.astype(dtype=np.float64)
    y = creep['Rupture_stress'].values.astype(dtype=np.float64)
    return x, y.reshape(-1, 1)


def load_ailerons():
    """
    regression target:  goal
    data source:        https://web.archive.org/web/*/http://www.liacc.up.pt/~ltorgo/Regression/*
    preprocessing:      NA
    :return:
    """
    dataset_path = f"{home}/datasets/ailerons/ailerons.data"
    names = ['climbRate', 'Sgz', 'p', 'q', 'curPitch', 'curRoll', 'absRoll', 'diffClb', 'diffRollRate', \
             'diffDiffClb', 'SeTime1', 'SeTime2', 'SeTime3', 'SeTime4', 'SeTime5', 'SeTime6', 'SeTime7', \
             'SeTime8', 'SeTime9', 'SeTime10', 'SeTime11', 'SeTime12', 'SeTime13', 'SeTime14', \
             'diffSeTime1', 'diffSeTime2', 'diffSeTime3', 'diffSeTime4', 'diffSeTime5', 'diffSeTime6', \
             'diffSeTime7', 'diffSeTime8', 'diffSeTime9', 'diffSeTime10', 'diffSeTime11', \
             'diffSeTime12', 'diffSeTime13', 'diffSeTime14', 'alpha', 'Se', 'goal']
    ailerons = pd.concat([pd.read_csv(dataset_path, names=names)]).astype('float64')

    x = ailerons.drop('goal', axis=1).values.astype(dtype=np.float64)
    y = ailerons['goal'].values.astype(dtype=np.float64)
    return x, y.reshape(-1, 1)


data_loaders = {
    "airfoil_noise": load_airfoil_noise,
    "concrete_compressive": load_concrete_compressive,
    "parkinsons_total": load_parkinsons_total,
    "elevators": load_elevators,
    "bike_sharing_hourly": load_bike_sharing_hourly,
    "protein_structure": load_protein_structure,
    "ct_slice": load_ct_slice,
    "superconductivity": load_superconductivity,

    "ailerons": load_ailerons,
    "creep": load_creep,
    "abalone": load_abalone,

    "song": load_song,
    "buzz": load_buzz,
}
