Preserve training progress using Autocheckpoint
Historically, when a TPU VM requires maintenance, the procedure is initiated immediately, without leaving time for users to perform progress-preserving actions such as saving a checkpoint. This is shown in Figure 1(a).
Fig. 1. Illustration of the Autocheckpoint feature: (a) Without Autocheckpoint, the training progress from the last checkpoint is lost when there is an upcoming maintenance event. (b) With Autocheckpoint, the training progress since the last checkpoint can be preserved when there is an upcoming maintenance event.
You can use Autocheckpoint (Figure 1(b)) to preserve training progress by configuring your code to save a non-scheduled checkpoint when a maintenance event occurs. When a maintenance event occurs, progress since the last checkpoint is automatically saved. The feature works on both single slices and Multislice.
The Autocheckpoint feature works with frameworks that can capture SIGTERM signals and subsequently save a checkpoint. The supported frameworks include:
Using Autocheckpoint
The Autocheckpoint feature is disabled by default. When you create a
TPU or a request a queued resource,
you can enable Autocheckpoint by adding the --autocheckpoint-enabled
flag when provisioning
the TPU.
With the feature enabled, Cloud TPU
performs the following steps once it receives notification of a
maintenance event:
- Capture SIGTERM signal sent to the process using the TPU device
- Wait until the process exits, or 5 minutes have elapsed, whichever comes first
- Perform maintenance on the impacted slices
The infrastructure used by Autocheckpoint is ML framework-independent. Any ML framework can support Autocheckpoint if it can capture the SIGTERM signal and initiate a checkpointing process.
In the application code, you need to enable the Autocheckpoint capabilities provided by the ML framework. In Pax, for example, this means enabling command-line flags when launching the training. For more information, see the Autocheckpoint quickstart with Pax. Behind the scenes, the frameworks save a non-scheduled checkpoint when a SIGTERM signal is received, and the impacted TPU VM goes through maintenance when the TPU is no longer in use.
Quickstart: Autocheckpoint with MaxText
MaxText is a high performance, arbitrarily scalable, open source, well-tested LLM written in pure Python/JAX targeting Cloud TPUs. MaxText contains all the necessary setup to use the Autocheckpoint feature.
The MaxText README
file describes
two ways to run MaxText at scale:
- Using
multihost_runner.py
, recommended for experimentation - Using
multihost_job.job
, recommended for production
When using multihost_runner.py
, enable Autocheckpoint by setting
the autocheckpoint-enabled
flag when provisioning the queued resource.
When using
multihost_job.py
, enable Autocheckpoint by specifying the
ENABLE_AUTOCHECKPOINT=true
command line flag when launching the job.
Quickstart: Autocheckpoint with Pax on single slices
This section provides an example of how to set up and use Autocheckpoint with Pax on a single slice. With the appropriate setup:
- A checkpoint will be saved when a maintenance event occurs.
- Cloud TPU will perform maintenance on the affected TPU VM(s) after the checkpoint is saved.
- When Cloud TPU completes maintenance, you can use the TPU VM as usual.
Use the
autocheckpoint-enabled
flag when creating the TPU VM or requesting a queued resource.For example:
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Install Pax on a single slice
The Autocheckpoint feature works on Pax versions 1.1.0 and later. On the TPU VMs, install
jax[tpu]
and the latestpaxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Launch the training with the appropriate configuration.
The following example shows how to configure the
LmCloudSpmd2B
model to save checkpoints triggered by Autocheckpoint to a Cloud Storage bucket:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py \ --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Note the two flags that are passed to the command:
jax_fully_async_checkpoint
: With this flag on,orbax.checkpoint.AsyncCheckpointer
will be used. TheAsyncCheckpointer
class automatically saves a checkpoint when the training script receives a SIGTERM signal.exit_after_ondemand_checkpoint
: With this flag on, the TPU process exits after the Autocheckpoint is successfully saved, which triggers the maintenance to be performed immediately. If you don't use this flag, the training will continue after the checkpoint is saved and Cloud TPU will wait for a timeout to occur (5 minutes) before performing the required maintenance.
Quickstart: Autocheckpoint with Pax on Multislice
Autocheckpoint works not only for single slices, but also for Multislice. This section details the steps needed to use Autocheckpoint with Multislice.
Specify Autocheckpoint during queued resource creation.
A Multislice environment can only be provisioned through a queued resource request. Similar to the single-slice case, use the
autocheckpoint-enabled
flag in the call to create a queued resource.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
For more information about all available options, see the Multislice user guide. When the queued resource request is created and in the
ACTIVE
state, follow the next steps to run Pax with Autocheckpoint.Install Pax on all VMs in the Multislice environment.
On the TPU VMs, install
jax[tpu]
and the latestpaxml
on all of the TPU VMs in your Multislice environment:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Launch the training with the appropriate configuration.
This example shows how to configure the
LmCloudSpmd2B
model for Autocheckpoint when training in a Multislice environment. Before running the training script, setDCN_MESH_SHAPE
to[2, 1, 1]
as shown in the following example:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
When launching the training, in addition to the command line flags discussed in the single-slice case, three more are required:
num_hosts
: the total number of hosts. In this case, it is 2.host_index
: the index of the host launching the training. It varies from 0 toN-1
whereN
is the total number of hosts.server_addr
: the IP address of worker 0 of node 0, with an unused port (for example, 8476). To find this information, usehostname -i
on worker 0 of node 0.
Autocheckpoint with Orbax
The Autocheckpoint feature is not limited to MaxText or Pax. Any framework that can capture the SIGTERM signal and initiate a checkpointing process works with the infrastructure provided by Autocheckpoint. Orbax, a namespace that provides common utility libraries for JAX users, provides these capabilities.
As explained in the Orbax documentation,
these capabilities are enabled by default for users
of orbax.checkpoint.CheckpointManager
. The save
method
that is called after every step automatically checks whether a maintenance
event is impending, and if so, saves a checkpoint even if the step number
is not a multiple of save_interval_steps
.
The GitHub documentation
also illustrates how to make the training exit after saving an
Autocheckpoint, with a modification in the user code.