{ cuda, pkgs ? import <nixpkgs> {}}:

with pkgs;

let
  py = python3.override {
    packageOverrides = self: super: {
      jaxlib = super.jaxlib.override { cudaSupport = cuda; };
      jax = super.buildPythonPackage rec {
        pname = super.jax.pname;
        version = super.jax.version;
        src = super.jax.src;
        propagatedBuildInputs = super.jax.propagatedBuildInputs;
        doCheck = false;
        doInstallCheck = false;
      };
      neural-tangents = super.buildPythonPackage rec {
        pname = "neural-tangents";
        version = "0.3.8";
        src = python3Packages.fetchPypi {
          inherit pname version;
          sha256 = "sha256-yqrXB3EXyF4e0xY7gIDiQ1a9T9wdCFAs9CvXgYzQnjI=";
        };
        postInstall = ''
          rm -r $out/lib/python3.9/site-packages/examples
        '';
        propagatedBuildInputs = with self; [
          frozendict jax
        ];
        doCheck = false;
      };
      tree-math = super.buildPythonPackage rec {
        pname = "tree-math";
        version = "0.1.0";
        src = python3Packages.fetchPypi {
          inherit pname version;
          sha256 = "sha256-d+uNa6TWz90tmGprw/wtGxYhLwFyhjo8pQlyC6v3WSk=";
        };
        propagatedBuildInputs = with self; [
          jax
        ];
        doCheck = false;
      };
    };
  };
  pyEnv = py.withPackages (pyPkgs: with pyPkgs; [
    neural-tangents
    numpy
    pandas
    matplotlib
    scikitlearn
    jax
    jaxlib
    tqdm
    toolz
    scipy
    tabulate
    sympy
    seaborn
    tree-math
    imagecorruptions
  ]);
  hyLatest = pkgs.hy.overrideAttrs (old: rec {
    version = "1.0a3";
    pname = "hy";
    src = pkgs.fetchFromGitHub {
      owner = "hylang";
      repo = "hy";
      rev = "1.0a3";
      sha256 = "1dqw24rvsps2nab1pbjjm1c81vrs34r4kkk691h3xdyxnv9hb84b";
    };
  });
in
mkShell {
  name = "research";

  buildInputs = [
    hyLatest pyEnv ffmpeg
  ] ++ lib.lists.optionals cuda [
    cudatoolkit_11_2 cudnn_cudatoolkit_11_2
  ];
  shellHook = ''
    export LD_LIBRARY_PATH=${stdenv.cc.cc.lib}/lib
  '' + lib.strings.optionalString cuda ''
    export CUDA_PATH=${pkgs.cudatoolkit_11_2}
    export XLA_FLAGS=--xla_gpu_cuda_data_dir=${pkgs.cudatoolkit_11_2}
    export LD_LIBRARY_PATH=${pkgs.cudatoolkit_11_2.lib}/lib:${pkgs.cudatoolkit_11_2.out}/lib:${pkgs.cudnn_cudatoolkit_11_2}/lib:${pkgs.linuxPackages.nvidia_x11}/lib:$LD_LIBRARY_PATH
   '';
}
