import os
import shutil
import sys
import time
sys.path.append('./')
sys.path.append('../')
import warnings
warnings.filterwarnings("ignore")
import wandb
from wandb.keras import WandbCallback
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
import tensorflow_datasets as tfds
from deel.utils.deep_models import get_mlp
from deel.lip.activations import MaxMin, GroupSort, GroupSort2, FullSort


import random
import numpy as np
from tensorflow.python.keras.layers import ReLU
import matplotlib.pyplot as plt
from deel.utils.yaml_to_params import load_yaml_config,getParams, getFunctionFromModules, dumdict2yaml
from deel.datasets.load_dataset import load_dataset
from deel.utils.yaml_loader import load_model, loadFunctionList
from deel.utils.yaml_loader import load_optimizer_and_loss
from tensorflow.keras.callbacks import ReduceLROnPlateau, LearningRateScheduler
import foolbox as fb
from foolbox.attacks import *
from deel.utils.adversarial_utils import compute_adversarial_robustness,wandb_log_robustness
from tensorflow.keras.optimizers import Adam,SGD
from tensorflow.keras.callbacks import ReduceLROnPlateau, LearningRateScheduler, Callback
from tensorflow.keras.losses import categorical_crossentropy, binary_crossentropy
from tensorflow.keras import backend as K
from deel.utils.custom_scheduler import WarmUpCosineDecayScheduler
from deel.lip.layers import SpectralConv2D, SpectralDense, FrobeniusDense
from deel.lip.losses import (HKR,
                             KR,
                             HingeMargin,
                             MulticlassKR,
                             MulticlassHinge,
                             MulticlassHKR)
from deel.lip.custom_losses import wasserstein_acc,HKR_binary_auto, HKR_multiclass_auto
import yaml
import matplotlib.pyplot as plt
from tensorflow.keras.metrics import top_k_categorical_accuracy,categorical_accuracy
from tensorflow_riemopt.variable import assign_to_manifold
import tensorflow_riemopt as riemopt
from tensorflow_riemopt.manifolds import StiefelCayley,Euclidean
from tensorflow_riemopt.optimizers.riemannian_adam import RiemannianAdam
from deel.datasets.imagenet_dataset import imagenet_dataset
from deel.utils.lip_utils import redressage_grad,get_random_name,fine_tune_model_last,set_global_variable
#from tensorflow.keras.callbacks import CallbackList

from tensorflow.keras import mixed_precision
from deel.lip.normalizers import set_stop_grad_spectral,set_grad_passthrough_bjorck
import deel.lip.normalizers
def apply_constraints(model,verbose = False):
    print("constraints")
    for w in model.layers :
        if((type(w) is SpectralConv2D) or (type(w) is FrobeniusDense)):
            w.kernel.assign(w.kernel_constraint.__call__(w.kernel))
            if verbose :
                print(w,w.kernel_constraint)


def compute_singular(var,one = False):
    for w in var:
        w_val = K.get_value(w)

        w = w_val.reshape((-1, w_val.shape[-1]))
        _,coeff,_=np.linalg.svd(w)
        print(w_val.shape,coeff[0])
        if one :
            break

def save_metrics(model,test_dataset):

    results = model.evaluate(test_dataset, steps = 195)
    print(results)
    wandb.log({'test_accuracy': results[-2]})
    wandb.log({'test_top_k': results[-1]})


def init_config(config,xp_id):
    name = config['run_name']
    rep = config['result_path']
    set_global_variable(config,verbose = True)
    folder = rep + name + "/"
    if not os.path.exists(folder):
        os.makedirs(folder)
    #if not os.path.exists(folder + "log"):
    #    os.makedirs(folder + "log")
    if not os.path.exists(folder + "curves"):
        os.makedirs(folder + "curves")
    if not os.path.exists(folder + "models"):
        os.makedirs(folder + "models")
    with open(folder +'config.yml', 'w') as file:
        documents = yaml.dump(config, file)
    with open(folder +xp_id+'_config.yml', 'w') as file:
        documents = yaml.dump(config, file)

def init_wand_db(full_config,xp_id):
    os.environ["WANDB_API_KEY"] = "b43035018da7e61f798d9bf228eeb6c141debe26"
    os.environ["WANDB_MODE"] = "online"
    project_name = "imagenet"
    if 'project_name' in full_config:
        project_name = full_config['project_name']
    print(project_name)
    wandb.init(project=project_name, sync_tensorboard=False, name= xp_id, group=full_config['run_name'],
              config={
        "loss" : full_config['loss']['type'],
        "loss_param" : str(full_config['loss']['params']),
        "network" : str(full_config['network']['params']),
        "dataset": "imagenet"
    })

def save_model(model, config,sufix = ""):
    name = config['run_name']
    rep = config['result_path']
    folder = rep + name + "/models/"
    #model = model.vanilla_export()
    #model_json = model.to_json()
    #with open(os.path.join(folder, name + ".json"), "w") as json_file:
    #    json_file.write(model_json)
        # serialize weights to HDF5
    model.save_weights(os.path.join(folder, name +sufix+ ".h5"))
    #model.save(os.path.join(folder, name + "_full.h5"))
    #if not os.path.exists(folder + name):
    #    os.makedirs(folder + name)
    #model.save_weights(folder)
def gradient_norm(grads,verbose = False):
    total = 0
    nb = 0
    for i,g in enumerate(grads):
        #if verbose:
        #    tf.print("layer : ",i,"shape :", tf.shape(g),"norm",tf.norm(g)," mean_value ",tf.reduce_mean(tf.abs(g)),output_stream=OUTSTREAM )
        if g is not None:
            total+=tf.norm(g)
        nb+=1
    return total/nb
def diff_grad(grad_1,grad_2):
    total_dif =0
    for g1,g2 in zip(grad_1,grad_2):
        dif = (g2-g1)
        dif = tf.math.abs(dif)
        total_dif+=tf.reduce_sum(dif)
    return total_dif
    #print("diff",total_dif)


def add_model_regularizer_loss(model, lambda_orth = 0):
    loss=0
    for l in model.layers:
        #if hasattr(l,'layers') and l.layers: # the layer itself is a model
        #    loss+=add_model_loss(l)
        if hasattr(l,'kernel_regularizer') and l.kernel_regularizer and lambda_orth!=0:
            loss+=lambda_orth*l.kernel_regularizer(l.kernel)
        if hasattr(l,'bias_regularizer') and l.bias_regularizer:
            loss+=l.bias_regularizer(l.bias)
    loss += tf.reduce_sum(model.losses)
    return loss

def normalize_grad(g):
    if tf.size(tf.shape(g))==4:
        return g/tf.norm(g)
    return g


@tf.function
def train_step(x, y,
               model,
               loss_fn,
               t,
               optimizer,
               optim_marg,
               lambda_orth = 0,
               optim_margin = False,
               grad_coeff = 0.1,
               mixed = False,
               spectral = True,
               finetune_biases = False,
               redress = False,
               verbose = False):
    training = True
    weights = model.trainable_weights
    if finetune_biases:
        optim_margin = False
        redress = False
        training = True
        weights = weights[-1:]
    if optim_margin:
        weights = weights+[t]
    with tf.GradientTape() as w_tape :
        logits = model(x, training=training)
        loss_value = loss_fn(y, logits)
        #regul = add_model_regularizer_loss(model,lambda_orth=lambda_orth)
        regul = 0
        final_loss = loss_value+ regul
        if mixed :
            final_loss = optimizer.get_scaled_loss(final_loss)

    #tf.print(weights)
    grads = w_tape.gradient(final_loss, weights)
    if mixed :
        grads = optimizer.get_unscaled_gradients(grads)
    if redress:
        if optim_margin:
            #tf.print("optim")
            grads[:-1] = redressage_grad(grads[:-1], weights[:-1], coeff = grad_coeff,spectral=spectral)
        else :
            grads = redressage_grad(grads, weights, coeff = grad_coeff,spectral=spectral)
    optimizer.apply_gradients(zip(grads, weights))


    #tf.print("final",diff_grad(saved_weights,new_weights),diff_grad(saved_weights,last_weights))
    logits = tf.cast(logits,tf.float32)
    top_k=top_k_categorical_accuracy( y,logits, k=5)
    acc=categorical_accuracy(y,logits)
    #tf.print(loss_value,tf.reduce_mean(top_k),tf.reduce_mean(acc))
    #train_acc_metric.update_state(y, logits)

    results = {"loss" :loss_value,"categorical_accuracy" :acc, "regul" :regul,"grad_norm" : gradient_norm(grads),"top_k_categorical_accuracy": top_k }
    y_pred = logits
    y_true = y

    H1 = tf.where(y_true==1,tf.reduce_min(y_pred), y_pred) ## set y_true at minimum on batch to avoid being the max
    vYtrue = tf.reduce_sum(y_pred * y_true, axis=1)
    maxOthers = tf.reduce_max(H1, axis=1)
    results["robustness"] = tf.reduce_mean(vYtrue)
    results["avg_value"] = tf.reduce_mean(tf.abs(y_pred))
    results["abs_margin"] = tf.reduce_mean(tf.abs(vYtrue-maxOthers))
    results["margin"] = tf.reduce_mean(vYtrue-maxOthers)
    results["margin_std"] = tf.math.reduce_std(vYtrue-maxOthers)
    return results



def fit_constraints(model,train,loss_fct,optimizer,config,
                    steps_per_epoch=50,
                    finetune_biases = False,
                    callbacks=[],
                    batch_size = 128,
                    epochs=20,
                    verbose=2,
                    mixed = False,
                    spectral = True,
                    test_dataset = None,
                    redress = False,
                    lambda_orth = 0,
                    grad_coeff = 0.1,
                    optim_margin=False):
    for c in callbacks:
        c.set_model (model)
    nb_epoch_change = int(1281167/(batch_size*steps_per_epoch))+1
    print("nb steps/epochs :",nb_epoch_change)
    dataset_it = train.__iter__()
    optim_prox =  SGD(learning_rate=1e-5)
    optim_marg =  Adam(learning_rate=1e-4)
    #print("apply constraints once")
    #apply_constraints(model)
    #print("apply constraints twice")
    #apply_constraints(model)
    model_vars=model.trainable_variables
    compute_singular(model_vars,one = True)
    logs = {}
    for c in callbacks:
        c.on_train_begin(logs=logs)
    metrics = {}
    metrics["val_loss"] = tf.metrics.Mean()
    metrics["val_acc"] = tf.metrics.Mean()
    val_acc = 0.
    val_k_acc = 0.
    for e in range(epochs):
        start_time = time.time()
        for c in callbacks:
            c.on_epoch_begin(e, logs=None)




        for batch in range(steps_per_epoch):
            #model.condense()
            x,y = next(dataset_it)
            for c in callbacks:
                c.on_batch_begin(batch, logs=None)
                #c.on_train_batch_begin(batch, logs=None)


            results= train_step(x, y,
                                model,
                                loss_fct,
                                loss_fct.margins,
                                optimizer,                                  
                                optim_marg,
                                mixed = mixed,
                                grad_coeff = grad_coeff,
                                spectral = spectral,
                                redress = redress,
                                finetune_biases=finetune_biases,
                                lambda_orth = lambda_orth,
                                optim_margin=optim_margin,verbose = (batch ==steps_per_epoch-1 ))

            logs ={}
            for k in results.keys():
                logs[k] = results[k].numpy().mean()
            for c in callbacks:
                c.on_train_batch_end(batch, logs=logs)
                #c.on_batch_end(batch, logs=logs)
            #print(e,batch,"imgs :",tf.reduce_mean(x).numpy(),tf.reduce_min(x).numpy(),tf.reduce_max(x).numpy())
            #print(e,batch,"y :",tf.reduce_mean(tf.argmax(y,axis=1)).numpy(),tf.reduce_min(tf.argmax(y,axis=1)).numpy(),tf.reduce_max(tf.argmax(y,axis=1)).numpy())
            #print(e,batch,"grads norm :",n1.numpy(), n2.numpy(),"lr :",optimizer.learning_rate.numpy())
            for k in results.keys():
                if k not in metrics:
                    metrics[k] = tf.metrics.Mean()
                metrics[k].update_state(results[k])

        total_time =time.time() - start_time
        #apply_constraints(model)
        logs = {k: metrics[k].result() for k in metrics.keys()}
        logs ['time']  = total_time
        if e%nb_epoch_change ==0 and test_dataset is not None and e!=0:
            results = model.evaluate(test_dataset, steps = 50000//batch_size)
            val_acc = results[-2]
            val_k_acc = results[-1]
            if e!= 0:
                print("OK Tank, reload the matrix !") 
                train, test_dataset, info = imagenet_dataset(batch_size = batch_size,
                                                             preprocess = full_config['scale_img'],
                                                             contrast_min = full_config['contrast_min'],
                                                             contrast_max = full_config['contrast_max'],
                                                             bright = full_config['bright'] ,
                                                             cutout = full_config['cutout'],
                                                             compute_train_val=False, 
                                                             write_dir=full_config['write_dir'],
                                                             shuffle = 1024,
                                                             verbose = True)
                dataset_it = train.__iter__()
                
        if e!= 0 and e%3200 ==0:
            save_model(model, config, sufix = "_"+str(e))
        logs ['val_acc']  = val_acc
        logs ['val_k_acc']  = val_k_acc
        for c in callbacks:
            c.on_epoch_end(e, logs=logs)
        print(f"Epoch {e+1}/{epochs}")
        print(f"time : {total_time:.2f}s *** loss:",metrics["loss"].result(), "accuracy:",metrics["categorical_accuracy"].result(), " top_5:",metrics["top_k_categorical_accuracy"].result() )
        print("max m",tf.reduce_max(loss_fct.margins).numpy(),"min m",tf.reduce_min(loss_fct.margins).numpy(), "men m",tf.reduce_mean(loss_fct.margins).numpy(),optim_margin)
        for k in metrics.keys():
            metrics[k].reset_states()
        sys.stdout.flush()

    for c in callbacks:
        c.on_train_end(logs=logs)

def data_range(data):
  dataset_it = data.__iter__()
  print("ok1")
  x,y = next(dataset_it)
  print("ok2")
  x = x.numpy()
  y = y.numpy()
  print(x.shape)
  print("data range :",x.min(), x.mean(), x.max()) 

tf.config.set_soft_device_placement(
    True
)



xp_id = get_random_name()
filename="./configs/imagenet_base.yml"
if len(sys.argv)>=2:
        filename =sys.argv[1]

tf.random.set_seed(random.randint(0,10000000))

full_config = load_yaml_config(filename)
mixed = full_config.get('mixed', False)
print("mixed",mixed)
if mixed: 
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)



init_config(full_config,xp_id)
init_wand_db(full_config,xp_id)

name = full_config['run_name']
print(filename,full_config['run_name'])
rep = full_config['result_path']
file = rep + name + "/"+xp_id+"_logfile.log"
OUTSTREAM = open(file, "w")
sys.stdout = OUTSTREAM
sys.stdout.flush()


redress = full_config.get('redress', False)
finetune_biases = full_config.get('finetune_biases', False)
spectral = full_config.get('spectral', True)
grad_coeff = full_config.get('grad_coeff', 0.1)
print("redress",redress)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

batch_size = getParams(full_config,'batch_size')
epochs = getParams(full_config,'epochs')
write_dir = '/mnt/terminus/imagenet/extracted/'
if "write_dir" in full_config.keys():
    write_dir = full_config['write_dir']


model = load_model(getParams(full_config,'network'))
if "weights" in full_config.keys():
    expe_name = full_config["weights"]["expe_name"]
    rep_mod = "results/"+expe_name+"/models/"
    model.load_weights(rep_mod+expe_name+'.h5')
finetune = full_config.get('finetune', False)
if finetune:
    fine_tune_model_last(model)
    model.summary()
optimizer = load_optimizer_and_loss(getParams(full_config,'optimizer'))
if mixed: 
    optimizer = mixed_precision.LossScaleOptimizer(optimizer)
callbacks = []
if "callbacks" in full_config.keys():
    callbacks = loadFunctionList(getParams(full_config,'callbacks'))
if "optim_margin" in full_config.keys():
    optim_margin = full_config['optim_margin']['margin']
loss_fct = load_optimizer_and_loss(getParams(full_config,'loss'))
print("var name",)
metrics = loadFunctionList(getParams(full_config,'metrics'))
print(optimizer.get_config(),optim_margin)
sys.stdout.flush()
if "margins" in full_config.keys():
    expe_name = full_config["margins"]["expe_name"]
    rep_mod = "results/"+expe_name+"/models/"
    margins_array = np.loadtxt(rep_mod+'margins.txt').reshape(1000)
    if "coeff" in full_config["margins"].keys():
        margins_array=margins_array*full_config["margins"]["coeff"]
    loss_fct.margins.assign(margins_array)
    print("margins ",margins_array.shape,tf.reduce_mean(loss_fct.margins).numpy())
steps_per_epoch=getParams(full_config,'steps_per_epoch')


    
lambda_orth = 0
if "lambda_orth" in full_config.keys():
    lambda_orth = full_config['lambda_orth']
print('lambda_orth :',lambda_orth)
sys.stdout.flush()

callbacks.append(WandbCallback(save_model=False))


model.compile(loss=loss_fct, optimizer=optimizer, metrics=metrics)
full_config['contrast_min'] = full_config.get('contrast_min', 1.)
full_config['contrast_max'] = full_config.get('contrast_max', 1.)
full_config['bright'] = full_config.get('bright', 0.)
full_config['cutout'] = full_config.get('cutout', False)
full_config['write_dir'] = full_config.get('write_dir','/mnt/terminus/imagenet/extracted/')

train, test_dataset, info = imagenet_dataset(batch_size = batch_size,
                                             preprocess = full_config['scale_img'],
                                             contrast_min = full_config['contrast_min'],
                                             contrast_max = full_config['contrast_max'],
                                             bright = full_config['bright'] ,
                                             cutout = full_config['cutout'],
                                             
                                             compute_train_val=False, 
                                             write_dir=full_config['write_dir'],
                                             shuffle = 1024,
                                             verbose = True)
#data_range(train)
#print("biases", model.trainable_weights[-1])
sys.stdout.flush()
fit_constraints(model,
                train,
                loss_fct,
                optimizer,
                full_config,
                spectral = spectral,
                steps_per_epoch=steps_per_epoch,
                callbacks=callbacks,
                epochs=epochs,
                verbose=2,
                mixed = mixed,
                grad_coeff = grad_coeff,
                batch_size = batch_size,
                redress = redress,
                test_dataset = test_dataset,
                optim_margin=optim_margin,
                lambda_orth = lambda_orth)






save_model(model, full_config)
save_metrics(model,test_dataset)
a_file = open(full_config['result_path'] + full_config['run_name'] + "/models/margins.txt", "w")
np.savetxt(a_file, loss_fct.margins.numpy())
a_file.close()
#make_curves(full_config, hist)
wandb.finish()
