{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Dataset_Generator_SingleFeature",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "P7V1GOVFKmNu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "from numpy import genfromtxt\n",
        "from sklearn import svm\n",
        "from sklearn.metrics import accuracy_score\n",
        "import random\n",
        "import matplotlib.pyplot as plt\n",
        "import xgboost as xgb\n",
        "from xgboost import XGBClassifier\n",
        "from sklearn.tree import DecisionTreeClassifier\n",
        "from sklearn import svm\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "import statsmodels.api as sm\n",
        "import numpy.linalg as la\n",
        "import scipy.io as sio\n",
        "import pickle\n",
        "from cvxopt import matrix, solvers\n",
        "from sklearn.decomposition import PCA\n",
        "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "orVUL-X2RI-X",
        "colab_type": "text"
      },
      "source": [
        "**Code for datasets:**\n",
        "\n",
        "The next two sections generate the .p and .m files used for training and testing. You do not need to run this again if those files are loaded."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "o9XGDWirRH59",
        "colab_type": "code",
        "outputId": "1826af01-1aba-48b1-dad9-69190821916a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 170
        }
      },
      "source": [
        "def preprocess(W):\n",
        "    u = np.min(W,axis = 0)\n",
        "    v = np.max(W,axis = 0)\n",
        "    l = W.shape[0]\n",
        "    t = W.shape[1]\n",
        "    for i in range(l):\n",
        "        for j in range(t):\n",
        "            W[i,j] = W[i,j] - u[j]\n",
        "            W[i,j]/=(v[j]-u[j])\n",
        "    return W\n",
        "\n",
        "datasets = []\n",
        "# Breast Cancer\n",
        "data_file = open('/content/breast-cancer-wisconsin.data', 'r')\n",
        "data = np.loadtxt(data_file, delimiter=\",\")[:, 1:]\n",
        "zero_indices = np.where(data[:, -1] == 2.0)\n",
        "one_indices = np.where(data[:, -1] != 2.0)\n",
        "data[:, -1][zero_indices] = 0\n",
        "data[:, -1][one_indices] = 1\n",
        "data = preprocess(data)\n",
        "cancer = data\n",
        "datasets.append(('cancer', data))\n",
        "\n",
        "# Pima Indians Diabetes\n",
        "data_file = open('/content/pima-indians-diabetes.csv', 'r')\n",
        "data = np.loadtxt(data_file, delimiter=\",\")\n",
        "data = preprocess(data)\n",
        "diabetes = data\n",
        "datasets.append(('diabetes', data))\n",
        "\n",
        "# Banknote\n",
        "data_file = open('/content/data_banknote_authentication.txt', 'r')\n",
        "data = np.loadtxt(data_file, delimiter=\",\")\n",
        "data = preprocess(data)\n",
        "banknote = data\n",
        "datasets.append(('banknote', data))\n",
        "\n",
        "# Ringnorm\n",
        "data = []\n",
        "file = open('/content/ringnorm.data', 'r')\n",
        "for line in file.readlines():\n",
        "  data.append([float(x) for x in line.split()])\n",
        "data = np.array(data)\n",
        "data = preprocess(data)\n",
        "ringnorm = data\n",
        "datasets.append(('ringnorm', data))\n",
        "\n",
        "# twonorm\n",
        "data = []\n",
        "file = open('/content/twonorm.data', 'r')\n",
        "for line in file.readlines():\n",
        "  data.append([float(x) for x in line.split()])\n",
        "data = np.array(data)\n",
        "data = preprocess(data)\n",
        "twonorm = data\n",
        "datasets.append(('twonorm', data))\n",
        "\n",
        "mat_file = dict()\n",
        "mat_file['cancer'] = cancer\n",
        "mat_file['ringnorm'] = ringnorm\n",
        "mat_file['twonorm'] = twonorm\n",
        "mat_file['diabetes'] = diabetes\n",
        "mat_file['banknote'] = banknote"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[5. 1. 1. 1. 2. 1. 3. 1. 1. 0.]\n",
            "[  6.    148.     72.     35.      0.     33.6     0.627  50.      1.   ]\n",
            "[ 3.6216   8.6661  -2.8073  -0.44699  0.     ]\n",
            "[ 0.8494  2.177   0.5982  1.6894  3.1137 -3.406   3.7986 -2.6421  1.5779\n",
            " -0.1808 -0.2118  1.6327  4.664   1.0808 -1.1717 -1.6605  0.5775  1.6638\n",
            "  3.0895 -3.0276  0.    ]\n",
            "[-1.2036 -2.624   0.5963  1.3859 -1.3597  0.6758  1.0008 -0.9589 -1.3487\n",
            " -0.5572 -0.4398 -1.1223 -0.1817 -1.317  -0.3551 -1.422   0.1983 -3.0514\n",
            " -1.065  -0.8541  1.    ]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HLcIk7ibonj3",
        "colab_type": "text"
      },
      "source": [
        "Generate biased training sets"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yPlUBFFDEd4b",
        "colab_type": "code",
        "outputId": "0c5a059c-cee1-4d85-ff7f-8a8245619170",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "trials = 100 # trials per dataset\n",
        "training_points = 100 # size of training dataset\n",
        "testing_points = 500 # size of testing dataset\n",
        "\n",
        "for name, data in datasets:\n",
        "  trial = 0\n",
        "  training_sets = []\n",
        "  testing_sets = []\n",
        "  while trial < trials:\n",
        "    print(\"trial {}\".format(trial))\n",
        "    data_dimensions = len(data[0]) - 2\n",
        "\n",
        "    # First select random feature\n",
        "    feature_to_bias = random.randint(0, data_dimensions)\n",
        "\n",
        "    # Now decide whether to bias it up or down\n",
        "    bias_direction = bool(random.getrandbits(1))\n",
        "\n",
        "    # Larger values will be selected with x4 probability or smaller values with x4 probability\n",
        "    factor = .25\n",
        "    if bias_direction:\n",
        "      factor = 4.0\n",
        "\n",
        "    train_sample_probs = np.array([1.0 for i in range(len(data))])\n",
        "    median = np.median(data[:, feature_to_bias])\n",
        "    biased_inds = np.where(data[:, feature_to_bias] > median)\n",
        "\n",
        "    train_sample_probs[biased_inds] = factor\n",
        "    train_sample_probs = train_sample_probs / sum(train_sample_probs)\n",
        "    inds = [i for i in range(len(data))]\n",
        "\n",
        "    train_inds = np.random.choice(inds, training_points, replace = False, p=train_sample_probs)\n",
        "    possible_test_inds = list(set(range(len(data))) - set(train_inds))\n",
        "    test_inds = np.random.choice(possible_test_inds, testing_points, replace = False)\n",
        "\n",
        "    # Check to make sure its well balanced\n",
        "    y_train = data[:, -1]\n",
        "    if sum(y_train) < 20:\n",
        "      print(\"labels not balanced, skipping this set\")\n",
        "      continue\n",
        "    \n",
        "    trial += 1\n",
        "    training_sets.append(train_inds)\n",
        "    testing_sets.append(test_inds)\n",
        "  train_inds_full = np.array(training_sets)\n",
        "  test_inds_full = np.array(testing_sets)\n",
        "  mat_file['train_inds_{}'.format(name)] = train_inds_full\n",
        "  mat_file['test_inds_{}'.format(name)] = test_inds_full\n",
        "sio.savemat('single_datasets.mat', mat_file)\n",
        "dsets = open('single_datasets.p', 'wb')\n",
        "pickle.dump(mat_file, dsets)\n"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "trial 0\n",
            "trial 1\n",
            "trial 2\n",
            "trial 3\n",
            "trial 4\n",
            "trial 5\n",
            "trial 6\n",
            "trial 7\n",
            "trial 8\n",
            "trial 9\n",
            "trial 10\n",
            "trial 11\n",
            "trial 12\n",
            "trial 13\n",
            "trial 14\n",
            "trial 15\n",
            "trial 16\n",
            "trial 17\n",
            "trial 18\n",
            "trial 19\n",
            "trial 20\n",
            "trial 21\n",
            "trial 22\n",
            "trial 23\n",
            "trial 24\n",
            "trial 25\n",
            "trial 26\n",
            "trial 27\n",
            "trial 28\n",
            "trial 29\n",
            "trial 30\n",
            "trial 31\n",
            "trial 32\n",
            "trial 33\n",
            "trial 34\n",
            "trial 35\n",
            "trial 36\n",
            "trial 37\n",
            "trial 38\n",
            "trial 39\n",
            "trial 40\n",
            "trial 41\n",
            "trial 42\n",
            "trial 43\n",
            "trial 44\n",
            "trial 45\n",
            "trial 46\n",
            "trial 47\n",
            "trial 48\n",
            "trial 49\n",
            "trial 50\n",
            "trial 51\n",
            "trial 52\n",
            "trial 53\n",
            "trial 54\n",
            "trial 55\n",
            "trial 56\n",
            "trial 57\n",
            "trial 58\n",
            "trial 59\n",
            "trial 60\n",
            "trial 61\n",
            "trial 62\n",
            "trial 63\n",
            "trial 64\n",
            "trial 65\n",
            "trial 66\n",
            "trial 67\n",
            "trial 68\n",
            "trial 69\n",
            "trial 70\n",
            "trial 71\n",
            "trial 72\n",
            "trial 73\n",
            "trial 74\n",
            "trial 75\n",
            "trial 76\n",
            "trial 77\n",
            "trial 78\n",
            "trial 79\n",
            "trial 80\n",
            "trial 81\n",
            "trial 82\n",
            "trial 83\n",
            "trial 84\n",
            "trial 85\n",
            "trial 86\n",
            "trial 87\n",
            "trial 88\n",
            "trial 89\n",
            "trial 90\n",
            "trial 91\n",
            "trial 92\n",
            "trial 93\n",
            "trial 94\n",
            "trial 95\n",
            "trial 96\n",
            "trial 97\n",
            "trial 98\n",
            "trial 99\n",
            "trial 0\n",
            "trial 1\n",
            "trial 2\n",
            "trial 3\n",
            "trial 4\n",
            "trial 5\n",
            "trial 6\n",
            "trial 7\n",
            "trial 8\n",
            "trial 9\n",
            "trial 10\n",
            "trial 11\n",
            "trial 12\n",
            "trial 13\n",
            "trial 14\n",
            "trial 15\n",
            "trial 16\n",
            "trial 17\n",
            "trial 18\n",
            "trial 19\n",
            "trial 20\n",
            "trial 21\n",
            "trial 22\n",
            "trial 23\n",
            "trial 24\n",
            "trial 25\n",
            "trial 26\n",
            "trial 27\n",
            "trial 28\n",
            "trial 29\n",
            "trial 30\n",
            "trial 31\n",
            "trial 32\n",
            "trial 33\n",
            "trial 34\n",
            "trial 35\n",
            "trial 36\n",
            "trial 37\n",
            "trial 38\n",
            "trial 39\n",
            "trial 40\n",
            "trial 41\n",
            "trial 42\n",
            "trial 43\n",
            "trial 44\n",
            "trial 45\n",
            "trial 46\n",
            "trial 47\n",
            "trial 48\n",
            "trial 49\n",
            "trial 50\n",
            "trial 51\n",
            "trial 52\n",
            "trial 53\n",
            "trial 54\n",
            "trial 55\n",
            "trial 56\n",
            "trial 57\n",
            "trial 58\n",
            "trial 59\n",
            "trial 60\n",
            "trial 61\n",
            "trial 62\n",
            "trial 63\n",
            "trial 64\n",
            "trial 65\n",
            "trial 66\n",
            "trial 67\n",
            "trial 68\n",
            "trial 69\n",
            "trial 70\n",
            "trial 71\n",
            "trial 72\n",
            "trial 73\n",
            "trial 74\n",
            "trial 75\n",
            "trial 76\n",
            "trial 77\n",
            "trial 78\n",
            "trial 79\n",
            "trial 80\n",
            "trial 81\n",
            "trial 82\n",
            "trial 83\n",
            "trial 84\n",
            "trial 85\n",
            "trial 86\n",
            "trial 87\n",
            "trial 88\n",
            "trial 89\n",
            "trial 90\n",
            "trial 91\n",
            "trial 92\n",
            "trial 93\n",
            "trial 94\n",
            "trial 95\n",
            "trial 96\n",
            "trial 97\n",
            "trial 98\n",
            "trial 99\n",
            "trial 0\n",
            "trial 1\n",
            "trial 2\n",
            "trial 3\n",
            "trial 4\n",
            "trial 5\n",
            "trial 6\n",
            "trial 7\n",
            "trial 8\n",
            "trial 9\n",
            "trial 10\n",
            "trial 11\n",
            "trial 12\n",
            "trial 13\n",
            "trial 14\n",
            "trial 15\n",
            "trial 16\n",
            "trial 17\n",
            "trial 18\n",
            "trial 19\n",
            "trial 20\n",
            "trial 21\n",
            "trial 22\n",
            "trial 23\n",
            "trial 24\n",
            "trial 25\n",
            "trial 26\n",
            "trial 27\n",
            "trial 28\n",
            "trial 29\n",
            "trial 30\n",
            "trial 31\n",
            "trial 32\n",
            "trial 33\n",
            "trial 34\n",
            "trial 35\n",
            "trial 36\n",
            "trial 37\n",
            "trial 38\n",
            "trial 39\n",
            "trial 40\n",
            "trial 41\n",
            "trial 42\n",
            "trial 43\n",
            "trial 44\n",
            "trial 45\n",
            "trial 46\n",
            "trial 47\n",
            "trial 48\n",
            "trial 49\n",
            "trial 50\n",
            "trial 51\n",
            "trial 52\n",
            "trial 53\n",
            "trial 54\n",
            "trial 55\n",
            "trial 56\n",
            "trial 57\n",
            "trial 58\n",
            "trial 59\n",
            "trial 60\n",
            "trial 61\n",
            "trial 62\n",
            "trial 63\n",
            "trial 64\n",
            "trial 65\n",
            "trial 66\n",
            "trial 67\n",
            "trial 68\n",
            "trial 69\n",
            "trial 70\n",
            "trial 71\n",
            "trial 72\n",
            "trial 73\n",
            "trial 74\n",
            "trial 75\n",
            "trial 76\n",
            "trial 77\n",
            "trial 78\n",
            "trial 79\n",
            "trial 80\n",
            "trial 81\n",
            "trial 82\n",
            "trial 83\n",
            "trial 84\n",
            "trial 85\n",
            "trial 86\n",
            "trial 87\n",
            "trial 88\n",
            "trial 89\n",
            "trial 90\n",
            "trial 91\n",
            "trial 92\n",
            "trial 93\n",
            "trial 94\n",
            "trial 95\n",
            "trial 96\n",
            "trial 97\n",
            "trial 98\n",
            "trial 99\n",
            "trial 0\n",
            "trial 1\n",
            "trial 2\n",
            "trial 3\n",
            "trial 4\n",
            "trial 5\n",
            "trial 6\n",
            "trial 7\n",
            "trial 8\n",
            "trial 9\n",
            "trial 10\n",
            "trial 11\n",
            "trial 12\n",
            "trial 13\n",
            "trial 14\n",
            "trial 15\n",
            "trial 16\n",
            "trial 17\n",
            "trial 18\n",
            "trial 19\n",
            "trial 20\n",
            "trial 21\n",
            "trial 22\n",
            "trial 23\n",
            "trial 24\n",
            "trial 25\n",
            "trial 26\n",
            "trial 27\n",
            "trial 28\n",
            "trial 29\n",
            "trial 30\n",
            "trial 31\n",
            "trial 32\n",
            "trial 33\n",
            "trial 34\n",
            "trial 35\n",
            "trial 36\n",
            "trial 37\n",
            "trial 38\n",
            "trial 39\n",
            "trial 40\n",
            "trial 41\n",
            "trial 42\n",
            "trial 43\n",
            "trial 44\n",
            "trial 45\n",
            "trial 46\n",
            "trial 47\n",
            "trial 48\n",
            "trial 49\n",
            "trial 50\n",
            "trial 51\n",
            "trial 52\n",
            "trial 53\n",
            "trial 54\n",
            "trial 55\n",
            "trial 56\n",
            "trial 57\n",
            "trial 58\n",
            "trial 59\n",
            "trial 60\n",
            "trial 61\n",
            "trial 62\n",
            "trial 63\n",
            "trial 64\n",
            "trial 65\n",
            "trial 66\n",
            "trial 67\n",
            "trial 68\n",
            "trial 69\n",
            "trial 70\n",
            "trial 71\n",
            "trial 72\n",
            "trial 73\n",
            "trial 74\n",
            "trial 75\n",
            "trial 76\n",
            "trial 77\n",
            "trial 78\n",
            "trial 79\n",
            "trial 80\n",
            "trial 81\n",
            "trial 82\n",
            "trial 83\n",
            "trial 84\n",
            "trial 85\n",
            "trial 86\n",
            "trial 87\n",
            "trial 88\n",
            "trial 89\n",
            "trial 90\n",
            "trial 91\n",
            "trial 92\n",
            "trial 93\n",
            "trial 94\n",
            "trial 95\n",
            "trial 96\n",
            "trial 97\n",
            "trial 98\n",
            "trial 99\n",
            "trial 0\n",
            "trial 1\n",
            "trial 2\n",
            "trial 3\n",
            "trial 4\n",
            "trial 5\n",
            "trial 6\n",
            "trial 7\n",
            "trial 8\n",
            "trial 9\n",
            "trial 10\n",
            "trial 11\n",
            "trial 12\n",
            "trial 13\n",
            "trial 14\n",
            "trial 15\n",
            "trial 16\n",
            "trial 17\n",
            "trial 18\n",
            "trial 19\n",
            "trial 20\n",
            "trial 21\n",
            "trial 22\n",
            "trial 23\n",
            "trial 24\n",
            "trial 25\n",
            "trial 26\n",
            "trial 27\n",
            "trial 28\n",
            "trial 29\n",
            "trial 30\n",
            "trial 31\n",
            "trial 32\n",
            "trial 33\n",
            "trial 34\n",
            "trial 35\n",
            "trial 36\n",
            "trial 37\n",
            "trial 38\n",
            "trial 39\n",
            "trial 40\n",
            "trial 41\n",
            "trial 42\n",
            "trial 43\n",
            "trial 44\n",
            "trial 45\n",
            "trial 46\n",
            "trial 47\n",
            "trial 48\n",
            "trial 49\n",
            "trial 50\n",
            "trial 51\n",
            "trial 52\n",
            "trial 53\n",
            "trial 54\n",
            "trial 55\n",
            "trial 56\n",
            "trial 57\n",
            "trial 58\n",
            "trial 59\n",
            "trial 60\n",
            "trial 61\n",
            "trial 62\n",
            "trial 63\n",
            "trial 64\n",
            "trial 65\n",
            "trial 66\n",
            "trial 67\n",
            "trial 68\n",
            "trial 69\n",
            "trial 70\n",
            "trial 71\n",
            "trial 72\n",
            "trial 73\n",
            "trial 74\n",
            "trial 75\n",
            "trial 76\n",
            "trial 77\n",
            "trial 78\n",
            "trial 79\n",
            "trial 80\n",
            "trial 81\n",
            "trial 82\n",
            "trial 83\n",
            "trial 84\n",
            "trial 85\n",
            "trial 86\n",
            "trial 87\n",
            "trial 88\n",
            "trial 89\n",
            "trial 90\n",
            "trial 91\n",
            "trial 92\n",
            "trial 93\n",
            "trial 94\n",
            "trial 95\n",
            "trial 96\n",
            "trial 97\n",
            "trial 98\n",
            "trial 99\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ks9ywY9SWK7z",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}