import os
import sys

from deel.utils.yaml_to_params import getParams, getFunctionFromModules
#from model_samples.model_fftlip_samples import FFTlipModel
from deel.utils.lip_model import deel_lip_vgg, vgg_small_images, get_wass_MLP, vgg_large_images
from deel.utils.lip_depth_model import deel_lip_depthconv, ortho_cifar, ortho_test
from deel.utils.lip_res_model import ResNet50_lip, ResNet50
#from runs.model_samples.model_adv_samples import adversarialCNN


from deel.lip.activations import MaxMin, GroupSort, GroupSort2, FullSort, Absolute
from tensorflow.keras.models import model_from_json
from tensorflow.keras.metrics import binary_accuracy
from tensorflow.keras.optimizers import Adam, SGD, Nadam, Adamax
#from tensorflow.keras.optimizers.experimental import Adafactor


from tensorflow.keras.callbacks import ReduceLROnPlateau, LearningRateScheduler, Callback
from tensorflow.keras.losses import categorical_crossentropy, binary_crossentropy, CategoricalCrossentropy, BinaryCrossentropy
from tensorflow.keras import backend as K
from deel.utils.custom_scheduler import WarmUpCosineDecayScheduler, TimeStepScheduler, SGDRScheduler, LinearScheduler, InvSGDRScheduler
from tensorflow_riemopt.optimizers.riemannian_adam import RiemannianAdam
#from tensorflow_riemopt.optimizers.riemannian_gradient_descent import RiemannianSGD
from deel.lip.losses import (HKR,
                             KR,
                             HingeMargin,
                             MulticlassKR,
                             MulticlassHinge,
                             MulticlassHKR)

from deel.lip.custom_losses import (wasserstein_acc,
                                    BinaryCrossentropyLip,
                                    HKR_binary_auto, HKR_binary,
                                    HKR_multiclass_auto,
                                    HKR_cross_ent,
                                    cosin_lip,
                                    HKR_multiclass_hinge_auto,
                                    CategoricalCrossentropyLip)
from tensorflow.keras.layers import ReLU, Softmax

getFunction = getFunctionFromModules(sys.modules[__name__])


def load_network(net_path, net_name):
    modelPath2 = os.path.join(net_path, net_name)
    json_file = open(modelPath2+'.json', 'r')
    loaded_model_json = json_file.read()
    json_file.close()
    loaded_model = model_from_json(loaded_model_json, custom_objects=CUSTOM_OBJECTS)
    # load weights into new model
    loaded_model.load_weights(modelPath2+'.h5')
    # loaded_model.summary()
    print("Loaded model from disk")
    return loaded_model


# register(load_network)


def load_model(network_config):
    global type_to_function
    return getFunction(network_config['type'])(**getParams(network_config, 'params', getFunction))


def loadFunctionList(configList):
    functions = []
    for mm in configList:
        fct = getFunction(mm['type'])
        if isinstance(fct, str):
            functions.append(fct)  # simple functions
        else:
            functions.append(fct(**getParams(mm, 'params', getFunction)))
    return functions


def load_optimizer_and_loss(config):
    if 'params' in config:
        return getFunction(config['type'])(**getParams(config, 'params', getFunction))
    else:
        return getFunction(config['type'])
