# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch RoBERTa model. """


import logging

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss

from transformers.configuration_roberta import RobertaConfig
from transformers.file_utils import add_start_docstrings
from transformers.modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu


logger = logging.getLogger(__name__)

ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
  "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
  "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
  "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
  "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-pytorch_model.bin",
  "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-pytorch_model.bin",
  "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-pytorch_model.bin",
}


class RobertaEmbeddings(BertEmbeddings):
  """
  Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  """

  def __init__(self, config):
    super(RobertaEmbeddings, self).__init__(config)
    self.padding_idx = 1
    self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
    self.position_embeddings = nn.Embedding(
        config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
    )

  def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
    if position_ids is None:
      if input_ids is not None:
        # Create the position ids from the input token ids. Any padded tokens remain padded.
        position_ids = self.create_position_ids_from_input_ids(input_ids).to(input_ids.device)
      else:
        position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)

    return super(RobertaEmbeddings, self).forward(
        input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds
    )

  def create_position_ids_from_input_ids(self, x):
    """ Replace non-padding symbols with their position numbers. Position numbers begin at
    padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
    `utils.make_positions`.

    :param torch.Tensor x:
    :return torch.Tensor:
    """
    mask = x.ne(self.padding_idx).long()
    incremental_indicies = torch.cumsum(mask, dim=1) * mask
    return incremental_indicies + self.padding_idx

  def create_position_ids_from_inputs_embeds(self, inputs_embeds):
    """ We are provided embeddings directly. We cannot infer which are padded so just generate
    sequential position ids.

    :param torch.Tensor inputs_embeds:
    :return torch.Tensor:
    """
    input_shape = inputs_embeds.size()[:-1]
    sequence_length = input_shape[1]

    position_ids = torch.arange(
        self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
    )
    return position_ids.unsqueeze(0).expand(input_shape)


ROBERTA_START_DOCSTRING = r"""    The RoBERTa model was proposed in
    `RoBERTa: A Robustly Optimized BERT Pretraining Approach`_
    by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer,
    Veselin Stoyanov. It is based on Google's BERT model released in 2018.

    It builds on BERT and modifies key hyperparameters, removing the next-sentence pretraining
    objective and training with much larger mini-batches and learning rates.

    This implementation is the same as BertModel with a tiny embeddings tweak as well as a setup for Roberta pretrained
    models.

    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.

    .. _`RoBERTa: A Robustly Optimized BERT Pretraining Approach`:
        https://arxiv.org/abs/1907.11692

    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module

    Parameters:
        config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
            model. Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

ROBERTA_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            To match pre-training, RoBERTa input sequence should be formatted with <s> and </s> tokens as follows:

            (a) For sequence pairs:

                ``tokens:         <s> Is this Jacksonville ? </s> </s> No it is not . </s>``

            (b) For single sequences:

                ``tokens:         <s> the dog is hairy . </s>``

            Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
            the ``add_special_tokens`` parameter set to ``True``.

            RoBERTa is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.

            See :func:`transformers.PreTrainedTokenizer.encode` and
            :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
        **token_type_ids**: (`optional` need to be trained) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Optional segment token indices to indicate first and second portions of the inputs.
            This embedding matrice is not trained (not pretrained during RoBERTa pretraining), you will have to train it
            during finetuning.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1[``.
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
        **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
            Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
"""


@add_start_docstrings(
    "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
    ROBERTA_START_DOCSTRING,
    ROBERTA_INPUTS_DOCSTRING,
)
class RobertaModel(BertModel):
  r"""
  Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
      **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
          Sequence of hidden-states at the output of the last layer of the model.
      **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
          Last layer hidden-state of the first token of the sequence (classification token)
          further processed by a Linear layer and a Tanh activation function. The Linear
          layer weights are trained from the next sentence prediction (classification)
          objective during Bert pretraining. This output is usually *not* a good summary
          of the semantic content of the input, you're often better with averaging or pooling
          the sequence of hidden-states for the whole input sequence.
      **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
          list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
          of shape ``(batch_size, sequence_length, hidden_size)``:
          Hidden-states of the model at the output of each layer plus the initial embedding outputs.
      **attentions**: (`optional`, returned when ``config.output_attentions=True``)
          list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
          Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  Examples::

      tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
      model = RobertaModel.from_pretrained('roberta-base')
      input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
      outputs = model(input_ids)
      last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

  """
  config_class = RobertaConfig
  pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
  base_model_prefix = "roberta"

  def __init__(self, config):
    super(RobertaModel, self).__init__(config)

    self.embeddings = RobertaEmbeddings(config)
    self.init_weights()

  def get_input_embeddings(self):
    return self.embeddings.word_embeddings

  def set_input_embeddings(self, value):
    self.embeddings.word_embeddings = value


@add_start_docstrings(
  """RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING
)
class RobertaForMaskedLM(BertPreTrainedModel):
  r"""
      **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
          Labels for computing the masked language modeling loss.
          Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
          Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
          in ``[0, ..., config.vocab_size]``

  Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
      **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
          Masked language modeling loss.
      **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
          Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
      **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
          list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
          of shape ``(batch_size, sequence_length, hidden_size)``:
          Hidden-states of the model at the output of each layer plus the initial embedding outputs.
      **attentions**: (`optional`, returned when ``config.output_attentions=True``)
          list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
          Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  Examples::

      tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
      model = RobertaForMaskedLM.from_pretrained('roberta-base')
      input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
      outputs = model(input_ids, masked_lm_labels=input_ids)
      loss, prediction_scores = outputs[:2]

  """
  config_class = RobertaConfig
  pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
  base_model_prefix = "roberta"

  def __init__(self, config):
    super(RobertaForMaskedLM, self).__init__(config)

    self.roberta = RobertaModel(config)
    self.lm_head = RobertaLMHead(config)

    self.init_weights()

  def get_output_embeddings(self):
    return self.lm_head.decoder

  def forward(
      self,
      input_ids=None,
      attention_mask=None,
      token_type_ids=None,
      position_ids=None,
      head_mask=None,
      inputs_embeds=None,
      masked_lm_labels=None,
  ):
    outputs = self.roberta(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
    )
    sequence_output = outputs[0]
    prediction_scores = self.lm_head(sequence_output)

    outputs = (prediction_scores,) + outputs[2:]  # Add hidden states and attention if they are here

    if masked_lm_labels is not None:
      loss_fct = CrossEntropyLoss()
      masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
      outputs = (masked_lm_loss,) + outputs

    return outputs  # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)


class RobertaLMHead(nn.Module):
  """Roberta Head for masked language modeling."""

  def __init__(self, config):
    super(RobertaLMHead, self).__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    self.bias = nn.Parameter(torch.zeros(config.vocab_size))

  def forward(self, features, **kwargs):
    x = self.dense(features)
    x = gelu(x)
    x = self.layer_norm(x)

    # project back to size of vocabulary with bias
    x = self.decoder(x) + self.bias

    return x


@add_start_docstrings(
    """RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
    on top of the pooled output) e.g. for GLUE tasks. """,
    ROBERTA_START_DOCSTRING,
    ROBERTA_INPUTS_DOCSTRING,
)
class RobertaForSequenceClassification(BertPreTrainedModel):
  r"""
      **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
          Labels for computing the sequence classification/regression loss.
          Indices should be in ``[0, ..., config.num_labels]``.
          If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
          If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).

  Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
      **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
          Classification (or regression if config.num_labels==1) loss.
      **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
          Classification (or regression if config.num_labels==1) scores (before SoftMax).
      **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
          list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
          of shape ``(batch_size, sequence_length, hidden_size)``:
          Hidden-states of the model at the output of each layer plus the initial embedding outputs.
      **attentions**: (`optional`, returned when ``config.output_attentions=True``)
          list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
          Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  Examples::

      tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
      model = RobertaForSequenceClassification.from_pretrained('roberta-base')
      input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
      labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
      outputs = model(input_ids, labels=labels)
      loss, logits = outputs[:2]

  """
  config_class = RobertaConfig
  pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
  base_model_prefix = "roberta"

  def __init__(self, config):
    super(RobertaForSequenceClassification, self).__init__(config)
    self.num_labels = config.num_labels

    self.roberta = RobertaModel(config)
    self.classifier = RobertaClassificationHead(config)

  def forward(
      self,
      input_ids=None,
      attention_mask=None,
      token_type_ids=None,
      position_ids=None,
      head_mask=None,
      inputs_embeds=None,
      labels=None,
  ):
    outputs = self.roberta(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
    )
    sequence_output = outputs[0]
    logits = self.classifier(sequence_output)

    outputs = (logits,) + outputs[2:]
    if labels is not None:
      if self.num_labels == 1:
        #  We are doing regression
        loss_fct = MSELoss()
        loss = loss_fct(logits.view(-1), labels.view(-1))
      else:
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
      outputs = (loss,) + outputs

    return outputs  # (loss), logits, (hidden_states), (attentions)


@add_start_docstrings(
    """Roberta Model with a multiple choice classification head on top (a linear layer on top of
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
    ROBERTA_START_DOCSTRING,
    ROBERTA_INPUTS_DOCSTRING,
)
class RobertaForMultipleChoice(BertPreTrainedModel):
  r"""
  Inputs:
      **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
          Indices of input sequence tokens in the vocabulary.
          The second dimension of the input (`num_choices`) indicates the number of choices to score.
          To match pre-training, RoBerta input sequence should be formatted with [CLS] and [SEP] tokens as follows:

          (a) For sequence pairs:

              ``tokens:         [CLS] is this jack ##son ##ville ? [SEP] [SEP] no it is not . [SEP]``

              ``token_type_ids:   0   0  0    0    0     0       0   0   0     1  1  1  1   1   1``

          (b) For single sequences:

              ``tokens:         [CLS] the dog is hairy . [SEP]``

              ``token_type_ids:   0   0   0   0  0     0   0``

          Indices can be obtained using :class:`transformers.BertTokenizer`.
          See :func:`transformers.PreTrainedTokenizer.encode` and
          :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
      **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
          Segment token indices to indicate first and second portions of the inputs.
          The second dimension of the input (`num_choices`) indicates the number of choices to score.
          Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
      **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
          Mask to avoid performing attention on padding token indices.
          The second dimension of the input (`num_choices`) indicates the number of choices to score.
          Mask values selected in ``[0, 1]``:
          ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
      **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
          Mask to nullify selected heads of the self-attention modules.
          Mask values selected in ``[0, 1]``:
          ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
      **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
          Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
          This is useful if you want more control over how to convert `input_ids` indices into associated vectors
          than the model's internal embedding lookup matrix.
      **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
          Labels for computing the multiple choice classification loss.
          Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
          of the input tensors. (see `input_ids` above)

  Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
      **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
          Classification loss.
      **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
          of the input tensors. (see `input_ids` above).
          Classification scores (before SoftMax).
      **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
          list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
          of shape ``(batch_size, sequence_length, hidden_size)``:
          Hidden-states of the model at the output of each layer plus the initial embedding outputs.
      **attentions**: (`optional`, returned when ``config.output_attentions=True``)
          list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
          Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  Examples::

      tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
      model = RobertaForMultipleChoice.from_pretrained('roberta-base')
      choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
      input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0)  # Batch size 1, 2 choices
      labels = torch.tensor(1).unsqueeze(0)  # Batch size 1
      outputs = model(input_ids, labels=labels)
      loss, classification_scores = outputs[:2]

  """
  config_class = RobertaConfig
  pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
  base_model_prefix = "roberta"

  def __init__(self, config):
    super(RobertaForMultipleChoice, self).__init__(config)

    self.roberta = RobertaModel(config)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, 1)

    self.init_weights()

  def forward(
      self,
      input_ids=None,
      token_type_ids=None,
      attention_mask=None,
      labels=None,
      position_ids=None,
      head_mask=None,
      inputs_embeds=None,
  ):
    num_choices = input_ids.shape[1]

    flat_input_ids = input_ids.view(-1, input_ids.size(-1))
    flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
    flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
    flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
    outputs = self.roberta(
        flat_input_ids,
        position_ids=flat_position_ids,
        token_type_ids=flat_token_type_ids,
        attention_mask=flat_attention_mask,
        head_mask=head_mask,
    )
    pooled_output = outputs[1]

    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)
    reshaped_logits = logits.view(-1, num_choices)

    outputs = (reshaped_logits,) + outputs[2:]  # add hidden states and attention if they are here

    if labels is not None:
      loss_fct = CrossEntropyLoss()
      loss = loss_fct(reshaped_logits, labels)
      outputs = (loss,) + outputs

    return outputs  # (loss), reshaped_logits, (hidden_states), (attentions)


@add_start_docstrings(
    """Roberta Model with a token classification head on top (a linear layer on top of
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
    ROBERTA_START_DOCSTRING,
    ROBERTA_INPUTS_DOCSTRING,
)
class RobertaForTokenClassification(BertPreTrainedModel):
  r"""
      **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
          Labels for computing the token classification loss.
          Indices should be in ``[0, ..., config.num_labels - 1]``.

  Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
      **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
          Classification loss.
      **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
          Classification scores (before SoftMax).
      **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
          list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
          of shape ``(batch_size, sequence_length, hidden_size)``:
          Hidden-states of the model at the output of each layer plus the initial embedding outputs.
      **attentions**: (`optional`, returned when ``config.output_attentions=True``)
          list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
          Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  Examples::

      tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
      model = RobertaForTokenClassification.from_pretrained('roberta-base')
      input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
      labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0)  # Batch size 1
      outputs = model(input_ids, labels=labels)
      loss, scores = outputs[:2]

  """
  config_class = RobertaConfig
  pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
  base_model_prefix = "roberta"

  def __init__(self, config):
    super(RobertaForTokenClassification, self).__init__(config)
    self.num_labels = config.num_labels

    self.roberta = RobertaModel(config)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    self.init_weights()

  def forward(
      self,
      input_ids=None,
      attention_mask=None,
      token_type_ids=None,
      position_ids=None,
      head_mask=None,
      inputs_embeds=None,
      labels=None,
  ):

    outputs = self.roberta(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
    )

    sequence_output = outputs[0]

    sequence_output = self.dropout(sequence_output)
    logits = self.classifier(sequence_output)

    outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
    if labels is not None:
      loss_fct = CrossEntropyLoss()
      # Only keep active parts of the loss
      if attention_mask is not None:
        active_loss = attention_mask.view(-1) == 1
        active_logits = logits.view(-1, self.num_labels)[active_loss]
        active_labels = labels.view(-1)[active_loss]
        loss = loss_fct(active_logits, active_labels)
      else:
        loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
      outputs = (loss,) + outputs

    return outputs  # (loss), scores, (hidden_states), (attentions)


class RobertaClassificationHead(nn.Module):
  """Head for sentence-level classification tasks."""

  def __init__(self, config):
    super(RobertaClassificationHead, self).__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

  def forward(self, features, **kwargs):
    x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
    x = self.dropout(x)
    x = self.dense(x)
    x = torch.tanh(x)
    x = self.dropout(x)
    x = self.out_proj(x)
    return x


@add_start_docstrings(
    """Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    ROBERTA_START_DOCSTRING,
    ROBERTA_INPUTS_DOCSTRING,
)
class RobertaForQuestionAnswering(BertPreTrainedModel):
  r"""
      **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
          Labels for position (index) of the start of the labelled span for computing the token classification loss.
          Positions are clamped to the length of the sequence (`sequence_length`).
          Position outside of the sequence are not taken into account for computing the loss.
      **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
          Labels for position (index) of the end of the labelled span for computing the token classification loss.
          Positions are clamped to the length of the sequence (`sequence_length`).
          Position outside of the sequence are not taken into account for computing the loss.
  Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
      **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
          Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
      **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
          Span-start scores (before SoftMax).
      **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
          Span-end scores (before SoftMax).
      **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
          list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
          of shape ``(batch_size, sequence_length, hidden_size)``:
          Hidden-states of the model at the output of each layer plus the initial embedding outputs.
      **attentions**: (`optional`, returned when ``config.output_attentions=True``)
          list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
          Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
  Examples::
      tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
      model = RobertaForQuestionAnswering.from_pretrained('roberta-large')
      question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
      input_ids = tokenizer.encode(question, text)
      start_scores, end_scores = model(torch.tensor([input_ids]))
      all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
      answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
  """
  config_class = RobertaConfig
  pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
  base_model_prefix = "roberta"

  def __init__(self, config):
    super(RobertaForQuestionAnswering, self).__init__(config)
    self.num_labels = config.num_labels

    self.roberta = RobertaModel(config)
    self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

    self.init_weights()

  def forward(
      self,
      input_ids,
      attention_mask=None,
      token_type_ids=None,
      position_ids=None,
      head_mask=None,
      start_positions=None,
      end_positions=None,
  ):

    outputs = self.roberta(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
    )

    sequence_output = outputs[0]

    logits = self.qa_outputs(sequence_output)
    start_logits, end_logits = logits.split(1, dim=-1)
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)

    outputs = (start_logits, end_logits,) + outputs[2:]
    if start_positions is not None and end_positions is not None:
      # If we are on multi-GPU, split add a dimension
      if len(start_positions.size()) > 1:
        start_positions = start_positions.squeeze(-1)
      if len(end_positions.size()) > 1:
        end_positions = end_positions.squeeze(-1)
      # sometimes the start/end positions are outside our model inputs, we ignore these terms
      ignored_index = start_logits.size(1)
      start_positions.clamp_(0, ignored_index)
      end_positions.clamp_(0, ignored_index)

      loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
      start_loss = loss_fct(start_logits, start_positions)
      end_loss = loss_fct(end_logits, end_positions)
      total_loss = (start_loss + end_loss) / 2
      outputs = (total_loss,) + outputs

    return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)