{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.io as sio\n",
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "datapath = r'..\\data'\n",
    "dataset = 'CUB_DNA' # This is the same CUB data, just 6 classes lacking any DNA barcodes are removed \n",
    "x2 = sio.loadmat(os.path.join(datapath, dataset, 'CUB_DNA.mat')) \n",
    "x = sio.loadmat(os.path.join(datapath, dataset, 'Bird_DNA.mat'))  # Dataset containing all bird DNA barcodes extracted from COI gene\n",
    "us_cls = sio.loadmat(os.path.join(datapath, dataset, 'CUB_unseen_classes_sci_name.mat'))\n",
    "barcodes=x['nucleotides_aligned'][0]\n",
    "unseen_cls = us_cls['us_classes']\n",
    "bird_species=x['species'][0]\n",
    "cub_barcodes = x2['nucleotides_aligned'][0]\n",
    "c_labels = x2['species'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26409\n"
     ]
    }
   ],
   "source": [
    "# Number of training samples and entire data\n",
    "N = len(barcodes)\n",
    "print(N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "bcodes=[]\n",
    "#labels=[]\n",
    "b_species = []\n",
    "for i in range(N):\n",
    "    if len(barcodes[i][0])>0:\n",
    "        bcodes.append(barcodes[i][0])\n",
    "        b_species.append(bird_species[i][0])\n",
    "\n",
    "b_species = np.asarray(b_species)\n",
    "bcodes = np.asarray(bcodes)\n",
    "        \n",
    "# c_species = []\n",
    "# for i in range(len(cub_species)):\n",
    "#     c_species.append(cub_species[i][0][0].split('--> ')[1])  \n",
    "    \n",
    "cub_bcodes = []\n",
    "c_labels_ = []\n",
    "for i in range(len(cub_barcodes)):\n",
    "    cub_bcodes.append(cub_barcodes[i][0])\n",
    "    c_labels_.append(c_labels[i][0])\n",
    "    \n",
    "unseen_classes = []\n",
    "for i in range(len(unseen_cls)):\n",
    "    unseen_classes.append(unseen_cls[i][0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample barcode: A-----------------------------------------------------CTATACCTAATCTTCGGCGCCTGAGCTGGTATAGTCGGCACCGCCCTCAGCCTACTCATTCGCGCAGAACTCGGCCAACCAGGCACACTCCTAGGCGACGACCAAATCTACAACGTAATCGTTACTGCACATGCCTTTGTAATAATCTTCTTCATAGTTATACCAATCATGATTGGAGGATTCGGAAACTGACTTGTCCCACTCATAATTGGAGCCCCCGACATAGCCTTTCCACGCATAAACAACATAAGCTTCTGGCTACTCCCTCCATCCTTCCTCCTCCTACTGGCCTCCTCAACAGTAGAAGCAGGAGCCGGCACTGGATGAACTGTCTACCCCCCATTAGCTGGCAACATAGCCCATGCCGGAGCTTCAGTAGACTTGGCCATCTTCTCCTTACACTTAGCTGGAGTCTCATCCATTCTAGGAGCAATCAACTTTATCACAACCGCCATCAACATAAAACCCCCAGCCCTCTCCCAGTACCAAACACCCCTATTCGTATGATCTGTCCTCATTACCGCCGTCCTTCTATTACTTTCACTCCCAGTCCTAGCCGCTGGCATCACTATACTACTCACAGACCGAAACCTAAACACAACATTCTTTGACCCTGCCGGCGGAGGCGATCCTATCCTATACCAACACCTCTTCTGATTCTTTGGACACCCAGAAGTCTACATCTTAATCCT---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------A\n"
     ]
    }
   ],
   "source": [
    "print('Sample barcode: %s' % bcodes[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25712 697\n"
     ]
    }
   ],
   "source": [
    "# Excluding unseen classes of CUB data from training set\n",
    "idx = np.in1d(b_species, unseen_classes)\n",
    "\n",
    "b_species = b_species[~idx]\n",
    "bcodes = bcodes[~idx]\n",
    "print(len(b_species), sum(idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of classes: 999\n"
     ]
    }
   ],
   "source": [
    "uy = np.unique(b_species)\n",
    "labels = np.zeros(len(b_species), dtype='int32')\n",
    "for i in range(len(uy)):\n",
    "    idx = b_species==uy[i]\n",
    "    labels[idx.ravel()] = i+1 # Matlab indexing\n",
    "    if idx.sum()<10:\n",
    "        print(uy[i])\n",
    "    \n",
    "print('Number of classes: %d' % len(np.unique(labels)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Aechmophorus occidentalis\n",
      "Aethia cristatella\n",
      "Aethia psittacula\n",
      "Aethia pusilla\n",
      "Ammodramus bairdii\n",
      "Ammodramus henslowii\n",
      "Ammodramus leconteii\n",
      "Ammodramus maritimus\n",
      "Ammodramus savannarum\n",
      "Amphispiza bilineata\n",
      "Antrostomus carolinensis\n",
      "Antrostomus vociferus\n",
      "Aphelocoma coerulescens\n",
      "Bombycilla cedrorum\n",
      "Calypte anna\n",
      "Campylorhynchus brunneicapillus\n",
      "Cepphus columba\n",
      "Ceryle rudis\n",
      "Chordeiles minor\n",
      "Coccyzus americanus\n",
      "Coccyzus erythropthalmus\n",
      "Coccyzus minor\n",
      "Colibri thalassinus\n",
      "Contopus cooperi\n",
      "Corvus albicollis\n",
      "Corvus ossifragus\n",
      "Cyanocitta cristata\n",
      "Cyanocorax yncas\n",
      "Dolichonyx oryzivorus\n",
      "Dryocopus pileatus\n",
      "Empidonax flaviventris\n",
      "Empidonax minimus\n",
      "Empidonax virescens\n",
      "Euphagus carolinus\n",
      "Fratercula corniculata\n",
      "Geococcyx californianus\n",
      "Helmitheros vermivorum\n",
      "Icteria virens\n",
      "Icterus cucullatus\n",
      "Icterus parisorum\n",
      "Icterus spurius\n",
      "Lamprotornis nitens\n",
      "Lanius ludovicianus\n",
      "Larus californicus\n",
      "Larus delawarensis\n",
      "Larus glaucescens\n",
      "Larus heermanni\n",
      "Larus occidentalis\n",
      "Leiothlypis ruficapilla\n",
      "Leuconotopicus borealis\n",
      "Leucosticte tephrocotis\n",
      "Limnothlypis swainsonii\n",
      "Megaceryle alcyon\n",
      "Melanerpes erythrocephalus\n",
      "Myiarchus crinitus\n",
      "Oreoscoptes montanus\n",
      "Pagophila eburnea\n",
      "Parkesia motacilla\n",
      "Passerina caerulea\n",
      "Passerina ciris\n",
      "Passerina cyanea\n",
      "Pelecanus erythrorhynchos\n",
      "Pelecanus occidentalis\n",
      "Petrochelidon pyrrhonota\n",
      "Phalacrocorax penicillatus\n",
      "Phalacrocorax urile\n",
      "Pheucticus ludovicianus\n",
      "Picoides dorsalis\n",
      "Pipilo chlorurus\n",
      "Pipilo erythrophthalmus\n",
      "Piranga olivacea\n",
      "Podiceps auritus\n",
      "Podiceps nigricollis\n",
      "Podilymbus podiceps\n",
      "Pooecetes gramineus\n",
      "Quiscalus major\n",
      "Rissa brevirostris\n",
      "Salpinctes obsoletus\n",
      "Sayornis phoebe\n",
      "Selasphorus rufus\n",
      "Setophaga cerulea\n",
      "Setophaga citrina\n",
      "Setophaga discolor\n",
      "Setophaga pensylvanica\n",
      "Setophaga pinus\n",
      "Setophaga tigrina\n",
      "Sitta carolinensis\n",
      "Spizella pallida\n",
      "Spizella pusilla\n",
      "Sterna forsteri\n",
      "Sternula antillarum\n",
      "Sturnella neglecta\n",
      "Thalasseus elegans\n",
      "Thryothorus ludovicianus\n",
      "Troglodytes hiemalis\n",
      "Tyrannus dominicensis\n",
      "Tyrannus forficatus\n",
      "Vermivora chrysoptera\n",
      "Vermivora cyanoptera\n",
      "Vireo atricapilla\n",
      "Vireo flavifrons\n",
      "Xanthocephalus xanthocephalus\n",
      "Zonotrichia querula\n",
      "194 [  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18\n",
      "  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36\n",
      "  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54\n",
      "  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72\n",
      "  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90\n",
      "  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108\n",
      " 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126\n",
      " 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144\n",
      " 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162\n",
      " 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180\n",
      " 181 182 183 184 185 186 187 188 189 190 191 192 193 194]\n"
     ]
    }
   ],
   "source": [
    "# Classes and their names from CUB data that has less than 10 samples\n",
    "# Just for information\n",
    "uy = np.unique(c_labels_)\n",
    "c_labels_ = np.asarray(c_labels_)\n",
    "c_labels = np.zeros(len(c_labels_), dtype='int32')\n",
    "for i in range(len(uy)):\n",
    "    idx = c_labels_==uy[i]\n",
    "    c_labels[idx] = i+1 # Matlab indexing\n",
    "    if np.sum(idx)<10:\n",
    "        print(uy[i])\n",
    "    \n",
    "print(len(np.unique(c_labels)), np.unique(c_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.text import Tokenizer\n",
    "# Converting base pairs, ACTG, into numbers: Tokenization\n",
    "\n",
    "### For BIRD data\n",
    "tokenizer = Tokenizer(char_level=True)\n",
    "tokenizer.fit_on_texts(bcodes)\n",
    "sequence_of_int = tokenizer.texts_to_sequences(bcodes)\n",
    "\n",
    "### For CUB data barcodes\n",
    "tokenizer = Tokenizer(char_level=True)\n",
    "tokenizer.fit_on_texts(cub_bcodes)\n",
    "sequence_of_int_cub = tokenizer.texts_to_sequences(cub_bcodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 1548.0\n"
     ]
    }
   ],
   "source": [
    "# Bird data barcode lengths have a high variance and not as in INSECT data majority has seq length of 658\n",
    "# These barcodes are aligned ones as in INSECT data\n",
    "cnt = 0\n",
    "ll = np.zeros((len(bcodes), 1))\n",
    "for i in range(len(bcodes)):\n",
    "    ll[i] = len(sequence_of_int[i])\n",
    "    if ll[i]==658:\n",
    "        #print(i)\n",
    "        cnt += 1\n",
    "print(cnt, np.median(ll))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Calculating sequence count for each species\n",
    "cl = len(np.unique(b_species))\n",
    "N = len(b_species)\n",
    "seq_len=np.zeros((cl))\n",
    "seq_cnt=np.zeros((cl))\n",
    "for i in range(N):\n",
    "    k=labels[i]-1\n",
    "    seq_cnt[k]+=1\n",
    "    #seq_len[k]+=len(sequence_of_int[k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# classes: 999\n"
     ]
    }
   ],
   "source": [
    "print('# classes: %d' % cl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Converting the all CUB DNA barcodes into one-hot encoding. Note that this data is not used during training. \n",
    "# We are getting it ready for the prediction time to get the final DNA embeddings after training is done\n",
    "\n",
    "sl = 1500 # Fixed to the max sequnce length. Note that, changing it to median value, 1548, does have an infitesimal effect\n",
    "N_cub = len(cub_bcodes)\n",
    "allX=np.zeros((N_cub,sl,5))\n",
    "for i in range(N_cub):\n",
    "    Nt=len(sequence_of_int_cub[i])\n",
    "\n",
    "    for j in range(sl):\n",
    "        if(len(sequence_of_int_cub[i])>j):\n",
    "            k=sequence_of_int_cub[i][j]-1\n",
    "            if(k>4):\n",
    "                k=4\n",
    "            allX[i][j][k]=1.0\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the training matrix and labels\n",
    "trainX=np.zeros((N,sl,5))\n",
    "trainY=np.zeros((N,cl))\n",
    "labelY=np.zeros(N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here is where we cap the class sample size to 50 for a balanced training set\n",
    "\n",
    "Nc=-1\n",
    "clas_cnt=np.zeros((cl))\n",
    "for i in range(N):\n",
    "    Nt=len(sequence_of_int[i])\n",
    "    k=labels[i]-1\n",
    "    clas_cnt[k]+=1\n",
    "    itl=i+1\n",
    "    if(seq_cnt[k]>=10 and clas_cnt[k]<=50):\n",
    "        Nc=Nc+1\n",
    "        for j in range(sl):\n",
    "            if(len(sequence_of_int[i])>j):\n",
    "                k=sequence_of_int[i][j]-1\n",
    "                if(k>4):\n",
    "                    k=4\n",
    "                trainX[Nc][j][k]=1.0\n",
    "            \n",
    "        k=labels[i]-1\n",
    "        trainY[Nc][k]=1.0\n",
    "        labelY[Nc]=k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final training size: 21230\n"
     ]
    }
   ],
   "source": [
    "print('Final training size: %d' % Nc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(21230, 1500, 5) (21230, 999) (21230,)\n"
     ]
    }
   ],
   "source": [
    "# selecting the balanced training set\n",
    "trainX=trainX[0:Nc]\n",
    "trainY=trainY[0:Nc]\n",
    "labelY=labelY[0:Nc]\n",
    "print(trainX.shape, trainY.shape, labelY.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 10.0\n"
     ]
    }
   ],
   "source": [
    "idx = np.argwhere(np.all(trainY[..., :] == 0, axis=0))\n",
    "trainY = np.delete(trainY, idx, axis=1)\n",
    "\n",
    "print(len(idx), min(np.sum(trainY,axis=0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(21230, 999)\n",
      "(21230,)\n"
     ]
    }
   ],
   "source": [
    "print(trainY.shape)\n",
    "print(labelY.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(21230, 1500, 5, 1) (2772, 1500, 5, 1)\n"
     ]
    }
   ],
   "source": [
    "# Expanding the training set shape for CNN \n",
    "trainX=np.expand_dims(trainX, axis=3)\n",
    "allX=np.expand_dims(allX, axis=3)\n",
    "print(trainX.shape, allX.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# Splitting the training set into train and validation\n",
    "X_train, X_test, y_train, y_test = train_test_split(trainX, trainY, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(16984, 999)\n"
     ]
    }
   ],
   "source": [
    "print(y_train.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classic CNN Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras import datasets, layers, models, optimizers, callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CNN model architecture\n",
    "model = models.Sequential()\n",
    "model.add(layers.Conv2D(64, (3,3), activation='relu', input_shape=(sl, 5,1),padding=\"SAME\"))\n",
    "model.add(layers.BatchNormalization())\n",
    "model.add(layers.MaxPooling2D((3,1)))\n",
    "model.add(layers.Conv2D(32, (3,3), activation='relu',padding=\"SAME\"))\n",
    "model.add(layers.BatchNormalization())\n",
    "model.add(layers.MaxPooling2D((3,1)))\n",
    "model.add(layers.Conv2D(16, (3,3), activation='relu',padding=\"SAME\"))\n",
    "model.add(layers.BatchNormalization())\n",
    "model.add(layers.MaxPooling2D((3,1)))\n",
    "# model.add(layers.Conv2D(16, (3,3), activation='relu',padding=\"SAME\"))\n",
    "# model.add(layers.BatchNormalization())\n",
    "# model.add(layers.MaxPooling2D((3,1)))\n",
    "model.add(layers.Flatten())\n",
    "model.add(layers.BatchNormalization())\n",
    "model.add(layers.Dense(400, activation='tanh'))\n",
    "model.add(layers.Dense(y_train.shape[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 1500, 5, 64)       640       \n",
      "_________________________________________________________________\n",
      "batch_normalization (BatchNo (None, 1500, 5, 64)       256       \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 500, 5, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 500, 5, 32)        18464     \n",
      "_________________________________________________________________\n",
      "batch_normalization_1 (Batch (None, 500, 5, 32)        128       \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 166, 5, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 166, 5, 16)        4624      \n",
      "_________________________________________________________________\n",
      "batch_normalization_2 (Batch (None, 166, 5, 16)        64        \n",
      "_________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2 (None, 55, 5, 16)         0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 4400)              0         \n",
      "_________________________________________________________________\n",
      "batch_normalization_3 (Batch (None, 4400)              17600     \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 400)               1760400   \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 999)               400599    \n",
      "=================================================================\n",
      "Total params: 2,202,775\n",
      "Trainable params: 2,193,751\n",
      "Non-trainable params: 9,024\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step-decay learning rate scheduler\n",
    "def step_decay(epoch):\n",
    "   initial_lrate = 0.001\n",
    "   drop = 0.5\n",
    "   epochs_drop = 2.0\n",
    "   lrate = initial_lrate * np.power(drop, np.floor((1+epoch)/epochs_drop))\n",
    "   return lrate\n",
    "\n",
    "class LossHistory(callbacks.Callback):\n",
    "    def on_train_begin(self, logs={}):\n",
    "       self.losses = []\n",
    "       self.lr = []\n",
    " \n",
    "    def on_epoch_end(self, batch, logs={}):\n",
    "       self.losses.append(logs.get('loss'))\n",
    "       self.lr.append(step_decay(len(self.losses)))\n",
    "        \n",
    "loss_history = LossHistory()\n",
    "lrate = callbacks.LearningRateScheduler(step_decay)\n",
    "callbacks_list = [loss_history, lrate]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 16984 samples, validate on 4246 samples\n",
      "Epoch 1/5\n",
      "16984/16984 [==============================] - 168s 10ms/sample - loss: 1.6933 - accuracy: 0.7631 - top_k_categorical_accuracy: 0.8573 - val_loss: 3.5235 - val_accuracy: 0.5787 - val_top_k_categorical_accuracy: 0.8295\n",
      "Epoch 2/5\n",
      "16984/16984 [==============================] - 167s 10ms/sample - loss: 0.1881 - accuracy: 0.9564 - top_k_categorical_accuracy: 0.9953 - val_loss: 0.3157 - val_accuracy: 0.9444 - val_top_k_categorical_accuracy: 0.9894\n",
      "Epoch 3/5\n",
      "16984/16984 [==============================] - 170s 10ms/sample - loss: 0.1145 - accuracy: 0.9676 - top_k_categorical_accuracy: 0.9972 - val_loss: 0.2422 - val_accuracy: 0.9520 - val_top_k_categorical_accuracy: 0.9903\n",
      "Epoch 4/5\n",
      "16984/16984 [==============================] - 168s 10ms/sample - loss: 0.0731 - accuracy: 0.9761 - top_k_categorical_accuracy: 0.9986 - val_loss: 0.1813 - val_accuracy: 0.9555 - val_top_k_categorical_accuracy: 0.9925\n",
      "Epoch 5/5\n",
      "16984/16984 [==============================] - 169s 10ms/sample - loss: 0.0643 - accuracy: 0.9779 - top_k_categorical_accuracy: 0.9984 - val_loss: 0.1874 - val_accuracy: 0.9564 - val_top_k_categorical_accuracy: 0.9906\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.keras.metrics import top_k_categorical_accuracy\n",
    "#opt = tf.keras.optimizers.SGD(momentum=0.9, nesterov=True)\n",
    "opt = tf.keras.optimizers.Adam()\n",
    "model.compile(optimizer=opt, loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy','top_k_categorical_accuracy'])\n",
    "# Validation time\n",
    "history = model.fit(X_train, y_train, epochs=5, batch_size = 32, validation_data=(X_test, y_test),\n",
    "                      callbacks=callbacks_list, verbose=1)\n",
    "\n",
    "# Final Test time\n",
    "#history = model.fit(trainX, trainY, epochs=5, callbacks=callbacks_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Learning rates through epochs: [0.0005, 0.0005, 0.00025, 0.00025, 0.000125]\n"
     ]
    }
   ],
   "source": [
    "print('Learning rates through epochs:', loss_history.lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\sbadirli\\.conda\\envs\\resnet\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n",
      "INFO:tensorflow:Assets written to: CUB_DNA_5e_adam_aligned\\assets\n"
     ]
    }
   ],
   "source": [
    "### Saving and Loading model\n",
    "# model.save('CUB_DNA_5e_adam_aligned')\n",
    "# model = tf.keras.models.load_model('CUB_DNA_5e_adam_aligned')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2772, 400)\n"
     ]
    }
   ],
   "source": [
    "# Getting the DNA embeddings of all data from the last dense layer\n",
    "from tensorflow.keras.models import Model\n",
    "\n",
    "# Getting the DNA embeddings of all data from the last dense layer\n",
    "layer_name = 'dense'\n",
    "model2= Model(inputs=model.input, outputs=model.get_layer(layer_name).output)\n",
    "dna_embeddings=model2.predict(allX)\n",
    "print(dna_embeddings.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savetxt(\"cub_cnn_embeddings_400_5e_adam_aligned.csv\", dna_embeddings, delimiter=\",\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
