import os
import sys
# sys.path.append(os.path.realpath('../..'))
sys.path.append(os.path.realpath('.'))
import toy.ops as ops
import toy.data as data
import toy.net as net
import toy.train as train
import toy.ground_truth as gt
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import re
import toy.ground_truth as gt

rhos = [1.0, 0.9999, 0.9995, 0.999, 0.995, 0.99, 0.95, 0.90]
# rhos = [1.0, 0.999999, 0.99999, 0.9999, 0.999]
# rhos = [1.0, 0.999999, 0.99999, 0.9999]
# rhos = np.concatenate([[1.0], 1 - np.exp(np.linspace(np.log(1e-6), np.log(1e-3), 15, endpoint=True))])
# epoch_list = [1535, 3455, 7423]
epoch = 5119
# epoch = 8959
device_name = 'cuda:0'
fold = 5

datanums = [16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32]

generate_dir = './experiment/KD_training/Teacher_NN_Stendent_Linear_NN_Infinite_Data'
current_dir = './experiment/KD_training/Imperfect_Teacher_Distillation'

datanum = 4096

for datanum in datanums:
    for i in range(fold):
        for rho in rhos:
            name = current_dir + '/rho-{:.04f}/datanum-{:06d}/fold-{:02d}'.format(rho, datanum, i)
            if not os.path.exists(name):
                print(name)
                train.train_TS_linear(
                    dir=name,
                    target_function_path='./experiment/KD_training/Gaussian_Function/function',
                    teacher_net_path='./experiment/05c/teacher/network/student_epoch-{:06d}'
                    .format(epoch),
                    init_net_path=generate_dir + '/init_net',
                    train_dataset_path=generate_dir + '/train_data',
                    test_data_path=generate_dir + '/test_data',
                    model_config={
                        'rho': rho,
                        'T': 10.0,
                        'teacher_reduction': 0.3,
                        'datanum': datanum,
                        'regenerate_data': True,
                    },
                    training_strategry={
                        'batch_size': 512,
                        'lr': 0.01,
                        'epoch': 4096 + 1,
                        'test_interval': 64,
                        'display_interval': 16,
                        'save_interval': 64,
                        'record_interval': 8,
                        'test_datanum': 32768,
                    },
                    seed=i,
                    device_name=device_name,
                )
