{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "import pickle\n",
    "sys.path.insert(0, '/home/cvds_lab/yury/mxt-experiments/mlperf/recommendation/pytorch')\n",
    "from alias_generator import AliasSample\n",
    "from convert import generate_negatives\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done\n",
      "855776\n",
      "1223962043\n",
      "1223962043\n"
     ]
    }
   ],
   "source": [
    "data_dir = '/home/cvds_lab/yury/mxt-experiments/mlperf/recommendation/pytorch/ml-20mx16x32'\n",
    "\n",
    "sampler_cache = os.path.join(data_dir, \"alias_tbl_16x32_cached_sampler.pkl\")\n",
    "with open(sampler_cache, \"rb\") as f:\n",
    "    sampler, pos_users_full, pos_items_full, nb_items_full, pck = pickle.load(f)\n",
    "print(\"Done\")\n",
    "print(nb_items_full)\n",
    "print(len(pos_users_full))\n",
    "print(len(pos_items_full))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max user: 2197224\n",
      "max item: 855775\n"
     ]
    }
   ],
   "source": [
    "max_user = pos_users_full.max()\n",
    "print(\"max user: {}\".format(max_user))\n",
    "max_item = pos_items_full.max()\n",
    "print(\"max item: {}\".format(max_item))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 10000\n",
    "p = np.random.choice(np.arange(len(pos_users_full)), size=N, replace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max user: 2196900\n",
      "max item: 855453\n"
     ]
    }
   ],
   "source": [
    "nb_items = nb_items_full\n",
    "pos_users = pos_users_full[p]\n",
    "pos_items = pos_items_full[p]\n",
    "print(\"max user: {}\".format(pos_users.max()))\n",
    "print(\"max item: {}\".format(pos_items.max()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_users = len(sampler.num_regions)\n",
    "train_users = torch.from_numpy(pos_users).type(torch.LongTensor)\n",
    "train_items = torch.from_numpy(pos_items).type(torch.LongTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# produce things not change between epoch\n",
    "# mask for filtering duplicates with real sample\n",
    "# note: test data is removed before create mask, same as reference\n",
    "# create label\n",
    "negative_samples = 4\n",
    "train_label = torch.ones_like(train_users, dtype=torch.float32)\n",
    "neg_label = torch.zeros_like(train_label, dtype=torch.float32)\n",
    "neg_label = neg_label.repeat(negative_samples)\n",
    "train_label = torch.cat((train_label,neg_label))\n",
    "del neg_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      ":::MLPv0.5.0 ncf 1569736180.287664652 (/home/cvds_lab/yury/mxt-experiments/mlperf/recommendation/pytorch/neumf.py:19) model_hp_mf_dim: 64\n",
      "\n",
      ":::MLPv0.5.0 ncf 1569736193.064679861 (/home/cvds_lab/yury/mxt-experiments/mlperf/recommendation/pytorch/neumf.py:27) model_hp_mlp_layer_sizes: [256, 256, 128, 64]\n",
      "NeuMF(\n",
      "  (mf_user_embed): Embedding(2197225, 64)\n",
      "  (mf_item_embed): Embedding(855776, 64)\n",
      "  (mlp_user_embed): Embedding(2197225, 128)\n",
      "  (mlp_item_embed): Embedding(855776, 128)\n",
      "  (mlp): ModuleList(\n",
      "    (0): Linear(in_features=256, out_features=256, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=256, out_features=128, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=128, out_features=64, bias=True)\n",
      "    (5): ReLU()\n",
      "  )\n",
      "  (final_mf): Linear(in_features=64, out_features=1, bias=True)\n",
      "  (final_mlp): Linear(in_features=64, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from neumf import NeuMF\n",
    "\n",
    "# Create model\n",
    "model = NeuMF(2197225, 855776,\n",
    "              mf_dim=64, mf_reg=0.,\n",
    "              mlp_layer_sizes=[256, 256, 128, 64],\n",
    "              mlp_layer_regs=[0. for i in [256, 256, 128, 64]])\n",
    "\n",
    "# Add optimizer and loss to graph\n",
    "params = model.parameters()\n",
    "\n",
    "learning_rate = 0.001\n",
    "beta1 = 0.9\n",
    "beta2 = 0.999\n",
    "eps = 1e-8\n",
    "optimizer = torch.optim.Adam(params, lr=learning_rate, betas=(beta1, beta2), eps=eps)\n",
    "criterion = nn.BCEWithLogitsLoss(reduction = 'none') # use torch.mean() with dim later to avoid copy to host\n",
    "\n",
    "model = model.cuda()\n",
    "criterion = criterion.cuda()\n",
    "        \n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "negatives = generate_negatives(sampler, negative_samples, train_users.numpy())\n",
    "negatives = torch.from_numpy(negatives)\n",
    "neg_users = negatives[:, 0]\n",
    "neg_items = negatives[:, 1]\n",
    "epoch_users = torch.cat((train_users,neg_users))\n",
    "epoch_items = torch.cat((train_items,neg_items))\n",
    "del neg_users, neg_items\n",
    "\n",
    "# shuffle prepared data and split into batches\n",
    "epoch_indices = torch.randperm(len(epoch_users))\n",
    "epoch_size = len(epoch_indices)\n",
    "epoch_users = epoch_users[epoch_indices]\n",
    "epoch_items = epoch_items[epoch_indices]\n",
    "epoch_label = train_label[epoch_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "cal_set_f = os.path.join(data_dir, 'cal_set')\n",
    "# f = open(cal_set_f, 'w')\n",
    "cal_set_dict = {'users': epoch_users, 'items': epoch_items, 'labels': epoch_label}\n",
    "torch.save(cal_set_dict, cal_set_f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'items': tensor([638857, 531419, 625382,  ..., 162897, 678570, 546731]),\n",
       " 'labels': tensor([0., 0., 0.,  ..., 0., 0., 0.]),\n",
       " 'users': tensor([1222745,   74998, 1579138,  ...,  363161, 1196984, 1181740])}"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.load(cal_set_f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/196 [00:00<?, ?it/s]"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "local_batch = batch_size\n",
    "epochs = 1\n",
    "for epoch in range(epochs):\n",
    "    negatives = generate_negatives(sampler, negative_samples, train_users.numpy())\n",
    "    negatives = torch.from_numpy(negatives)\n",
    "    neg_users = negatives[:, 0]\n",
    "    neg_items = negatives[:, 1]\n",
    "    epoch_users = torch.cat((train_users,neg_users))\n",
    "    epoch_items = torch.cat((train_items,neg_items))\n",
    "    del neg_users, neg_items\n",
    "    \n",
    "    # shuffle prepared data and split into batches\n",
    "    epoch_indices = torch.randperm(len(epoch_users))\n",
    "    epoch_size = len(epoch_indices)\n",
    "    epoch_users = epoch_users[epoch_indices]\n",
    "    epoch_items = epoch_items[epoch_indices]\n",
    "    epoch_label = train_label[epoch_indices]\n",
    "    epoch_users_list = epoch_users.split(local_batch)\n",
    "    epoch_items_list = epoch_items.split(local_batch)\n",
    "    epoch_label_list = epoch_label.split(local_batch)\n",
    "    \n",
    "    # only print progress bar on rank 0\n",
    "    num_batches = (epoch_size + batch_size - 1) // batch_size\n",
    "    qbar = tqdm(range(num_batches))\n",
    "    \n",
    "#     for i in qbar:\n",
    "        # selecting input from prepared data\n",
    "#         user = epoch_users_list[i].cuda()\n",
    "#         item = epoch_items_list[i].cuda()\n",
    "#         label = epoch_label_list[i].view(-1,1).cuda()\n",
    "\n",
    "#         for p in model.parameters():\n",
    "#             p.grad = None\n",
    "\n",
    "#         outputs = model(user, item)\n",
    "#         loss = criterion(outputs, label).float()\n",
    "#         loss = torch.mean(loss.view(-1), 0)\n",
    "#         print(loss.item())\n",
    "\n",
    "#         loss.backward()\n",
    "#         optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 6,  5, 14,  5, 12,  2, 13,  3,  3, 13, 14,  6,  4,  3,  9,  3,  9,  6,\n",
       "         4,  2,  2, 13,  6,  0, 12, 11,  6,  9,  1,  0,  9,  5,  9, 10,  6, 13,\n",
       "        13,  7,  4,  4,  9, 12,  3,  2,  5,  3,  3, 11, 11,  7,  6,  6, 12,  9,\n",
       "         9,  0,  6,  5,  3, 13,  6,  3,  3,  7,  3,  8,  6,  9,  6, 13,  3, 12,\n",
       "         9,  9,  9,  1, 14,  1, 10,  3, 13,  9,  2, 11,  3,  2,  3,  7,  4,  9,\n",
       "         0, 10, 13,  3,  1,  0, 14, 13,  6,  6,  6, 11,  2,  3, 12,  5, 12,  8,\n",
       "         6,  6,  4, 14,  4,  5,  8,  9, 12,  4,  2,  9, 13,  9,  7,  7,  7, 10,\n",
       "        12,  6,  0,  9,  6, 12, 11, 12,  2,  6,  3,  2, 11,  9,  5, 13,  0,  3,\n",
       "         4,  3,  1,  6,  6, 11,  3,  2,  3,  6,  4, 13, 10,  8,  5, 13, 12,  8,\n",
       "         7,  6,  8, 11,  5,  3,  1,  5,  7,  3,  5,  0,  2,  4, 12,  4, 12,  3,\n",
       "         7,  9,  1, 11, 10,  3,  1, 12, 15,  8, 13,  1,  7, 10,  6, 11,  9,  3,\n",
       "        12,  9,  3,  2,  1,  1, 13,  7, 13, 11, 12,  1, 10, 11, 13,  1, 14,  8,\n",
       "        11, 11,  3,  0,  3, 14,  7,  6,  5, 11, 12,  4,  4,  2,  5, 13, 14,  9,\n",
       "         2, 12, 12, 10,  2, 14,  6, 10,  4, 11, 13, 15,  1, 10, 13,  6, 10, 12,\n",
       "         7,  9,  2, 11])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "epoch_users_list[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([738434, 435846, 151808, 832544, 420195, 177795, 382987, 563540, 696714,\n",
       "        150602, 359062, 843844, 534328, 398248, 736186, 241708, 168527, 677968,\n",
       "        130986, 313454,  60469, 701909, 582572, 215296, 614220, 377902, 203404,\n",
       "        615900, 296711,  39836, 389947, 594673, 300492, 489974, 412205, 498236,\n",
       "        728268, 510405, 394506, 188835, 443500, 432449, 234825, 255006, 694706,\n",
       "        741851, 375773, 320001, 537916,  17696,  69316, 581796, 203522, 670856,\n",
       "        851786, 629876, 789018, 766496, 498923, 339141, 322131, 611526, 506599,\n",
       "        187971, 356328, 266931,  52219, 622618, 402324, 547895, 628472, 637481,\n",
       "         35095, 770339, 121111, 283057, 616152, 660983,   1013, 137968, 457527,\n",
       "        387486, 232585, 575139, 248926,   1937, 669186, 386263, 194287, 183478,\n",
       "        419780, 245735, 319916, 376053, 192176, 148524, 384232, 625605,  76647,\n",
       "        217490, 665915, 394617, 113411, 390937, 417765, 619537, 364686,  96900,\n",
       "        230661, 411551,  95168, 507191, 823957, 134680, 235354, 514839, 130342,\n",
       "        335744, 360129, 344848,  35244, 674393, 215640, 556761, 568300, 192358,\n",
       "        591826, 661963, 222002, 236989, 478114, 151792, 114619, 535813, 194579,\n",
       "        233194, 775116, 337148, 491541, 702655, 620253, 556566, 100530, 573750,\n",
       "        422964, 268319, 206238,  51704, 450501, 593402, 241639, 827090, 249737,\n",
       "        677479, 161412, 660144, 479663, 515468, 599979,  45435, 186324, 625606,\n",
       "        821712, 387946, 280882, 538528, 193373, 385732, 161483, 415298, 446320,\n",
       "         34505, 553337,  11480, 510496, 702642, 259629, 139930, 564312, 146379,\n",
       "        362456, 713095, 238642, 541356, 226755, 664893, 187416,  30025, 638746,\n",
       "        491861,  71523, 181447,   3648,   2211, 212724, 635825, 186456, 617249,\n",
       "        499190, 259743, 848268, 568955, 153183, 773245, 259411, 614240, 280343,\n",
       "        367936,  50428,  60232, 197631, 291021, 411645, 503317, 319685, 181915,\n",
       "         44393, 191639, 247104, 762857, 370643, 210927, 381100, 640193, 747015,\n",
       "        173095, 745717, 139860,  79467, 526345, 724957,  56258, 152401, 594502,\n",
       "        417328, 535647,  43808, 116336, 677389, 625360,  54629, 539403, 228163,\n",
       "        835042, 379052,  20775, 512559, 476061, 146903, 552055, 304643, 818337,\n",
       "        250835, 829281, 665536,  83084])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "epoch_items_list[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([13,  3,  3,  9,  0, 11,  9,  9,  4, 11,  7,  9,  3,  2,  7,  0,  3,  1,\n",
      "        11,  3,  8,  6,  4,  5,  8,  0, 11,  3,  5,  3,  1,  9, 10,  1, 12,  6,\n",
      "         9,  1,  8, 11,  3, 14,  7,  9, 12,  4, 13, 15])\n",
      "tensor([382987, 398248, 241708, 168527, 215296, 377902, 615900, 389947, 188835,\n",
      "        320001, 187971, 387486, 248926,   1937, 386263, 419780, 376053, 192176,\n",
      "        394617, 390937,  96900, 411551, 823957, 134680, 235354, 100530, 593402,\n",
      "        241639, 193373, 385732, 161483, 713095, 226755, 187416,  30025, 212724,\n",
      "        186456, 153183, 181915, 191639, 247104, 210927, 381100, 594502,  43808,\n",
      "        228163, 379052,  20775])\n"
     ]
    }
   ],
   "source": [
    "nz = epoch_label_list[0].nonzero().flatten()\n",
    "print(epoch_users_list[0][nz])\n",
    "print(epoch_items_list[0][nz])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
