import jax
import numpy as np
import jax.numpy as jnp
from models.cvx_relu_mlp import CVX_ReLU_MLP
from optimizers.admm import admm
from os.path import dirname, join, abspath
import os
import pickle
import sys


# manually change the following variables DATASET and MODEL
DATASET = 'imbd'
MODEL = 'gpt2' # change to gpt2-ft, gpt2-da
OUTPUT_DIR = '/home/miria/Downloads/ZACH/results/' # can change to relative directory 





dataset_rel_path = join('datasets', 'DATA_LM/convex1')
dirname(abspath('content'))
project_root = dirname(abspath('content'))
path = join(project_root, dataset_rel_path)
print('path to datasets is = ', path)

dats_training = []
labels_training = []
print('Loading data')
for i in range(1, 10):
        training_file_name_pos = 'Poslast_hidden_states_' + str(i)+'.npy'
        training_file_name_neg = 'Neglast_hidden_states_' + str(i)+'.npy'
        with open(join(path, training_file_name_pos), 'rb') as ftrain:
          X = np.load(ftrain)
          dats_training += [X]
          labels_training += [np.ones(X.shape[0])]
        with open(join(path, training_file_name_neg), 'rb') as ftrain:
          X = np.load(ftrain)
          dats_training += [X]
          labels_training += [-1*np.ones(X.shape[0])]
A = jnp.concatenate([jnp.array(dat_training)
                    for dat_training in dats_training], axis=0)
y = jnp.concatenate([jnp.array(labels)
                    for labels in labels_training], axis=0)
print('Finished loading data!') 

#Reshape, shuffle, and split data
shape_A = A.shape
A = A.reshape(shape_A[0],shape_A[1]*shape_A[2])

n = 25000
J = np.random.permutation(n)
A = A[J]
y = y[J]

ntr = 20000
ntst = n-ntr
Atr = A[:ntr]
Atst = A[ntr+1:]
ytr = y[:ntr]
ytst = y[ntr+1:]

del A

#Train model
model = CVX_ReLU_MLP(Atr, ytr, 10, 0.001, 0.01, jax.random.key(0) )

model.init_model()
model.Xtst = Atst
model.ytst = ytst

admm_params = dict(rank = 10, beta = 0.001,
                   gamma_ratio = 1, admm_iters = 5, pcg_iters = 30, check_opt = False)

print('Training model')

_ , metrics = admm(model, admm_params)

print(metrics['train_acc'])
print(metrics['val_acc'])


filename = f"{DATASET}_{MODEL}_seed.pkl"
   
# Create the subfolder path
model_dir = os.path.join(OUTPUT_DIR, MODEL)

# Create the subfolder if it doesn't exist
os.makedirs(model_dir, exist_ok=True)

# Define the full path for the pickle file
# CHECK filename correctly defined #########################################################################
pickle_file_path = os.path.join(model_dir, filename)

# Save the pickle file to the specified directory
with open(pickle_file_path, 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)