import hdf5storage
import torch as t
from Core.NN_FLM import NN_FLM


def run(result_dir, train_data_set, test_data_set, fea_ext_net_type, fea_ext_net_structure,
        fuzzy_para_alpha=0.2, fuzzy_para_beta=0.8,
        max_epoch_num=500, min_loss_gap=1e-8, decay_lr_how_often=200, decay_lr_rate=0.9,
        class_exemplar_num=5,
        batch_size=5000,
        device=t.device('cpu')):
    nn_flm = NN_FLM(sample_type_shape=train_data_set.get_data_information(),
                    fea_ext_net_type=fea_ext_net_type,
                    fea_ext_net_structure=fea_ext_net_structure,
                    fuzzy_para_alpha=fuzzy_para_alpha, fuzzy_para_beta=fuzzy_para_beta,
                    gamma_FPL=1.0, gamma_Reg=0.1,
                    model_design_file=result_dir + 'network-structure.txt')
    nn_flm.build_optimizer(fea_ext_net_lr=1e-3, equ_rel_net_lr=1e-3)
    nn_flm.to(device)
    nn_flm.apply(NN_FLM.init_fn)
    nn_flm.train_net_with_log(train_data_set=train_data_set,
                              max_epoch_num=max_epoch_num, min_loss_gap=min_loss_gap, batch_size=batch_size, device=device,
                              decay_lr_how_often=decay_lr_how_often, decay_lr_rate=decay_lr_rate,
                              train_log_file=result_dir + 'train-log',
                              test_data_set=test_data_set)
    nn_flm.save_model(file_name=result_dir + 'train_end.model')
    exemplar = nn_flm.select_class_exemplar(train_data_set=train_data_set,
                                            class_exemplar_num=class_exemplar_num,
                                            device=device)
    hdf5storage.savemat(file_name=result_dir + 'exemplar_sam_indices.mat',
                        mdict={'exemplar_sam_indices': exemplar['exemplar_sam_indices']})
    pre_result = nn_flm.predict(exemplar_data_set=exemplar['exemplar_data_set'], test_data_set=test_data_set,
                                device=device)
    NN_FLM.SavePredictResult(result=pre_result,
                             file_name=result_dir + 'exemplar-predict-result.txt')
