import argparse
import os
import random


def run(cmd):
    print(cmd)
    os.system(cmd)


def preprocess_cmd():
    TEXT = "iwslt14.tokenized.de-en"
    cmd = f"python preprocess.py --source-lang {args.source} --target-lang {args.target} \
    --trainpref {TEXT}/train --validpref {TEXT}/valid --testpref {TEXT}/test \
    --destdir data-bin/iwslt14.tokenized.{args.source}-{args.target} \
    --workers 20"
    run(cmd)


def train_cmd():
    cmd = f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7  python train.py \
    data-bin/iwslt14.tokenized.{args.source}-{args.target}  \
    --arch transformer_repro_iwslt --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --max-epoch 100 --save-dir checkpoints/iwslt14-{args.source}-{args.target}/ \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 " + " --eval-bleu \
    --eval-bleu-args '{\"beam\": 5, \"max_len_a\": 1.2, \"max_len_b\": 10}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --eval-bleu-print-samples \
    --best-checkpoint-metric bleu --keep-last-epochs 10 \
    --keep-best-checkpoints 10 --maximize-best-checkpoint-metric"
    run(cmd)


def gen_cmd():
    batch_size = 2 * round((32 / args.beam) * 64)
    # batch_size = 1
    if args.avg_ckpt == "True":
        run("python scripts/average_checkpoints.py --inputs checkpoints_iwslt14/ "
            "--num-epoch-checkpoints  5 --output averaged_model.pt")

    shared_cmd = f"CUDA_VISIBLE_DEVICES={args.gpu} python generate.py \
        data-bin/iwslt14.tokenized.{args.source}-{args.target} \
        --path checkpoints/iwslt14-{args.source}-{args.target}/{args.ckpt_name}.pt \
        --batch-size {batch_size} \
        --beam {args.beam} --diverse-beam-groups {args.diverse_beam_groups} \
        --diverse-beam-strength {args.diverse_beam_strength}   \
        --diversity-rate {args.diversity_rate}  --max-len-a 1.2 --max-len-b 10  "
    if args.decoding_method == "sampling":
        shared_cmd += f"--{args.decoding_method}   --sampling-topp {args.sampling_topp}  "
    if args.nbest > 1:
        cmd_all = shared_cmd + f" --remove-bpe --retain-dropout --retain-dropout-modules '[\"TransformerDecoder\"]'  --nbest {args.nbest} | tee gen_results_multi/iwslt14_{args.ckpt_name}_{args.decoding_method}_{args.nbest}.out"
        run(cmd_all)
    else:
        run(shared_cmd + f" --remove-bpe  | tee gen_results/iwslt14_{args.ckpt_name}.out")
        # show the results
        run(f"tail -1 gen_results/iwslt14_{args.ckpt_name}.out")


def gen_cl_cmd():
    batch_size = 2 * round((32 / args.beam) * 64)
    # batch_size = 1
    if args.avg_ckpt == "True":
        run("python scripts/average_checkpoints.py --inputs checkpoints_iwslt14/ "
            "--num-epoch-checkpoints  5 --output averaged_model.pt")

    shared_cmd = f"CUDA_VISIBLE_DEVICES={args.gpu} python my_generate.py \
        data-bin/iwslt14.tokenized.{args.source}-{args.target} \
        --path {args.ckpt_dir}/{args.ckpt_name}.pt --task translation_cl \
        --batch-size {batch_size} \
        --diversity-rate {args.diversity_rate}  --max-len-a 1.2 --max-len-b 10  "
    if args.decoding_method == "sampling":
        shared_cmd += f"--{args.decoding_method}   --sampling-topp {args.sampling_topp}  "
    if args.nbest > 1:
        # --retain-dropout --retain-dropout-modules '[\"TransformerDecoder\"]'
        cmd_all = shared_cmd + f" --remove-bpe  --nbest {args.nbest} | tee gen_results_multi/iwslt14_{args.ckpt_name}_{args.decoding_method}_{args.nbest}.out"
        run(cmd_all)
    else:
        run(shared_cmd + f" --remove-bpe | tee gen_results/iwslt14_{args.ckpt_name}.out")
        # show the results
        run(f"tail -1 gen_results/iwslt14_{args.ckpt_name}.out")


def score_cmd():
    file_name = "iwslt14_checkpoint_best"
    base_dir = "gen_results"
    run(
        f"grep ^T {base_dir}/{file_name}.out" + " | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' " + f"> {base_dir}/{file_name}.ref")
    run(
        f"grep ^H {base_dir}/{file_name}.out" + " | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' " + f"> {base_dir}/{file_name}.sys")

    run(f"python score.py --sys {base_dir}/{file_name}.sys --ref {base_dir}/{file_name}.ref")


def score_multi_cmd():
    file_name = "iwslt14_checkpoint60_wp_beam_search_16"
    base_dir = "gen_results_multi"
    run(
        f"grep ^T {base_dir}/{file_name}.out" + " | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' " + f"> {base_dir}/{file_name}.ref")
    run(
        f"grep ^H {base_dir}/{file_name}.out" + " | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' " + f"> {base_dir}/{file_name}.sys")
    run(f"wc -l {base_dir}/{file_name}.sys")
    run(f"wc -l {base_dir}/{file_name}.ref")
    run(
        f"python score_multi.py --beam {file_name.split('_')[-1]} --sys {base_dir}/{file_name}.sys --ref {base_dir}/{file_name}.ref")


def train_cl_cmd():
    model_dict = "skip_warmup_ckpts/iwslt14-de-en/model-warmed-up.pt"
    cmd = f"CUDA_VISIBLE_DEVICES=0,1,2,3  python train.py \
    data-bin/iwslt14.tokenized.{args.source}-{args.target}  --task translation_cl  \
    --arch transformer_clg{args.version}_iwslt --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 --skip_warmup_ckpt {model_dict} \
    --max-epoch 100 --save-dir checkpoints/iwslt14{args.version}_{args.cl_loss}{args.diverse_bias}-{args.source}-{args.target}/ \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens {4000}  --update-freq 4 --seed {random.randint(0,2023)} \
    --cl_loss {args.cl_loss} --n_gram 2 --eval-bleu \
    --max_len_a {1.2} --validate-interval-updates 100 --max_len_b {10} \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --eval-bleu-print-samples  --log-interval 5 \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric  \
    --maximize-best-checkpoint-metric    --keep-last-epochs 10 --keep-best-checkpoints 10  --fp16 "
    run(cmd)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='run cmd python wrapper'
    )
    parser.add_argument('--mode', choices=["preprocess", "train", "gen", "train_cl", "gen_cl", "score", "score_multi"])
    parser.add_argument('--version', type=str, default="")
    # training parameter
    parser.add_argument('--source', default="de")
    parser.add_argument('--target', default="en")
    parser.add_argument('--cl_loss', default='ranking')
    parser.add_argument('--diverse_bias', default='2.8')


    # gen parameter
    parser.add_argument('--avg_ckpt', default="False")
    parser.add_argument('--ckpt_dir', default="")
    parser.add_argument('--ckpt_name', default="checkpoint_best")
    parser.add_argument('--nbest', default=1, type=int)
    parser.add_argument('--gpu', default="0")
    parser.add_argument('--beam', default=5, type=int)
    parser.add_argument('--decoding_method', default="beam_search",
                        choices=["beam_search", "div_sibling_search", "div_beam_search", "sampling"])

    # no need to set
    parser.add_argument('----sampling_topk', default="5")
    parser.add_argument('--diversity_rate', default=-1.0, type=float)
    parser.add_argument("--diverse_beam_strength", default=0.5, type=float)
    parser.add_argument('--diverse_beam_groups', default=-1)
    args = parser.parse_args()
    if args.decoding_method == "div_sibling_search":
        args.diversity_rate = 1.0
    elif args.decoding_method == "div_beam_search":
        args.diverse_beam_groups = args.beam

    eval(f"{args.mode}_cmd()")
