XLA is a compiler used internally by JAX. JAX is distributed via PyPI wheels. The JAX Continuous Integration documentation explains how to build JAX wheels using the tensorflow/ml-build:latest Docker container.
We can extend these instructions to build XLA targets within the JAX container as well. This ensures that the XLA targets’ build configuration is consistent with the JAX/XLA build configuration, which can be useful if we want to reproduce workload results using XLA tools that were originally created in JAX.
- Clone the JAX repository and navigate to the 'jax' directory
git clone https://github.com/jax-ml/jax.git
cd jax
- Start JAX CI/Release Docker container by running:
./ci/utilities/run_docker_container.sh
This will start a Docker container named 'jax'.
- Build the jax-cuda-plugin target inside the container using:
docker exec jax ./ci/build_artifacts.sh jax-cuda-plugin
This will create the .jax_configure.bazelrc file with the required build configuration, including CUDA/cuDNN support
- Access an interactive shell inside the container:
docker exec -ti jax /bin/bash
You should now be in the /jax
directory within the container
- Build the XLA target with the following command, e.g.:
/usr/local/bin/bazel build \
--config=cuda_libraries_from_stubs \
--verbose_failures=true \
@xla//xla/tools/multihost_hlo_runner:hlo_runner_main
Optionally, you can overwrite HERMETIC
envs, e.g.:
--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_90"
- Copy the resulting artifacts to
/jax/dist
to access them from the host OS if needed
cp bazel-bin/external/xla/xla/tools/multihost_hlo_runner/hlo_runner_main \
./dist/
- Exit the interactive shell:
exit