#!/bin/bash

PYTHON_VERSION=$(python -c 'import sys; print(sys.version_info[:2])')

if [ "$PYTHON_VERSION" = "(3, 8)" ]; then
    JAXLIB_WHEEL="jaxlib-0.4.1+cuda11.cudnn82-cp38-cp38-manylinux2014_x86_64.whl"
elif [ "$PYTHON_VERSION" = "(3, 9)" ]; then
    JAXLIB_WHEEL="jaxlib-0.4.1+cuda11.cudnn82-cp39-cp39-manylinux2014_x86_64.whl"
elif [ "$PYTHON_VERSION" = "(3, 10)" ]; then
    JAXLIB_WHEEL="jaxlib-0.4.1+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"
elif [ "$PYTHON_VERSION" = "(3, 11)" ]; then
    JAXLIB_WHEEL="jaxlib-0.4.1+cuda11.cudnn82-cp311-cp311-manylinux2014_x86_64.whl"
else
    echo "Unsupported Python version: $PYTHON_VERSION"
    exit 1
fi

pip install https://storage.googleapis.com/jax-releases/cuda11/$JAXLIB_WHEEL
pip install jax==0.4.1
