# =================================================================================================#
# Description: Produces plot of the sparsity levels on the energy consumption data
# Author: Ryan Thompson
# =================================================================================================#

include("Estimators/contextual_lasso.jl")

import Cairo, CSV, DataFrames, Dates, Fontconfig, Gadfly, Random, StatsBase

# CUDA.jl is not reproducible with default rng
rng = Random.MersenneTwister(2023); Random.default_rng() = rng # Random.seed!(2023)

# Load data
data = CSV.read("Data/energydata_complete.csv", DataFrames.DataFrame)
DataFrames.transform!(data, :date => DataFrames.ByRow(x -> Dates.DateTime(x, 
    Dates.dateformat"yyyy-mm-dd HH:MM:SS")) => :date)

# Save data dimension
n = size(data, 1)

# Create functions to extract year, week, and day from time stamp
yearly(x) = (x - floor(x, Dates.Year)).value / 3.154e+10
weekly(x) = (x - floor(x, Dates.Week)).value / 6.048e+8
daily(x) = (x - floor(x, Dates.Day)).value / 8.64e+7

# Extract explanatory features, contextual features, and response
x = DataFrames.select(data, DataFrames.Not([:date, :Appliances, :rv1, :rv2]))
z = DataFrames.select(
    data, 
    :date => DataFrames.ByRow(x -> float(weekly(x) > 5 / 7)) => :weekend,
    :date => DataFrames.ByRow(x -> cospi(2 * yearly(x))) => :monthcos,
    :date => DataFrames.ByRow(x -> sinpi(2 * yearly(x))) => :monthsin,
    :date => DataFrames.ByRow(x -> cospi(2 * weekly(x))) => :daycos,
    :date => DataFrames.ByRow(x -> sinpi(2 * weekly(x))) => :daysin,
    :date => DataFrames.ByRow(x -> cospi(2 * daily(x))) => :hourcos,
    :date => DataFrames.ByRow(x -> sinpi(2 * daily(x))) => :hoursin
    )
y = DataFrames.select(data, :Appliances => DataFrames.ByRow(log) => :Appliances)

# Generate indices of training, validation, and testing sets
id = 1:n
train_id = StatsBase.sample(id, round(Int64, n * 0.6), replace = false)
id = setdiff(id, train_id)
valid_id = StatsBase.sample(id, round(Int64, n * 0.2), replace = false)
id = setdiff(id, valid_id)
test_id = id

# Generate training, validation, and testing sets
x_train = Matrix(x[train_id, :])
z_train = Matrix(z[train_id, :])
y_train = Matrix(y[train_id, :])[:]
x_valid = Matrix(x[valid_id, :])
z_valid = Matrix(z[valid_id, :])
y_valid = Matrix(y[valid_id, :])[:]
z_test = Matrix(z[test_id, :])

# Set network configuration
m, p = size(z_train, 2), size(x_train, 2)
n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
hidden_layers = repeat([n_neuron], 3)

# Fit contextual lasso
fit = ContextualLasso.classo(
    x_train, z_train, y_train, 
    x_valid, z_valid, y_valid, 
    verbose = false, relax = true, 
    hidden_layers = hidden_layers
    )

# Determine sparsity levels
coef = ContextualLasso.coef(fit, z_test)[:, 2:end]
sparsity = map(x -> sum(x .!= 0) / length(x), eachrow(coef))

# Plot coefficients
Gadfly.plot(
    x = Dates.hour.(data.date[test_id]), 
    y = sparsity,
    Gadfly.Geom.boxplot(suppress_outliers = true),
    Gadfly.Guide.xlabel("Hour of day"),
    Gadfly.Guide.ylabel("Prop. nonzeros"),
    Gadfly.Coord.cartesian(xmin = 0, xmax = 23, ymin = 0, ymax = 0.45),
    Gadfly.Theme(default_color = "black", plot_padding = [0Gadfly.mm])
    ) |> 
    Gadfly.PDF("Figures/energy.pdf", 5Gadfly.inch, 2Gadfly.inch)