{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Custom Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## In this tutorial, we provide an example of adapting usb to custom dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import numpy as np\n",
    "from torchvision import transforms\n",
    "from semilearn import get_data_loader, get_net_builder, get_algorithm, get_config, Trainer\n",
    "from semilearn import split_ssl_data, BasicDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Specifiy configs and define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/bin/sh: 1: netstat: not found\n"
     ]
    }
   ],
   "source": [
    "# define configs and create config\n",
    "config = {\n",
    "    'algorithm': 'fixmatch',\n",
    "    'net': 'wrn_28_2',\n",
    "    'use_pretrain': False,  # todo: add pretrain\n",
    "\n",
    "    # optimization configs\n",
    "    'epoch': 3,\n",
    "    'num_train_iter': 150,\n",
    "    'num_eval_iter': 50,\n",
    "    'optim': 'SGD',\n",
    "    'lr': 0.03,\n",
    "    'momentum': 0.9,\n",
    "    'batch_size': 64,\n",
    "    'eval_batch_size': 64,\n",
    "\n",
    "    # dataset configs\n",
    "    'dataset': 'none',\n",
    "    'num_labels': 40,\n",
    "    'num_classes': 10,\n",
    "    'img_size': 32,\n",
    "    'crop_ratio': 0.875,\n",
    "    'data_dir': './data',\n",
    "\n",
    "    # algorithm specific configs\n",
    "    'hard_label': True,\n",
    "    'uratio': 3,\n",
    "    'ulb_loss_ratio': 1.0,\n",
    "\n",
    "    # device configs\n",
    "    'gpu': 0,\n",
    "    'world_size': 1,\n",
    "    \"num_workers\": 2,\n",
    "    'distributed': False,\n",
    "}\n",
    "config = get_config(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# create model and specify algorithm\n",
    "algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Create dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# replace with your own code\n",
    "data = np.random.randint(0, 255, size=3072 * 1000).reshape((-1, 32, 32, 3))\n",
    "data = np.uint8(data)\n",
    "target = np.random.randint(0, 10, size=1000)\n",
    "lb_data, lb_target, ulb_data, ulb_target = split_ssl_data(config, data, target,\n",
    "                                                          num_labels=config.num_labels,\n",
    "                                                          num_classes=config.num_classes)\n",
    "\n",
    "train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),\n",
    "                                      transforms.RandomCrop(32, padding=int(32 * 0.125), padding_mode='reflect'),\n",
    "                                      transforms.ToTensor(),\n",
    "                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "\n",
    "train_strong_transform = transforms.Compose([transforms.RandomHorizontalFlip(),\n",
    "                                             transforms.RandomCrop(32, padding=int(32 * 0.125), padding_mode='reflect'),\n",
    "                                             transforms.ToTensor(),\n",
    "                                             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "\n",
    "lb_dataset = BasicDataset(config.algorithm, lb_data, lb_target, config.num_classes, train_transform, is_ulb=False)\n",
    "ulb_dataset = BasicDataset(config.algorithm, lb_data, lb_target, config.num_classes, train_transform, is_ulb=True, strong_transform=train_strong_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# replace with your own code\n",
    "eval_data = np.random.randint(0, 255, size=3072 * 100).reshape((-1, 32, 32, 3))\n",
    "eval_data = np.uint8(eval_data)\n",
    "eval_target = np.random.randint(0, 10, size=100)\n",
    "\n",
    "eval_transform = transforms.Compose([transforms.Resize(32),\n",
    "                                      transforms.ToTensor(),\n",
    "                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "\n",
    "eval_dataset = BasicDataset(config.algorithm, lb_data, lb_target, config.num_classes, eval_transform, is_ulb=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# define data loaders\n",
    "train_lb_loader = get_data_loader(config, lb_dataset, config.batch_size)\n",
    "train_ulb_loader = get_data_loader(config, ulb_dataset, int(config.batch_size * config.uratio))\n",
    "eval_loader = get_data_loader(config, eval_dataset, config.eval_batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Training and evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/haoc/miniconda3/envs/usb/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "[2022-08-21 18:38:14,350 INFO] confusion matrix\n",
      "[2022-08-21 18:38:14,352 INFO] [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]\n",
      "[2022-08-21 18:38:14,355 INFO] evaluation metric\n",
      "[2022-08-21 18:38:14,356 INFO] acc: 0.2000\n",
      "[2022-08-21 18:38:14,357 INFO] precision: 0.0200\n",
      "[2022-08-21 18:38:14,357 INFO] recall: 0.1000\n",
      "[2022-08-21 18:38:14,358 INFO] f1: 0.0333\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model saved: ./saved_models/fixmatch_none/latest_model.pth\n",
      "model saved: ./saved_models/fixmatch_none/model_best.pth\n",
      "Epoch: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/haoc/miniconda3/envs/usb/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "[2022-08-21 18:38:24,027 INFO] confusion matrix\n",
      "[2022-08-21 18:38:24,029 INFO] [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]\n",
      "[2022-08-21 18:38:24,032 INFO] evaluation metric\n",
      "[2022-08-21 18:38:24,033 INFO] acc: 0.2000\n",
      "[2022-08-21 18:38:24,033 INFO] precision: 0.0200\n",
      "[2022-08-21 18:38:24,034 INFO] recall: 0.1000\n",
      "[2022-08-21 18:38:24,035 INFO] f1: 0.0333\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model saved: ./saved_models/fixmatch_none/latest_model.pth\n",
      "Epoch: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/haoc/miniconda3/envs/usb/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "[2022-08-21 18:38:33,855 INFO] confusion matrix\n",
      "[2022-08-21 18:38:33,856 INFO] [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]\n",
      "[2022-08-21 18:38:33,860 INFO] evaluation metric\n",
      "[2022-08-21 18:38:33,860 INFO] acc: 0.2000\n",
      "[2022-08-21 18:38:33,861 INFO] precision: 0.0200\n",
      "[2022-08-21 18:38:33,862 INFO] recall: 0.1000\n",
      "[2022-08-21 18:38:33,862 INFO] f1: 0.0333\n",
      "[2022-08-21 18:38:34,118 INFO] Best acc 0.2000 at epoch 0\n",
      "[2022-08-21 18:38:34,119 INFO] Training finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model saved: ./saved_models/fixmatch_none/latest_model.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/haoc/miniconda3/envs/usb/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "[2022-08-21 18:38:35,003 INFO] confusion matrix\n",
      "[2022-08-21 18:38:35,004 INFO] [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]\n",
      "[2022-08-21 18:38:35,007 INFO] evaluation metric\n",
      "[2022-08-21 18:38:35,008 INFO] acc: 0.2000\n",
      "[2022-08-21 18:38:35,009 INFO] precision: 0.0200\n",
      "[2022-08-21 18:38:35,009 INFO] recall: 0.1000\n",
      "[2022-08-21 18:38:35,010 INFO] f1: 0.0333\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'acc': 0.2, 'precision': 0.02, 'recall': 0.1, 'f1': 0.03333333333333334}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# training and evaluation\n",
    "trainer = Trainer(config, algorithm)\n",
    "trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)\n",
    "trainer.evaluate(eval_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 ('usb')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "29d8f0db1976a8fbb4040f75e07c81acdf04dbb74e20adf172a112c04d3c0b95"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
