# encoding: utf-8
"""Class Declaration of Transformer's Attention."""

import chainer

import chainer.functions as F
import chainer.links as L

import numpy as np

MIN_VALUE = float(np.finfo(np.float32).min)


class MultiHeadAttention(chainer.Chain):
    """Multi Head Attention Layer.

    Args:
        n_units (int): Number of input units.
        h (int): Number of attention heads.
        dropout (float): Dropout rate.
        initialW: Initializer to initialize the weight.
        initial_bias: Initializer to initialize the bias.

    :param int h: the number of heads
    :param int n_units: the number of features
    :param float dropout_rate: dropout rate

    """

    def __init__(self, n_units, h=8, dropout=0.1, initialW=None, initial_bias=None):
        """Initialize MultiHeadAttention."""
        super(MultiHeadAttention, self).__init__()
        assert n_units % h == 0
        stvd = 1.0 / np.sqrt(n_units)
        with self.init_scope():
            self.linear_q = L.Linear(
                n_units,
                n_units,
                initialW=initialW(scale=stvd),
                initial_bias=initial_bias(scale=stvd),
            )
            self.linear_k = L.Linear(
                n_units,
                n_units,
                initialW=initialW(scale=stvd),
                initial_bias=initial_bias(scale=stvd),
            )
            self.linear_v = L.Linear(
                n_units,
                n_units,
                initialW=initialW(scale=stvd),
                initial_bias=initial_bias(scale=stvd),
            )
            self.linear_out = L.Linear(
                n_units,
                n_units,
                initialW=initialW(scale=stvd),
                initial_bias=initial_bias(scale=stvd),
            )
        self.d_k = n_units // h
        self.h = h
        self.dropout = dropout
        self.attn = None

    def forward(self, e_var, s_var=None, mask=None, batch=1):
        """Core function of the Multi-head attention layer.

        Args:
            e_var (chainer.Variable): Variable of input array.
            s_var (chainer.Variable): Variable of source array from encoder.
            mask (chainer.Variable): Attention mask.
            batch (int): Batch size.

        Returns:
            chainer.Variable: Outout of multi-head attention layer.

        """
        xp = self.xp
        if s_var is None:
            # batch, head, time1/2, d_k)
            Q = self.linear_q(e_var).reshape(batch, -1, self.h, self.d_k)
            K = self.linear_k(e_var).reshape(batch, -1, self.h, self.d_k)
            V = self.linear_v(e_var).reshape(batch, -1, self.h, self.d_k)
        else:
            Q = self.linear_q(e_var).reshape(batch, -1, self.h, self.d_k)
            K = self.linear_k(s_var).reshape(batch, -1, self.h, self.d_k)
            V = self.linear_v(s_var).reshape(batch, -1, self.h, self.d_k)
        scores = F.matmul(F.swapaxes(Q, 1, 2), K.transpose(0, 2, 3, 1)) / np.sqrt(
            self.d_k
        )
        if mask is not None:
            mask = xp.stack([mask] * self.h, axis=1)
            scores = F.where(mask, scores, xp.full(scores.shape, MIN_VALUE, "f"))
        self.attn = F.softmax(scores, axis=-1)
        p_attn = F.dropout(self.attn, self.dropout)
        x = F.matmul(p_attn, F.swapaxes(V, 1, 2))
        x = F.swapaxes(x, 1, 2).reshape(-1, self.h * self.d_k)
        return self.linear_out(x)
