{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import json\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "\n",
    "sys.path.append(os.path.realpath('../..'))\n",
    "from toy.ops import KDLoss_TS, KDLoss, KDLoss_min_TS, KDLoss_min, err, KDLoss_min_T1\n",
    "import toy.ops as ops\n",
    "import toy.data as data\n",
    "import toy.net as net\n",
    "import toy.train as train\n",
    "import toy.ground_truth as gt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def train(\n",
    "    dir,\n",
    "    target_function_path,\n",
    "    init_net_path,\n",
    "    train_dataset_path,\n",
    "    test_data_path,\n",
    "    model_config={\n",
    "        'rho': 0.0,\n",
    "        'T': 10.0,\n",
    "        'teacher_reduction': 1.0,\n",
    "        'datanum': 4096,\n",
    "        'regenerate_data': False\n",
    "    },\n",
    "    training_strategry={\n",
    "        'batch_size': 4096,\n",
    "        'lr': 0.01,\n",
    "        'epoch': 4096,\n",
    "        'test_interval': 64,\n",
    "        'display_interval': 16,\n",
    "        'save_interval': 64,\n",
    "        'record_interval': 8,\n",
    "        'test_datanum': 32768,\n",
    "    },\n",
    "    device_name='cuda:0',\n",
    "    seed=0,\n",
    "):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "    if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "\n",
    "    device = torch.device(device_name)\n",
    "\n",
    "    # get net and target_function\n",
    "    target = torch.load(target_function_path, map_location=device)\n",
    "    net = torch.load(init_net_path, map_location=device)\n",
    "    init_weight = net.vec().detach()\n",
    "\n",
    "    # get data\n",
    "    train_dataset = torch.load(train_dataset_path, map_location=device)\n",
    "    train_dataset.device = device\n",
    "    if 'datanum' in model_config.keys():\n",
    "        train_dataset.online = False\n",
    "        train_dataset.datanum = model_config['datanum']\n",
    "        if 'regenerate_data' in model_config.keys() and model_config['regenerate_data'] == True:\n",
    "            with torch.no_grad():\n",
    "                train_dataset.generate_data()\n",
    "\n",
    "    train_dataloader = DataLoader(\n",
    "        train_dataset,\n",
    "        batch_size=min(training_strategry['batch_size'], len(train_dataset)),\n",
    "        shuffle=True\n",
    "    )\n",
    "\n",
    "    test_dataset = torch.load(test_data_path, map_location=device)\n",
    "    test_dataset.device = device\n",
    "    test_dataset.datanum = training_strategry['test_datanum']\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=training_strategry['batch_size'])\n",
    "\n",
    "    # loss\n",
    "    Loss = KDLoss(model_config['rho'], model_config['T'])\n",
    "    if model_config['T'] == 1.0:\n",
    "        Loss_min = KDLoss_min_T1(model_config['rho'])\n",
    "    else:\n",
    "        Loss_min = KDLoss_min_TS(model_config['rho'], model_config['T'])\n",
    "\n",
    "    # if data is fixed, do them before training\n",
    "    if not train_dataset.online:\n",
    "        with torch.no_grad():\n",
    "            train_target = model_config['teacher_reduction'] * target(train_dataset[:][1])\n",
    "        train_loss_min = Loss_min(train_target, (train_target > 0).to(train_target.dtype))\n",
    "\n",
    "    if not test_dataset.online:\n",
    "        with torch.no_grad():\n",
    "            test_target = model_config['teacher_reduction'] * target(test_dataset[:][1])\n",
    "        test_loss_min = Loss_min(test_target, (test_target > 0).to(test_target.dtype))\n",
    "\n",
    "    # optim\n",
    "    optimizer = optim.Adam(net.parameters(), lr=training_strategry['lr'])\n",
    "\n",
    "    # # print/save\n",
    "    # data\n",
    "    if not os.path.exists(dir + '/dataset'):\n",
    "        os.makedirs(dir + '/dataset')\n",
    "    torch.save(train_dataset, dir + '/dataset/train_dataset')\n",
    "    torch.save(test_dataset, dir + '/dataset/test_dataset')\n",
    "\n",
    "    # network\n",
    "    if not os.path.exists(dir + '/network'):\n",
    "        os.makedirs(dir + '/network')\n",
    "    torch.save(target, dir + '/network/target')\n",
    "    torch.save(net, dir + '/network/init_net')\n",
    "\n",
    "    # trian log\n",
    "    if not os.path.exists(dir + '/log'):\n",
    "        os.makedirs(dir + '/log')\n",
    "    with open(dir + '/log/train.csv', 'w') as f:\n",
    "        f.write('epoch,train_loss,train_error,weight_change\\n')\n",
    "    with open(dir + '/log/test.csv', 'w') as f:\n",
    "        f.write('epoch,test_loss,test_error\\n')\n",
    "\n",
    "    # config\n",
    "    if not os.path.exists(dir + '/config'):\n",
    "        os.makedirs(dir + '/config')\n",
    "    with open(dir + '/config/json', 'w', encoding='utf-8') as json_file:\n",
    "        json.dump(model_config, json_file, ensure_ascii=False, indent=4)\n",
    "    with open(dir + '/config/train.json', 'w', encoding='utf-8') as json_file:\n",
    "        json.dump(training_strategry, json_file, ensure_ascii=False, indent=4)\n",
    "\n",
    "    timer = time.time()\n",
    "\n",
    "    # run epoch\n",
    "    for epoch in range(training_strategry['epoch']):\n",
    "\n",
    "        # train\n",
    "        train_loss = torch.tensor(0.0)\n",
    "        if epoch % training_strategry['record_interval'] == 0:\n",
    "            train_error = torch.tensor(0.0)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        for index, input in train_dataloader:\n",
    "\n",
    "            output = net(input)\n",
    "            if train_dataset.online:\n",
    "                with torch.no_grad():\n",
    "                    target_logits = model_config['teacher_reduction'] * target(input)\n",
    "                    loss_min = torch.mean(\n",
    "                        Loss_min(target_logits, (target_logits > 0).to(target_logits.dtype))\n",
    "                    )\n",
    "            else:\n",
    "                target_logits = train_target[index]\n",
    "                loss_min = torch.mean(train_loss_min[index])\n",
    "\n",
    "            loss = Loss(output, target_logits) - loss_min\n",
    "            loss.backward()\n",
    "\n",
    "            with torch.no_grad():\n",
    "                if epoch % training_strategry['record_interval'] == 0:\n",
    "                    train_error = train_error + torch.sum(err(output, target_logits))\n",
    "                train_loss = train_loss + loss.item() * len(input)\n",
    "\n",
    "        optimizer.step()\n",
    "        with torch.no_grad():\n",
    "            train_loss = train_loss / len(train_dataset)\n",
    "        # scheduler.step(train_loss)\n",
    "        # scheduler.step()\n",
    "\n",
    "        # record\n",
    "        if epoch % training_strategry['record_interval'] == 0:\n",
    "            with torch.no_grad():\n",
    "                train_error = train_error / len(train_dataset)\n",
    "                weight_change = torch.norm(net.vec() - init_weight)\n",
    "\n",
    "            with open(dir + '/log/train.csv', 'a') as f:\n",
    "                f.write(\n",
    "                    '{:05d},{:.20e},{:.020f},{:.20e},\\n'.format(\n",
    "                        epoch,\n",
    "                        train_loss.item(),\n",
    "                        train_error.item(),\n",
    "                        weight_change.item(),\n",
    "                    )\n",
    "                )\n",
    "\n",
    "        # display\n",
    "        if epoch % training_strategry['display_interval'] == 0:\n",
    "            print(\n",
    "                'TRAIN==> epoch:{:5d}, train_loss:{:.06e}, train_err:{:.5f}, weight_change:{:.5f}'\n",
    "                .format(\n",
    "                    epoch,\n",
    "                    train_loss.item(),\n",
    "                    train_error.item(),\n",
    "                    weight_change.item(),\n",
    "                )\n",
    "            )\n",
    "\n",
    "        # save model\n",
    "        if (epoch+1) % training_strategry['save_interval'] == 0:\n",
    "            print(\n",
    "                'Save student net at epoch {:d} with loss {:.4e}'.format(epoch, train_loss.item())\n",
    "            )\n",
    "            torch.save(net, dir + '/network/student_epoch-{:06d}'.format(epoch))\n",
    "            # torch.save(net, dir + '/network/student')\n",
    "\n",
    "        # test\n",
    "        if (epoch+1) % training_strategry['test_interval'] == 0:\n",
    "\n",
    "            test_loss = torch.tensor(0.0)\n",
    "            test_error = torch.tensor(0.0)\n",
    "\n",
    "            for index, input in test_dataloader:\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    output = net(input)\n",
    "                    if test_dataset.online:\n",
    "                        target_logits = model_config['teacher_reduction'] * target(input)\n",
    "                        loss_min = torch.mean(\n",
    "                            Loss_min(target_logits, (target_logits > 0).to(target_logits.dtype))\n",
    "                        )\n",
    "\n",
    "                    else:\n",
    "                        target_logits = test_target[index]\n",
    "                        loss_min = torch.mean(test_loss_min[index])\n",
    "\n",
    "                    loss = Loss(output, target_logits) - loss_min\n",
    "                    test_loss = test_loss + loss.item() * len(input)\n",
    "                    test_error = test_error + torch.sum(err(output, target_logits))\n",
    "\n",
    "            test_loss = test_loss / len(test_dataset)\n",
    "            test_error = test_error / len(test_dataset)\n",
    "\n",
    "            # record\n",
    "            with open(dir + '/log/test.csv', 'a') as f:\n",
    "                f.write(\n",
    "                    '{:05d},{:.20e},{:.020f},\\n'.format(\n",
    "                        epoch,\n",
    "                        test_loss.item(),\n",
    "                        test_error.item(),\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            # display\n",
    "            print(\n",
    "                'TEST===> epoch:{:5d}, test__loss:{:.06e}, test__err:{:.5f}'.format(\n",
    "                    epoch,\n",
    "                    test_loss.item(),\n",
    "                    test_error.item(),\n",
    "                )\n",
    "            )\n",
    "            print('time used {:.02f}s\\n'.format(time.time() - timer))\n",
    "            timer = time.time()\n",
    "\n",
    "    torch.save(net, dir + '/network/student')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend",
     "outputPrepend"
    ]
   },
   "outputs": [],
   "source": [
    "generate_dir = '../../experiment/KD_training/Teacher_Gaussian_Stendent_Real_NN'\n",
    "train(\n",
    "    dir = '../../experiment/05c/teacher',\n",
    "    target_function_path='../../experiment/KD_training/Gaussian_Function/function',\n",
    "    init_net_path=generate_dir + '/init_net',\n",
    "    train_dataset_path=generate_dir + '/train_data',\n",
    "    test_data_path=generate_dir + '/test_data',\n",
    "    model_config={\n",
    "        'rho': 0.0,\n",
    "        'T': 1.0,\n",
    "        'teacher_reduction':1.0,\n",
    "    },\n",
    "    training_strategry={\n",
    "        'batch_size': 4096,\n",
    "        'lr': 0.0001,\n",
    "        'epoch': 512*20 + 1,\n",
    "        'test_interval': 64,\n",
    "        'display_interval': 16,\n",
    "        'save_interval': 128,\n",
    "        'record_interval': 8,\n",
    "        'test_datanum': 32768,\n",
    "    }\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python38364bit5172cfd22f324156974f51e47e17b07a",
   "display_name": "Python 3.8.3 64-bit"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}