#! /bin/sh

set -eu

BACKEND=mpi
engine="$PWD/benchmark.py"
environment=/dev/null
master_hostname=localhost
output_file=/dev/stdout
hosts=localhost
MASTER_PORT=29500
MASTER_ADDR="$master_hostname:$MASTER_PORT"
WORLD_SIZE=2

errxit() {
    printf "%s\n" "$*" 1>&2
    exit 1
}

usage() {
    cat <<-EOF
Usage: ./run_benchmark [ OPTIONS ]

Optional arguments:
    -., --env FILE
        Set the path to a file to source. When using the MPI backend, the file
        will be sourced before running 'mpirun'. In case of the TCP backend
        the file will be sourced on every host after establishing a successful
        SSH connection. Default: '/dev/null'.

    -b, --backend BACKEND
        Set the backend to benchmark. Default: 'mpi'.

    -e, --engine ENGINE
        Set the path to the benchmarking script to run. Use absolute paths if
        you'll be using the tcp backend. Default: '\$PWD/benchmark.py'.

    --help
        Show this help and exit successfully.

    -h, --hosts HOSTS
        Set the list of hosts to run the benchmark on. Format: 'host1,host2'.
        Default: 'localhost'.

    --max-bytes MAX_BYTES
        Set the inclusive upper limit for tensor size.
        Default: 22 (2**22 = 4 MB).

    --max-num-tensors MAX_NUM_TENSORS
        Set the inclusive upper limit for the number of tensors to be sent
        during one test run. Default: 3 (10**3 = 1000).

    --min-bytes MIN_BYTES
        Set the inclusive lower limit for tensor bytes.
        Default: 19 (2**19 = 512 KB).

    --min-num-tensors MIN_NUM_TENSORS
        Set the inclusive upper limit for the number of tensors to be sent
        during one test run. Default: 2 (10**2 = 100).

    -n, --name NAME
        Set the ip address/host name of the master node. Default: 'localhost'.

    -o, --output FILE
        Set the path to the output file where the master host will append
        benchmark results. Default: '/dev/stdout'.

    -p, --port PORT
        Set the port number master is listening on. Default: '29500'.

    -s, --world-size WORLD_SIZE
        Set the number of processes to be spawned. Default: '2'.
EOF
}

while [ $# -gt 0 ]; do
    case "$1" in
        '-.'|--env)
            environment="$2"
            shift 2
            ;;
        --backend|-b)
            BACKEND="$2"
            shift 2
            ;;
        --engine|-e)
            engine="$2"
            shift 2
            ;;
        --help)
            usage
            exit 0
            ;;
        --hosts|-h)
            hosts="$2"
            shift 2
            ;;
        --port|-p)
            MASTER_PORT="$2"
            shift 2
            ;;
        --min-num-tensors)
            min_num_tensors="--min-num-tensors $2"
            shift 2
            ;;
        --min-bytes)
            min_bytes="--min-bytes $2"
            shift 2
            ;;
        --max-num-tensors)
            max_num_tensors="--max-num-tensors $2"
            shift 2
            ;;
        --max-bytes)
            max_bytes="--max-bytes $2"
            shift 2
            ;;
        --name|-n)
            master_hostname="$2"
            shift 2
            ;;
        --output|-o)
            output_file="$2"
            shift 2
            ;;
        --world-size|-s)
            WORLD_SIZE="$2"
            shift 2
            ;;
        *)
            errxit "Unknown option '$1'"
            ;;
    esac
done

MASTER_ADDR="$master_hostname:$MASTER_PORT"
if [ x"$BACKEND" = xtcp ]; then
    RANK=0
    host_list="$(printf "%s\n" "$hosts" | tr ',' ' ')"
    if [ "$(printf '%s\n' "$host_list" | wc -w)" -ne "$WORLD_SIZE" ]; then
        errxit "Number of hosts ($host_list) doesn't match" \
            "the world size ($WORLD_SIZE)"
    fi
    for host in $host_list; do
        ssh "$host" ". $environment &&" \
            "BACKEND=$BACKEND MASTER_ADDR=$MASTER_ADDR" \
            "MASTER_PORT=$MASTER_PORT WORLD_SIZE=$WORLD_SIZE RANK=$RANK" \
            "python $engine" \
            ">> ${output_file:-}" \
            "${min_num_tensors:-} ${min_bytes:-} ${max_num_tensors:-} ${max_bytes:-}" &
        RANK=$((RANK+1))
    done
    wait
elif [ x"$BACKEND" = xmpi ]; then
    . "$environment"
    export BACKEND
    mpirun -hosts "$hosts" -n "$WORLD_SIZE" >> ${output_file:-} \
        python "$engine" \
        ${min_num_tensors:-} ${min_bytes:-} ${max_num_tensors:-} ${max_bytes:-}
else
    errxit "Invalid backend: '$BACKEND'"
fi

