import shap
from sklearn.model_selection import train_test_split
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import trange
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# Specify data paths here
FEATURE_PATH = ''
SHAPLEY_PATH = ''

# Data loader
class SHAPTrain(Dataset):
    def __init__(self, Feat, Shap):
        self.Feat = Feat
        self.Shap = Shap

    def __len__(self):
        ''' get total number of samples in dataset '''
        return self.Feat.shape[0]
    
    def __getitem__(self, index):
        ''' get 1D tensor of weights and respective payoffs'''
        return (
            self.Feat[index, :].float(),
            self.Shap[index, :].float()
        )

class SHAPTest(Dataset):
    def __init__(self, Feat, Shap):
        self.Feat = Feat
        self.Shap = Shap

    def __len__(self):
        ''' get total number of samples in dataset '''
        return self.Feat.shape[0]
    
    def __getitem__(self, index):
        ''' get 1D tensor of weights and respective payoffs'''
        return (
            self.Feat[index].float(),
            self.Shap[index].float()
        )

class MLP_toshap(nn.Module):
    ''' MLP model '''
    def __init__(self, input_size, hidden_size, output_size, n_layers, drop_prob):
        super().__init__()
        layers = []
        for i in range(n_layers-1):
            layers += [
                nn.Linear(input_size, hidden_size),
                nn.ReLU(inplace=True),
                nn.Dropout(drop_prob)
            ]
            input_size = hidden_size

        # Add output layer
        layers += [
            nn.Linear(input_size, output_size),
            ]
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)

def train(model, loader, optimizer, loss_fn, device):

    running_loss = 0
    
    for i, (X_batch, y_batch) in enumerate(loader):

        # Zero parameter gradients
        optimizer.zero_grad()

        # Forward pass
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        y_hat = model(X_batch)

        loss = loss_fn(y_batch, y_hat)

        # Backward pass
        loss.backward()

        # Step
        optimizer.step()

        # Store
        running_loss += loss.item()

    return running_loss


# Start experiment
device = 'cuda'
train_splits_tiny = np.arange(0.001, 1, 0.001)
EPOCHS = 100

# Load obtained features and shapley value
Features = torch.load(FEATURE_PATH) 
Shapvals = torch.load(SHAPLEY_PATH)

# Use smaller increments in the beginning
train_splits = train_splits_tiny

# Build model
model_shap = MLP_toshap(input_size=Features.shape[1],
                        hidden_size=128,
                        n_layers=3,
                        output_size=Features.shape[1],
                        drop_prob=0.1,
                        ).to(device)

# Set optimizer and objective
optimizer = torch.optim.Adam(model_shap.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

MSEs_overall = torch.zeros(len(train_splits))
MAEs_overall = torch.zeros(len(train_splits))

# Start the splits
for iter, train_split in enumerate(train_splits):

    print(f'Train split = {train_split} | iter {iter}')

    # Split the data to train test
    X_train, X_test, Y_train, Y_test = train_test_split(Features, Shapvals, train_size=train_split, random_state=42)

    # Make data loaders
    train_set = SHAPTrain(X_train, Y_train)
    train_loader = DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True)

    test_set = SHAPTest(X_test, Y_test)
    test_loader = DataLoader(test_set, batch_size=1)

    for epoch in trange(EPOCHS):
        # Make sure gradient tracking is on, and do a pass over the data
        model_shap.train(True)
        avg_train_loss = train(model_shap, train_loader, optimizer, loss_fn, device)

    # Test model on test data
    MSEs = torch.zeros(len(test_loader))
    MAEs = torch.zeros(len(test_loader))

    with torch.no_grad():
        model_shap.eval()
        
        for i, (features, actual_shap) in enumerate(test_loader):

            features, actual_shap = features.to(device), actual_shap.to(device)

            # Predict
            pred_shap = model_shap(features)

            # Store error
            MSEs[i] = loss_fn(actual_shap, pred_shap)
            MAEs[i] = torch.abs(actual_shap - pred_shap).mean()

        MSEs_overall[iter] = MSEs.mean()
        MAEs_overall[iter] = MAEs.mean()

    # Intermediate saving
    if (iter % 100) == 0:
        torch.save(MSEs_overall, 'MelbourneMSEs.pt')

# Plotting
plt.plot(train_splits, np.sqrt(MSEs_overall), '.-', linewidth=1, c='r')
plt.grid(True, alpha=0.3)
plt.xlabel('Training fraction', fontsize=14)
plt.ylabel(r'$RMSE(\phi, \hat{\phi})$', fontsize=14)
plt.yscale('log')
plt.xticks(np.arange(0, 1.1, 0.1), fontsize=14)
plt.yticks(fontsize=14)
sns.despine()
plt.show()
