{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Getting started with our ultimate beginner guide!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## This tutorial will walk you through the basics of using the `usb` lighting package. Let's get started by training a FixMatch model on CIFAR-10!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/haoc/TorchSSL_Benchmark/Semi-supervised-learning-Hook/semilearn/nets/vit/vit.py:36: SyntaxWarning: assertion is always true, perhaps remove parentheses?\n",
      "  assert(H == self.img_size[0], f\"Input image height ({H}) doesn't match model ({self.img_size[0]}).\")\n",
      "/home/haoc/TorchSSL_Benchmark/Semi-supervised-learning-Hook/semilearn/nets/vit/vit.py:37: SyntaxWarning: assertion is always true, perhaps remove parentheses?\n",
      "  assert(W == self.img_size[1], f\"Input image width ({W}) doesn't match model ({self.img_size[1]}).\")\n"
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mCanceled future for execute_request message before replies were done"
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "The Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 1: define configs and create config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/bin/sh: 1: netstat: not found\n"
     ]
    }
   ],
   "source": [
    "config = {\n",
    "    'algorithm': 'fixmatch',\n",
    "    'net': 'wrn_28_2',\n",
    "    'use_pretrain': False,  # todo: add pretrain\n",
    "    'pretrain_path': None,\n",
    "\n",
    "    # optimization configs\n",
    "    'epoch': 3,\n",
    "    'num_train_iter': 150,\n",
    "    'num_eval_iter': 50,\n",
    "    'optim': 'SGD',\n",
    "    'lr': 0.03,\n",
    "    'momentum': 0.9,\n",
    "    'batch_size': 64,\n",
    "    'eval_batch_size': 64,\n",
    "\n",
    "    # dataset configs\n",
    "    'dataset': 'cifar10',\n",
    "    'num_labels': 40,\n",
    "    'num_classes': 10,\n",
    "    'img_size': 32,\n",
    "    'crop_ratio': 0.875,\n",
    "    'data_dir': './data',\n",
    "\n",
    "    # algorithm specific configs\n",
    "    'hard_label': True,\n",
    "    'uratio': 3,\n",
    "    'ulb_loss_ratio': 1.0,\n",
    "\n",
    "    # device configs\n",
    "    'gpu': 0,\n",
    "    'world_size': 1,\n",
    "    'distributed': False,\n",
    "}\n",
    "config = get_config(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 2: create model and specify algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 3: create dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir)\n",
    "train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)\n",
    "train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))\n",
    "eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 4: train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/haoc/miniconda3/envs/usb/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "[2022-08-21 18:19:28,969 INFO] confusion matrix\n",
      "[2022-08-21 18:19:28,970 INFO] [[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n",
      "[2022-08-21 18:19:28,973 INFO] evaluation metric\n",
      "[2022-08-21 18:19:28,973 INFO] acc: 0.0953\n",
      "[2022-08-21 18:19:28,974 INFO] precision: 0.0095\n",
      "[2022-08-21 18:19:28,974 INFO] recall: 0.1000\n",
      "[2022-08-21 18:19:28,975 INFO] f1: 0.0174\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model saved: ./saved_models/fixmatch_cifar10/latest_model.pth\n",
      "model saved: ./saved_models/fixmatch_cifar10/model_best.pth\n",
      "Epoch: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/haoc/miniconda3/envs/usb/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "[2022-08-21 18:19:38,818 INFO] confusion matrix\n",
      "[2022-08-21 18:19:38,820 INFO] [[0.         0.75079872 0.         0.         0.24920128 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.87155963 0.         0.         0.12844037 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.86666667 0.         0.         0.13333333 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.90634441 0.         0.         0.09365559 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.9147541  0.         0.         0.0852459  0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.88178914 0.         0.         0.11821086 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.93589744 0.         0.         0.06410256 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.88461538 0.         0.         0.11538462 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.75739645 0.         0.         0.24260355 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.80239521 0.         0.         0.19760479 0.\n",
      "  0.         0.         0.         0.        ]]\n",
      "[2022-08-21 18:19:38,822 INFO] evaluation metric\n",
      "[2022-08-21 18:19:38,823 INFO] acc: 0.0972\n",
      "[2022-08-21 18:19:38,823 INFO] precision: 0.0161\n",
      "[2022-08-21 18:19:38,824 INFO] recall: 0.0957\n",
      "[2022-08-21 18:19:38,825 INFO] f1: 0.0254\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model saved: ./saved_models/fixmatch_cifar10/latest_model.pth\n",
      "model saved: ./saved_models/fixmatch_cifar10/model_best.pth\n",
      "Epoch: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/haoc/miniconda3/envs/usb/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "[2022-08-21 18:19:49,039 INFO] confusion matrix\n",
      "[2022-08-21 18:19:49,040 INFO] [[0.         0.84025559 0.         0.         0.15974441 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.91743119 0.         0.         0.08256881 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.92698413 0.         0.         0.07301587 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.95166163 0.         0.         0.04833837 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.93770492 0.         0.         0.06229508 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.90415335 0.         0.         0.09584665 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.96474359 0.         0.         0.03525641 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.93589744 0.         0.         0.06410256 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.81656805 0.         0.         0.18343195 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.89221557 0.         0.         0.10778443 0.\n",
      "  0.         0.         0.         0.        ]]\n",
      "[2022-08-21 18:19:49,043 INFO] evaluation metric\n",
      "[2022-08-21 18:19:49,043 INFO] acc: 0.0997\n",
      "[2022-08-21 18:19:49,044 INFO] precision: 0.0168\n",
      "[2022-08-21 18:19:49,044 INFO] recall: 0.0980\n",
      "[2022-08-21 18:19:49,045 INFO] f1: 0.0249\n",
      "[2022-08-21 18:19:49,429 INFO] Best acc 0.0997 at epoch 2\n",
      "[2022-08-21 18:19:49,431 INFO] Training finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model saved: ./saved_models/fixmatch_cifar10/latest_model.pth\n",
      "model saved: ./saved_models/fixmatch_cifar10/model_best.pth\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(config, algorithm)\n",
    "trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 5: evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/haoc/miniconda3/envs/usb/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "[2022-08-21 18:21:23,656 INFO] confusion matrix\n",
      "[2022-08-21 18:21:23,658 INFO] [[0.         0.84025559 0.         0.         0.15974441 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.91743119 0.         0.         0.08256881 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.92698413 0.         0.         0.07301587 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.95166163 0.         0.         0.04833837 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.93770492 0.         0.         0.06229508 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.90415335 0.         0.         0.09584665 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.96474359 0.         0.         0.03525641 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.93589744 0.         0.         0.06410256 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.81656805 0.         0.         0.18343195 0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.89221557 0.         0.         0.10778443 0.\n",
      "  0.         0.         0.         0.        ]]\n",
      "[2022-08-21 18:21:23,661 INFO] evaluation metric\n",
      "[2022-08-21 18:21:23,661 INFO] acc: 0.0997\n",
      "[2022-08-21 18:21:23,662 INFO] precision: 0.0168\n",
      "[2022-08-21 18:21:23,662 INFO] recall: 0.0980\n",
      "[2022-08-21 18:21:23,663 INFO] f1: 0.0249\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'acc': 0.0996875,\n",
       " 'precision': 0.016786053719491927,\n",
       " 'recall': 0.09797262746277637,\n",
       " 'f1': 0.024902520800984422}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate(eval_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 6: predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "y_pred, y_logits = trainer.predict(eval_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 ('usb')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "29d8f0db1976a8fbb4040f75e07c81acdf04dbb74e20adf172a112c04d3c0b95"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
