# Jump Self-attention: Capturing High-order Statistics in Transformers

This is the pytorch implementation of JAT in the following paper: [Jump Self-attention: Capturing High-order Statistics in Transformers]().


## Requirements
+ Python 3.6
+ numpy==1.17.3
+ scipy==1.1.0
+ pandas==0.25.1
+ torch==1.2.0
+ tqdm==4.36.1
+ matplotlib==3.1.1
+ tokenizers==0.10.3
+ ...

Dependencies can be installed using the following command:

```
pip install -r requirements.txt
```

## Usage
We implement BERT-JAT in `huggingface transformers`, you can use BERT-JAT model like BERT model in `huggingface transformers`.

build BERT-JAT

```
from transformers import BertTokenizer, BertModel, BertConfig

config = BertConfig.from_pretrained('bert-base-uncased')

config.order = 2                # order of JAT
config.oheads = 6               # number of JAT heads (total_heads-oheads)
config.olayers = '0,1,2,7,8,9'  # layers using JAT heads
config.Atype = 3                # adjacent matrix construction method of graph convolution in JAT (0,1,2,3)

bert_JAT = BertModel.from_pretrained('bert-base-uncased', config=config)
```

use BERT-JAT

```
from transformers import BertTokenizer, BertModel, BertConfig, 

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

config = BertConfig.from_pretrained('bert-base-uncased')
bert_JAT = BertModel.from_pretrained('bert-base-uncased', config=config)

inputs = tokenizer('hello world',return_tensors="pt")
outputs = bert_JAT(**inputs)
```

## Train Commands
Commands for training and testing the model BERT-JAT on GLUE task (mrpc):

```
python run_glue.py --model_name_or_path bert-base-cased --task_name mrpc --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 10 --config_oheads 8 --config_order 2 --config_Atype 3 --overwrite_output_dir --output_dir /tmp/mrpc/
```