{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import scipy.sparse.csgraph as csg\n",
    "from joblib import Parallel, delayed\n",
    "import multiprocessing\n",
    "import networkx as nx\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import time\n",
    "import math\n",
    "\n",
    "from __future__ import unicode_literals, print_function, division\n",
    "from io import open\n",
    "import unicodedata\n",
    "import string\n",
    "import re\n",
    "import random\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "import learning_util as lu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Distortion calculations\n",
    "\n",
    "def acosh(x):\n",
    "    return torch.log(x + torch.sqrt(x**2-1))\n",
    "\n",
    "def dist_h(u,v):\n",
    "    z  = 2*torch.norm(u-v,2)**2\n",
    "    uu = 1. + torch.div(z,((1-torch.norm(u,2)**2)*(1-torch.norm(v,2)**2)))\n",
    "    return acosh(uu)\n",
    "\n",
    "def distance_matrix_euclidean(input):\n",
    "    row_n = input.shape[0]\n",
    "    mp1 = torch.stack([input]*row_n)\n",
    "    mp2 = torch.stack([input]*row_n).transpose(0,1)\n",
    "    dist_mat = torch.sum((mp1-mp2)**2,2).squeeze()\n",
    "    return dist_mat\n",
    "\n",
    "def distance_matrix_hyperbolic(input):\n",
    "    row_n = input.shape[0]\n",
    "    dist_mat = torch.zeros(row_n, row_n, device=device)\n",
    "    for row in range(row_n):\n",
    "        for i in range(row_n):\n",
    "            if i != row:\n",
    "                dist_mat[row, i] = dist_h(input[row,:], input[i,:])\n",
    "    return dist_mat\n",
    "\n",
    "def entry_is_good(h, h_rec): return (not torch.isnan(h_rec)) and (not torch.isinf(h_rec)) and h_rec != 0 and h != 0\n",
    "\n",
    "def distortion_entry(h,h_rec):\n",
    "    avg = abs(h_rec - h)/h\n",
    "    return avg\n",
    "\n",
    "def distortion_row(H1, H2, n, row):\n",
    "    avg, good = 0, 0\n",
    "    for i in range(n):\n",
    "        if i != row and entry_is_good(H1[i], H2[i]):\n",
    "            _avg = distortion_entry(H1[i], H2[i])\n",
    "            good        += 1\n",
    "            avg         += _avg\n",
    "    if good > 0:\n",
    "        avg /= good \n",
    "    else:\n",
    "        avg, good = torch.tensor(0., device=device, requires_grad=True), torch.tensor(0., device=device, requires_grad=True)\n",
    "    return (avg, good)\n",
    "\n",
    "def distortion(H1, H2, n, jobs=16):\n",
    "#     dists = Parallel(n_jobs=jobs)(delayed(distortion_row)(H1[i,:],H2[i,:],n,i) for i in range(n))\n",
    "    dists = (distortion_row(H1[i,:],H2[i,:],n,i) for i in range(n))\n",
    "    to_stack = [tup[0] for tup in dists]\n",
    "    avg = torch.stack(to_stack).sum()/n\n",
    "    return avg\n",
    "\n",
    "\n",
    "#Loading the graph and getting the distance matrix.\n",
    "\n",
    "def load_graph(file_name, directed=False):\n",
    "    G = nx.DiGraph() if directed else nx.Graph()\n",
    "    with open(file_name, \"r\") as f:\n",
    "        for line in f:\n",
    "            tokens = line.split()\n",
    "            u = int(tokens[0])\n",
    "            v = int(tokens[1])\n",
    "            if len(tokens) > 2:\n",
    "                w = float(tokens[2])\n",
    "                G.add_edge(u, v, weight=w)\n",
    "            else:\n",
    "                G.add_edge(u,v)\n",
    "    return G\n",
    "\n",
    "\n",
    "def compute_row(i, adj_mat): \n",
    "    return csg.dijkstra(adj_mat, indices=[i], unweighted=True, directed=False)\n",
    "\n",
    "def get_dist_mat(G):\n",
    "    n = G.order()\n",
    "    adj_mat = nx.to_scipy_sparse_matrix(G, nodelist=list(range(G.order())))\n",
    "    t = time.time()\n",
    "    \n",
    "    num_cores = multiprocessing.cpu_count()\n",
    "    dist_mat = Parallel(n_jobs=num_cores)(delayed(compute_row)(i,adj_mat) for i in range(n))\n",
    "    dist_mat = np.vstack(dist_mat)\n",
    "    return dist_mat\n",
    "\n",
    "\n",
    "def asMinutes(s):\n",
    "    m = math.floor(s / 60)\n",
    "    s -= m * 60\n",
    "    return '%dm %ds' % (m, s)\n",
    "\n",
    "\n",
    "def timeSince(since, percent):\n",
    "    now = time.time()\n",
    "    s = now - since\n",
    "    es = s / (percent)\n",
    "    rs = es - s\n",
    "    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))\n",
    "\n",
    "\n",
    "def showPlot(points):\n",
    "    plt.figure()\n",
    "    fig, ax = plt.subplots()\n",
    "    loc = ticker.MultipleLocator(base=0.2)\n",
    "    ax.yaxis.set_major_locator(loc)\n",
    "    plt.plot(points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Vocab:\n",
    "    def __init__(self, name):\n",
    "        self.name = name\n",
    "        self.word2index = {}\n",
    "        self.word2count = {}\n",
    "        self.index2word = {}\n",
    "        self.n_words = 0\n",
    "\n",
    "    def addSentence(self, sentence):\n",
    "        for token in sentence:\n",
    "            self.addWord(token['form'])\n",
    "\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1\n",
    "            \n",
    "\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "# Lowercase, trim, and remove non-letter characters\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
    "    s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from conllu import parse_tree, parse_tree_incr, parse, parse_incr\n",
    "from io import open\n",
    "import scipy.sparse.csgraph as csg\n",
    "import networkx as nx\n",
    "from collections import defaultdict\n",
    "import json\n",
    "import string\n",
    "\n",
    "\n",
    "def unroll(node, G):\n",
    "    if len(node.children) != 0:\n",
    "        for child in node.children:\n",
    "            G.add_edge(node.token['id'], child.token['id'])\n",
    "            unroll(child, G)\n",
    "    return G\n",
    "\n",
    "sentences = []\n",
    "data_file = open(\"UD_English-EWT/en_ewt-ud-train.conllu\", \"r\", encoding=\"utf-8\")\n",
    "for sentence in parse_incr(data_file):\n",
    "    sentences.append(sentence)\n",
    "    \n",
    "MIN_LENGTH = 10\n",
    "MAX_LENGTH = 50\n",
    "\n",
    "def check_length(sentence):\n",
    "    return len(sentence) < MAX_LENGTH and len(sentence) > MIN_LENGTH \n",
    "\n",
    "def filterSentences(sentences):\n",
    "    return [sent for sent in sentences if check_length(sent)]\n",
    "\n",
    "input_vocab = Vocab(\"ewt_train_trimmed\")\n",
    "filtered_sentences = filterSentences(sentences)\n",
    "\n",
    "sentences_text = []\n",
    "for sent in filtered_sentences:\n",
    "    input_vocab.addSentence(sent)\n",
    "    sentences_text.append(sent.metadata['text'])\n",
    "    \n",
    "dev_dict  = {}\n",
    "for idx in range(0, len(filtered_sentences)):\n",
    "    curr_tree = filtered_sentences[idx].to_tree()\n",
    "    G_curr = nx.Graph()\n",
    "    G_curr = unroll(curr_tree, G_curr)\n",
    "    G = nx.relabel_nodes(G_curr, lambda x: x-1)\n",
    "    nx.write_edgelist(G, \"train/\"+str(idx)+\".edges\", data=False)\n",
    "    G_final = nx.convert_node_labels_to_integers(G_curr, ordering = \"decreasing degree\")\n",
    "    nx.write_edgelist(G_final, \"ewt_train/\"+str(idx)+\".edges\", data=False)\n",
    "    dev_dict[idx] = list(G_final.edges)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def indexesFromSentence(vocab, sentence):\n",
    "    return [vocab.word2index[token['form']] for token in sentence]\n",
    "\n",
    "def tensorFromSentence(vocab, sentence):\n",
    "    indexes = indexesFromSentence(vocab, sentence)\n",
    "    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\n",
    "\n",
    "def pairfromidx(idx):\n",
    "    input_tensor = tensorFromSentence(input_vocab, filtered_sentences[idx])\n",
    "    G = load_graph(\"random_trees/\"+str(idx)+\".edges\")\n",
    "    target_matrix = get_dist_mat(G)\n",
    "    target_tensor = torch.from_numpy(target_matrix).float().to(device)\n",
    "    target_tensor.requires_grad = False\n",
    "    n = G.order()\n",
    "    return (input_tensor, target_tensor, n, sentences_text[idx])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderLSTM(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super(EncoderLSTM, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.embedding = nn.Embedding(input_size, hidden_size)\n",
    "        self.gru = nn.GRU(hidden_size, hidden_size)\n",
    "\n",
    "    def forward(self, input, hidden):\n",
    "        embedded = self.embedding(input).view(1, 1, -1)\n",
    "        output = embedded\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        return output, hidden\n",
    "\n",
    "    def initHidden(self):\n",
    "        return torch.zeros(1, 1, self.hidden_size, device=device)\n",
    "    \n",
    "\n",
    "class Attention(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, max_length=MAX_LENGTH):\n",
    "        super(Attention, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.max_length = max_length\n",
    "        self.embedding = nn.Embedding(input_size, hidden_size)\n",
    "        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)\n",
    "        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)\n",
    "\n",
    "\n",
    "    def forward(self, input, hidden, encoder_outputs):\n",
    "        embedded = self.embedding(input).view(1, 1, -1)\n",
    "        attention_scores = self.attn(torch.cat((embedded[0], hidden.unsqueeze(0)), 1))\n",
    "        attn_weights = F.softmax(attention_scores, dim=0)\n",
    "        attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))\n",
    "        output = torch.cat((embedded[0], attn_applied[0]), 1)\n",
    "        output = self.attn_combine(output).unsqueeze(0)\n",
    "        \n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Hyperbolic modules.\n",
    "\n",
    "class HypLinear(nn.Module):\n",
    "    \"\"\"Applies a hyperbolic \"linear\" transformation to the incoming data: :math:`y = xA^T + b`\n",
    "       Uses hyperbolic formulation of addition, scaling and matrix multiplication.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, in_features, out_features, bias=True):\n",
    "        super(HypLinear, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))\n",
    "\n",
    "        if bias:\n",
    "            self.bias = nn.Parameter(torch.FloatTensor(1, out_features))\n",
    "        else:\n",
    "            self.register_parameter('bias', None)\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        stdv = 1. / math.sqrt(self.weight.size(1))\n",
    "        self.weight.data.uniform_(-stdv, stdv)\n",
    "        if self.bias is not None:\n",
    "            self.bias.data.uniform_(-stdv, stdv)\n",
    "\n",
    "    def forward(self, input_):\n",
    "        result = lu.torch_hyp_add(lu.torch_mv_mul_hyp(torch.transpose(self.weight,0,1), input_), self.bias) #(batch, input) x (input, output)\n",
    "        return result\n",
    "\n",
    "    def extra_repr(self):\n",
    "        return 'in_features={}, out_features={}, bias={}'.format(\n",
    "            self.in_features, self.out_features, self.bias is not None\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Trains a Euclidean LSTM.\n",
    "\n",
    "def trainVanilla(input_tensor, ground_truth, n, encoder, encoder_optimizer, max_length=MAX_LENGTH):\n",
    "    encoder_hidden = encoder.initHidden()\n",
    "    encoder_optimizer.zero_grad()\n",
    " \n",
    "    input_length = input_tensor.size(0)\n",
    "    target_length = ground_truth.size(0)\n",
    "    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n",
    "    final_embeddings = torch.zeros(input_length, encoder.hidden_size, device=device)\n",
    "\n",
    "    loss = 0\n",
    "    for ei in range(input_length):\n",
    "        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
    "        encoder_outputs[ei] = encoder_output[0, 0]\n",
    "    \n",
    "    dist_recovered = distance_matrix_euclidean(encoder_outputs)\n",
    "    loss += distortion(ground_truth, dist_recovered, n)\n",
    "    loss.backward()\n",
    "    encoder_optimizer.step()\n",
    "\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Trains Euclidean LSTM + Attention.\n",
    "\n",
    "def trainWAttention(input_tensor, ground_truth, n, encoder, encoder_optimizer, attention, attention_optimizer, iter, max_length=MAX_LENGTH):\n",
    "    encoder_hidden = encoder.initHidden()\n",
    "    encoder_optimizer.zero_grad()\n",
    "    attention_optimizer.zero_grad()\n",
    "\n",
    "    input_length = input_tensor.size(0)\n",
    "    target_length = ground_truth.size(0)\n",
    "    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n",
    "    encoder_hiddens = torch.zeros(input_length, encoder.hidden_size, device=device)\n",
    "    final_embeddings = torch.zeros(input_length, encoder.hidden_size, device=device)\n",
    "\n",
    "    loss = 0\n",
    "    for ei in range(input_length):\n",
    "        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
    "        encoder_outputs[ei] = encoder_output[0, 0]\n",
    "        encoder_hiddens[ei] = encoder_hidden[0, 0]\n",
    "        \n",
    "    for idx in range(input_length):\n",
    "        output = attention(input_tensor[idx], encoder_hiddens[idx], encoder_outputs)\n",
    "        final_embeddings[idx] = output[0]\n",
    "        \n",
    "    dist_recovered = distance_matrix_euclidean(final_embeddings)\n",
    "    loss += distortion(ground_truth, dist_recovered, n)\n",
    "    loss.backward()\n",
    "    encoder_optimizer.step()\n",
    "    attention_optimizer.step()\n",
    "\n",
    "    return loss.item(), final_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainEuclidean(encoder, attention, n_epochs=10, n_iters=500, print_every=50, plot_every=100, learning_rate=0.01):\n",
    "    start = time.time()\n",
    "    plot_losses = []\n",
    "    print_loss_total = 0  \n",
    "    plot_loss_total = 0  \n",
    "\n",
    "    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\n",
    "    attention_optimizer = optim.SGD(attention.parameters(), lr=learning_rate)\n",
    "    training_pairs = [pairfromidx(idx) for idx in range(n_iters)]\n",
    "\n",
    "    euclidean_emb_dict = {}\n",
    "    for i in range(n_epochs):\n",
    "        for iter in range(1, n_iters+1):     \n",
    "            training_pair = training_pairs[iter-1]\n",
    "            input_tensor = training_pair[0]\n",
    "            print(\"input tensor\", input_tensor)\n",
    "            target_matrix = training_pair[1]\n",
    "            n = training_pair[2]\n",
    "            loss, final_embeddings = trainWAttention(input_tensor, target_matrix, n, encoder, encoder_optimizer, attention, attention_optimizer, iter-1)\n",
    "            torch.save(final_embeddings, \"saved_tensors_tree/\"+str(iter-1)+\".pt\")\n",
    "            euclidean_emb_dict[iter-1] = final_embeddings\n",
    "    #         loss = train(input_tensor, target_matrix, n, encoder, encoder_optimizer)\n",
    "            print_loss_total += loss\n",
    "            plot_loss_total += loss\n",
    "\n",
    "            if iter % print_every == 0:\n",
    "                print_loss_avg = print_loss_total / print_every\n",
    "                print_loss_total = 0\n",
    "                print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),\n",
    "                                             iter, iter / n_iters * 100, print_loss_avg))\n",
    "\n",
    "            if iter % plot_every == 0:\n",
    "                plot_loss_avg = plot_loss_total / plot_every\n",
    "                plot_losses.append(plot_loss_avg)\n",
    "                plot_loss_total = 0\n",
    "    \n",
    "    return euclidean_emb_dict\n",
    "\n",
    "# hidden_size = 100\n",
    "# encoder = EncoderLSTM(input_vocab.n_words, hidden_size).to(device)\n",
    "# attention = Attention(input_vocab.n_words, hidden_size).to(device)\n",
    "# euclidean_emb_dict = trainEuclidean(encoder, attention)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "euclidean_embeddings = {}\n",
    "saved_tensors = os.listdir(\"tree_emb_saved/\")\n",
    "indices = []\n",
    "for file in saved_tensors:\n",
    "    idx = int(file.split(\".\")[0])\n",
    "    indices.append(idx)\n",
    "    euclidean_embeddings[idx] = torch.load(\"tree_emb_saved/\"+str(file), map_location=torch.device('cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Riemannian SGD\n",
    "\n",
    "from torch.optim.optimizer import Optimizer, required\n",
    "spten_t = torch.sparse.FloatTensor\n",
    "\n",
    "\n",
    "def poincare_grad(p, d_p):\n",
    "    \"\"\"\n",
    "    Calculates Riemannian grad from Euclidean grad.\n",
    "    Args:\n",
    "        p (Tensor): Current point in the ball\n",
    "        d_p (Tensor): Euclidean gradient at p\n",
    "    \"\"\"\n",
    "    if d_p.is_sparse:\n",
    "        p_sqnorm = torch.sum(\n",
    "            p.data[d_p._indices()[0].squeeze()] ** 2, dim=1,\n",
    "            keepdim=True\n",
    "        ).expand_as(d_p._values())\n",
    "        n_vals = d_p._values() * ((1 - p_sqnorm) ** 2) / 4\n",
    "        d_p = spten_t(d_p._indices(), n_vals, d_p.size())\n",
    "    else:\n",
    "        p_sqnorm = torch.sum(p.data ** 2, dim=-1, keepdim=True)\n",
    "        d_p = d_p * ((1 - p_sqnorm) ** 2 / 4).expand_as(d_p)\n",
    "\n",
    "    return d_p\n",
    "\n",
    "\n",
    "def euclidean_grad(p, d_p):\n",
    "    return d_p\n",
    "\n",
    "\n",
    "def retraction(p, d_p, lr):\n",
    "    # Gradient clipping.\n",
    "    if torch.all(d_p < 1000) and torch.all(d_p>-1000):\n",
    "        p.data.add_(-lr, d_p)\n",
    "\n",
    "\n",
    "class RiemannianSGD(Optimizer):\n",
    "    r\"\"\"Riemannian stochastic gradient descent.\n",
    "    Args:\n",
    "        params (iterable): iterable of parameters to optimize or dicts defining\n",
    "            parameter groups\n",
    "        rgrad (Function): Function to compute the Riemannian gradient from\n",
    "            an Euclidean gradient\n",
    "        retraction (Function): Function to update the parameters via a\n",
    "            retraction of the Riemannian gradient\n",
    "        lr (float): learning rate\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, params, lr=required, rgrad=required, retraction=required):\n",
    "        defaults = dict(lr=lr, rgrad=rgrad, retraction=retraction)\n",
    "        super(RiemannianSGD, self).__init__(params, defaults)\n",
    "\n",
    "    def step(self, lr=None):\n",
    "        \"\"\"Performs a single optimization step.\n",
    "        Arguments:\n",
    "            lr (float, optional): learning rate for the current update.\n",
    "        \"\"\"\n",
    "        loss = None\n",
    "\n",
    "        for group in self.param_groups:\n",
    "            for p in group['params']:\n",
    "                if p.grad is None:\n",
    "                    continue\n",
    "                d_p = p.grad.data\n",
    "                if lr is None:\n",
    "                    lr = group['lr']\n",
    "                d_p = group['rgrad'](p, d_p)\n",
    "                group['retraction'](p, d_p, lr)\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Does Euclidean to hyperbolic mapping using series of FC layers.\n",
    "# We use ground truth distance matrix for the pair since the distortion for hyperbolic embs are really low.\n",
    "\n",
    "def trainFCHyp(input_matrix, ground_truth, n, mapping, mapping_optimizer, max_length=MAX_LENGTH):\n",
    "    mapping_optimizer.zero_grad()\n",
    " \n",
    "    loss = 0\n",
    "    output = mapping(input_matrix.float())\n",
    "    dist_recovered = distance_matrix_hyperbolic(output) \n",
    "    loss += distortion(ground_truth, dist_recovered, n)\n",
    "    loss.backward()\n",
    "    mapping_optimizer.step()\n",
    "\n",
    "    return loss.item()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainFCIters(mapping, n_epochs=5, n_iters=500, print_every=50, plot_every=100, learning_rate=0.01):\n",
    "    start = time.time()\n",
    "    plot_losses = []\n",
    "    print_loss_total = 0  \n",
    "    plot_loss_total = 0  \n",
    "\n",
    "    mapping_optimizer = RiemannianSGD(mapping.parameters(), lr=learning_rate, rgrad=poincare_grad, retraction=retraction)\n",
    "    training_pairs = [pairfromidx(idx) for idx in range(n_iters)]\n",
    "\n",
    "    for i in range(n_epochs):\n",
    "        print(\"Starting epoch \"+str(i))\n",
    "        iter=1\n",
    "        for idx in indices:     \n",
    "            input_matrix = euclidean_embeddings[idx]\n",
    "            target_matrix = training_pairs[idx][1]\n",
    "            n = training_pairs[idx][2]\n",
    "            loss = trainFCHyp(input_matrix, target_matrix, n, mapping, mapping_optimizer)\n",
    "            print_loss_total += loss\n",
    "            plot_loss_total += loss\n",
    "\n",
    "            if iter % print_every == 0:\n",
    "                print_loss_avg = print_loss_total / print_every\n",
    "                print_loss_total = 0\n",
    "                print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),\n",
    "                                             iter, iter / n_iters * 100, print_loss_avg))\n",
    "\n",
    "            if iter % plot_every == 0:\n",
    "                plot_loss_avg = plot_loss_total / plot_every\n",
    "                plot_losses.append(plot_loss_avg)\n",
    "                plot_loss_total = 0\n",
    "\n",
    "            iter+=1\n",
    "            \n",
    "input_size = 10\n",
    "output_size = 10\n",
    "mapping = nn.Sequential(\n",
    "          nn.Linear(input_size, 50).to(device),\n",
    "          nn.ReLU().to(device),\n",
    "          nn.Linear(50, output_size).to(device),\n",
    "          nn.ReLU().to(device))\n",
    "          \n",
    "# trainFCIters(mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainHyperbolic(input_tensor, ground_truth, n, encoder, encoder_optimizer, mapping, mapping_optimizer, max_length=MAX_LENGTH):\n",
    "    encoder_hidden = encoder.initHidden()\n",
    "    encoder_optimizer.zero_grad()\n",
    "    mapping_optimizer.zero_grad()\n",
    " \n",
    "    input_length = input_tensor.size(0)\n",
    "    target_length = ground_truth.size(0)\n",
    "    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n",
    "    final_embeddings = torch.zeros(input_length, encoder.hidden_size, device=device)\n",
    "\n",
    "    loss = 0\n",
    "    for ei in range(input_length):\n",
    "        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
    "        encoder_outputs[ei] = encoder_output[0, 0]\n",
    "        \n",
    "    final_embeddings = mapping(encoder_outputs)\n",
    "\n",
    "    dist_recovered = distance_matrix_hyperbolic(final_embeddings) \n",
    "    loss += distortion(ground_truth, dist_recovered, n)\n",
    "    loss.backward()\n",
    "    encoder_optimizer.step()\n",
    "    mapping_optimizer.step()\n",
    "\n",
    "    return loss.item()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Does end to end hyperbolic training.\n",
    "\n",
    "def trainHypIters(encoder, fc, n_iters=7600, print_every=500, plot_every=100, learning_rate=0.1):\n",
    "    start = time.time()\n",
    "    plot_losses = []\n",
    "    print_loss_total = 0  \n",
    "    plot_loss_total = 0  \n",
    "\n",
    "    encoder_optimizer = RiemannianSGD(encoder.parameters(), lr=learning_rate, rgrad=poincare_grad, retraction=retraction)\n",
    "    fc_optimizer = RiemannianSGD(fc.parameters(), lr=learning_rate, rgrad=poincare_grad, retraction=retraction)\n",
    "    training_pairs = [pairfromidx(idx) for idx in range(n_iters)]\n",
    "\n",
    "    for iter in range(1, n_iters + 1):     \n",
    "        training_pair = training_pairs[iter - 1]\n",
    "        input_tensor = training_pair[0]\n",
    "        target_matrix = training_pair[1]\n",
    "        n = training_pair[2]\n",
    "        loss = trainHyperbolic(input_tensor, target_matrix, n, encoder, encoder_optimizer, fc, fc_optimizer)\n",
    "        print_loss_total += loss\n",
    "        plot_loss_total += loss\n",
    "\n",
    "        if iter % print_every == 0:\n",
    "            print_loss_avg = print_loss_total / print_every\n",
    "            print_loss_total = 0\n",
    "            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),\n",
    "                                         iter, iter / n_iters * 100, print_loss_avg))\n",
    "\n",
    "        if iter % plot_every == 0:\n",
    "            plot_loss_avg = plot_loss_total / plot_every\n",
    "            plot_losses.append(plot_loss_avg)\n",
    "            plot_loss_total = 0\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_size = 100\n",
    "encoder = EncoderLSTM(input_vocab.n_words, hidden_size).to(device)\n",
    "input_size = 100\n",
    "output_size = 100\n",
    "fc = nn.Sequential(\n",
    "          nn.Linear(input_size, 1000).to(device),\n",
    "          nn.ReLU().to(device),\n",
    "          nn.Linear(1000, 500).to(device),\n",
    "          nn.ReLU().to(device),\n",
    "          nn.Linear(500, 50).to(device),\n",
    "          nn.ReLU().to(device),\n",
    "          nn.Linear(50, output_size).to(device))\n",
    "          \n",
    "# trainHypIters(encoder, fc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
