Due to the distributed nature of JAX with Pathways, some operations might not scale well due to communication overheads. While Pathways minimizes these overheads with features like asynchronous dispatch, there are some things you need to be aware of when you port JAX workloads to Pathways or scale a JAX with Pathways workload to a large number of accelerators.
Before you begin
Make sure you have:
- Installed Kubernetes tools
- Installed the gcloud CLI
- Enabled the TPU API
- Enabled the Google Kubernetes Engine API
- Ensure your Google Cloud project is allowlisted for Pathways
Process index
JAX with Pathways treats all devices across your Pathways cluster as local. This simplifies device management and allows JAX to utilize all available resources. In practice, this means:
jax.process_index()
is always 0 for all devices.jax.devices()
andjax.local_devices()
return all TPU devices across the entire job.
Hardware type and colocation
For best performance, place all Pathways components and the user job in the
same Google Cloud cloud zone. Use a large CPU like the IFRT proxy and resource
manager. We recommend at least a dedicated n2-standard-64
which comes with 64
vCPUs and 256 GB memory.
PathwaysUtils
Pathways-utils is a Python-based GitHub repository providing essential utilities and tools that let you streamline the deployment and execution of JAX workloads on the Pathways on Cloud architecture. This package handles the necessary adaptations for the cloud environment, allowing JAX developers to focus on their core machine learning workflows with minimal platform-specific configuration. Specifically, it offers:
- A "proxy" JAX backend: this custom backend enables your JAX application to
use the Pathways infrastructure by setting the
JAX_PLATFORMS=proxy
environment variable. - Integrated Profiling Utilities: profiling capabilities that let you understand
your application's performance. By using standard JAX profiling APIs like
jax.profiler.start_trace
andjax.profiler.start_server
, you can profile not only your JAX code but also the underlying Pathways components, providing a holistic view of execution within the cloud environment. - Distributed Checkpointing with Orbax: a custom Orbax checkpoint handler that
lets you use distributed checkpoints and restore your checkpoints when using
the Orbax library within the Pathways environment. This integration aims to
work without requiring any changes to your existing Orbax checkpointing code
as long as it imports
pathwaysutils
. - Elastic Training Primitives: provides foundational elastic training primitives that you can use to build robust and scalable training workflows using Pathways. These primitives allow your training jobs to dynamically adapt to changes in available resources, improving efficiency and resilience in cloud environments.
Checkpointing
Orbax is thoroughly tested with Pathways for
distributed checkpointing and restoring with Cloud Storage. When you
call import pathwaysutils; pathwaysutils.initialize()
in your train.py
, a custom
ArrayHandler
is registered that efficiently handles checkpoint operations
through the IFRT
proxy, allowing Pathways workers on accelerators to directly save and restore data.
Colocated Python
Colocated Python is an open-source JAX API that lets you run user-specified Python code directly on the TPU or GPU hosts, which is more straightforward in multi-controller JAX. This enables more compute-intensive tasks, such as data loading and checkpointing, to avoid data transfer between the client and TPU machines. To configure your Pathways cluster to run colocated python JAX API, follow the instructions in the colocated Python README These instructions explain how to start a colocated Python sidecar alongside your Pathways workers.
Data loading
During training we repeatedly load batches from a dataset to feed it into the model. Having an efficient, async data loader which shards the batch across hosts is important to avoid starving accelerators of work. When running training with Pathways, the data loader runs on a CPU VM (unlike a TPU VM which is used on multi-controller setups) and dispatches data to TPU VMs. This incurs a higher latency in reading data but that is mitigated partially by reading ahead X number of batches on the CPU host and dispatching the read data asynchronously to the TPUs. This solution is sufficient when running at small to medium scale.
For optimal performance at scale, we strongly recommend co-locating your input data pipeline by using colocated Python to run your data pipeline directly on the accelerators. This eliminates the CPU bottleneck and leverages the TPU's fast interconnects for data transfer.
You can find a reference implementation of migrating a TFDS based
input pipeline in the RemoteIterator
implementation in
multihost_dataloading.py.
This implementation works on both multi-controller JAX and Pathways in a
distributed manner using the colocated Python JAX API.
Jax Versioning
Pathways releases are tightly coupled with JAX versions to ensure compatibility
and stability. To avoid potential issues, verify that your Pathways artifacts
and your JAX version are aligned. Each Pathways release clearly specifies the
compatible JAX versions through a tag of the form jax-<version>
.
Compilation Cache
Pathways persistent compilation cache is a feature that allows Pathways servers
to store compiled XLA executables in a persistent location, such as
Cloud Storage, to avoid redundant compilation. This feature is enabled by
default. The location of the cache is passed in as --gcs_scratch_location
flag to the resource manager and Pathways worker containers. To keep associated
storage costs to a minimum, the cache attaches a lifecycle policy to the
Cloud Storage location. There is a limit of 50 policies per
Cloud Storage bucket. Therefore, we recommend using a common
Cloud Storage location across all workloads.
This cache is similar to the JAX compilation cache
which is disabled by pathwaysutils.initialize()
for Pathways workloads.
Profiling
You can use the JAX profiler to generate traces of a JAX program. There are two common ways supported with Pathways:
- Programmatic
- Programmatically capture profiles from your JAX code
- Manual
- Capturing profiles on demand after starting the profiler server from your JAX code
In both cases, the profiles are written to a Cloud Storage bucket. There will be multiple trace files created in the Cloud Storage bucket potentially under different timestamp folders, for example:
- Main Python process which invoked the trace (typically your notebook VM):
<jax-client-vm-name>.xplane.pb
- Pathways Resource manager:
server.-1.0000-<pathways-head-pod-node-name>.xplane.pb
- Pathways worker(s):
server.-1.0000-<tpu-node-name>.xplane.pb
These trace files can be analyzed with TensorBoard by running the following command. For more information about TensorBoard and all of its profiling tools, see Optimize TensorFlow performance using the Profiler.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Replace the following:
BUCKET
: a Cloud Storage bucket to store the trace filesPREFIX
: a path within your Cloud Storage bucket to store the trace files
Programmatic profile capture
Capture a profile from inside your code. The profiles are saved inside
gs://<bucket>/<prefix>
under a timestamp directory
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Manual profile capture
To manually capture profile information, you must start the profiler server from your Python code:
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
While the profiler server is running, you can capture a profile and export the data to the target Cloud Storage location:
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
You can find timing information for IFRT proxy client methods like Compile
and
Execute
within your program's trace. These events, which detail the
interactions with the IFRT Proxy gRPC server during compilation and execution,
appear on the thread named GrpcClientSessionUserFuturesWorkQueue
. By examining
this thread in your trace, you can gain insights into the performance of these
operations.
XLA Flags
When you use Pathways, you need to set the XLA flags in the pathways-proxy container. You can do so in using XPK or the PathwaysJob API.
When using XPK, set XLA flags like the following:
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
When using PathwaysJob API, set XLA flags like the following:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Replace the following:
USER
: your Google Cloud usernamevalue[n]
: the XLA flags you want to set
HLO Dump
In order to deep dive into High Level Optimizer (HLO) inputs that are given to the XLA compiler, you can configure Pathways to dump the HLO to a specified Cloud Storage location as follows:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"