Preserve training progress using Autocheckpoint

Historically, when a TPU VM requires maintenance, the procedure is initiated immediately, without leaving time for users to perform progress-preserving actions such as saving a checkpoint. This is shown in Figure 1(a).

Diagram showing the impact of host maintenance with and without autocheckpointing

Fig. 1. Illustration of the Autocheckpoint feature: (a) Without Autocheckpoint, the training progress from the last checkpoint is lost when there is an upcoming maintenance event. (b) With Autocheckpoint, the training progress since the last checkpoint can be preserved when there is an upcoming maintenance event.

You can use Autocheckpoint (Figure 1(b)) to preserve training progress by configuring your code to save a non-scheduled checkpoint when a maintenance event occurs. When a maintenance event occurs, progress since the last checkpoint is automatically saved. The feature works on both single slices and Multislice.

The Autocheckpoint feature works with frameworks that can capture SIGTERM signals and subsequently save a checkpoint. The supported frameworks include:

Using Autocheckpoint

The Autocheckpoint feature is disabled by default. When you create a TPU or a request a queued resource, you can enable Autocheckpoint by adding the --autocheckpoint-enabled flag when provisioning the TPU. With the feature enabled, Cloud TPU performs the following steps once it receives notification of a maintenance event:

  1. Capture SIGTERM signal sent to the process using the TPU device
  2. Wait until the process exits, or 5 minutes have elapsed, whichever comes first
  3. Perform maintenance on the impacted slices

The infrastructure used by Autocheckpoint is ML framework-independent. Any ML framework can support Autocheckpoint if it can capture the SIGTERM signal and initiate a checkpointing process.

In the application code, you need to enable the Autocheckpoint capabilities provided by the ML framework. In Pax, for example, this means enabling command-line flags when launching the training. For more information, see the Autocheckpoint quickstart with Pax. Behind the scenes, the frameworks save a non-scheduled checkpoint when a SIGTERM signal is received, and the impacted TPU VM goes through maintenance when the TPU is no longer in use.

Quickstart: Autocheckpoint with MaxText

MaxText is a high performance, arbitrarily scalable, open source, well-tested LLM written in pure Python/JAX targeting Cloud TPUs. MaxText contains all the necessary setup to use the Autocheckpoint feature.

The MaxText README file describes two ways to run MaxText at scale:

When using multihost_runner.py, enable Autocheckpoint by setting the autocheckpoint-enabled flag when provisioning the queued resource.

When using multihost_job.py, enable Autocheckpoint by specifying the ENABLE_AUTOCHECKPOINT=true command line flag when launching the job.

Quickstart: Autocheckpoint with Pax on single slices

This section provides an example of how to set up and use Autocheckpoint with Pax on a single slice. With the appropriate setup:

  • A checkpoint will be saved when a maintenance event occurs.
  • Cloud TPU will perform maintenance on the affected TPU VM(s) after the checkpoint is saved.
  • When Cloud TPU completes maintenance, you can use the TPU VM as usual.
  1. Use the autocheckpoint-enabled flag when creating the TPU VM or requesting a queued resource.

    For example:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
        --accelerator-type $ACCELERATOR_TYPE \
        --version tpu-ubuntu2204-base \
        --autocheckpoint-enabled
  2. Install Pax on a single slice

    The Autocheckpoint feature works on Pax versions 1.1.0 and later. On the TPU VMs, install jax[tpu] and the latest paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Launch the training with the appropriate configuration.

    The following example shows how to configure the LmCloudSpmd2B model to save checkpoints triggered by Autocheckpoint to a Cloud Storage bucket:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Note the two flags that are passed to the command:

    • jax_fully_async_checkpoint: With this flag on, orbax.checkpoint.AsyncCheckpointer will be used. The AsyncCheckpointer class automatically saves a checkpoint when the training script receives a SIGTERM signal.
    • exit_after_ondemand_checkpoint: With this flag on, the TPU process exits after the Autocheckpoint is successfully saved, which triggers the maintenance to be performed immediately. If you don't use this flag, the training will continue after the checkpoint is saved and Cloud TPU will wait for a timeout to occur (5 minutes) before performing the required maintenance.

Quickstart: Autocheckpoint with Pax on Multislice

Autocheckpoint works not only for single slices, but also for Multislice. This section details the steps needed to use Autocheckpoint with Multislice.

  1. Specify Autocheckpoint during queued resource creation.

    A Multislice environment can only be provisioned through a queued resource request. Similar to the single-slice case, use the autocheckpoint-enabled flag in the call to create a queued resource.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
        --node-count $NODE_COUNT \
        --accelerator-type $ACCELERATOR_TYPE \
        --runtime-version tpu-ubuntu2204-base \
        --autocheckpoint-enabled

    For more information about all available options, see the Multislice user guide. When the queued resource request is created and in the ACTIVE state, follow the next steps to run Pax with Autocheckpoint.

  2. Install Pax on all VMs in the Multislice environment.

    On the TPU VMs, install jax[tpu] and the latest paxml on all of the TPU VMs in your Multislice environment:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Launch the training with the appropriate configuration.

    This example shows how to configure the LmCloudSpmd2B model for Autocheckpoint when training in a Multislice environment. Before running the training script, set DCN_MESH_SHAPE to [2, 1, 1] as shown in the following example:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 4, 1]
        DCN_MESH_SHAPE = [2, 1, 1]

    When launching the training, in addition to the command line flags discussed in the single-slice case, three more are required:

    • num_hosts: the total number of hosts. In this case, it is 2.
    • host_index: the index of the host launching the training. It varies from 0 to N-1 where N is the total number of hosts.
    • server_addr: the IP address of worker 0 of node 0, with an unused port (for example, 8476). To find this information, use hostname -i on worker 0 of node 0.

Autocheckpoint with Orbax

The Autocheckpoint feature is not limited to MaxText or Pax. Any framework that can capture the SIGTERM signal and initiate a checkpointing process works with the infrastructure provided by Autocheckpoint. Orbax, a namespace that provides common utility libraries for JAX users, provides these capabilities.

As explained in the Orbax documentation, these capabilities are enabled by default for users of orbax.checkpoint.CheckpointManager. The save method that is called after every step automatically checks whether a maintenance event is impending, and if so, saves a checkpoint even if the step number is not a multiple of save_interval_steps. The GitHub documentation also illustrates how to make the training exit after saving an Autocheckpoint, with a modification in the user code.