This page shows you how to use Multi-Tier Checkpointing to reliably store and manage checkpoints during your machine learning model training on GKE. Checkpoint storage and management is crucial for large-scale training jobs, defined as those utilizing over thousands of nodes. Interruptions to these large-scale jobs are frequent (potentially hourly) and recovery from them can be slow.
Benefits
Using Multi-Tier Checkpointing provides the following benefits:
- Fully orchestrated checkpoint data management, including backups, replication,
and automatic restoration for the following workloads:
- JAX training runs using Orbax for state management running on TPUs.
- PyTorch workloads on GPUs.
- Quick recovery of training jobs from a checkpoint stored in the local node. You can also recover using checkpoints stored in another node in the training cluster.
- Quick restoration of training jobs from a checkpoint stored on a Cloud Storage backup in worst-case scenarios, where there are no in-cluster checkpoints.
Before you begin
Before you start, make sure you have performed the following tasks:
- Enable the Google Kubernetes Engine API. Enable Google Kubernetes Engine API
- If you want to use the Google Cloud CLI for this task,
install and then
initialize the
gcloud CLI. If you previously installed the gcloud CLI, get the latest
version by running
gcloud components update
.
- Create a Cloud Storage bucket, if you don't have one for your project. Make sure you enable hierarchical namespace, otherwise backups will fail.
Requirements
Multi-Tier Checkpointing requires GKE cluster version 1.32.4-gke.1415000 or later.
Limitations
- Autopilot clusters aren't supported.
Configure GKE nodes for using Multi-Tier Checkpointing
This sections covers how to configure GKE nodes on new and existing clusters.
Configure nodes on a new cluster
Create a cluster with Multi-Tier Checkpointing, the Cloud Storage FUSE CSI driver, and Workload Identity Federation for GKE enabled. If you use TPU slices for your machine learning workload, you'll need to adjust the cluster creation command to include the configuration for a TPU slice node pool.
gcloud container clusters create CLUSTER_NAME \ --addons=HighScaleCheckpointing,GcsFuseCsiDriver \ --node-locations=NODE_LOCATION \ --workload-pool=PROJECT_ID.svc.id.goog \ --cluster-version=CLUSTER_VERSION --location=CLUSTER_LOCATION \ --machine-type=MACHINE_TYPE \ --num-nodes=NUM_NODES
Replace the following values:
CLUSTER_NAME
: the name of the cluster.NODE_LOCATION
: the zone for your cluster nodes. This is where your TPU capacity lives.PROJECT_ID
: your Google Cloud project ID.CLUSTER_VERSION
: the version of your cluster. 1.32.4-gke.1415000 is the minimum supported version.CLUSTER_LOCATION
: the region that you want to create your cluster in.MACHINE_TYPE
: the machine type used for nodes that run components like the JobSet controller and the multi-tier checkpointing controller. For large-scale training, we recommend using at leaste2-standard-4
machines. You won't use this machine type for model training; instead, you'll create separate node pools for that purpose, often utilizing accelerator-optimized VM families.NUM_NODES
: the number of nodes to be created in each of the cluster's zones.
Configure nodes on an existing cluster
To use Multi-Tier Checkpointing with an existing cluster, enable it together with the Cloud Storage FUSE CSI driver, and Workload Identity Federation for GKE with the following commands. Your existing cluster version needs to be later than 1.32.3-gke.1170000.
Enable Workload Identity Federation for GKE:
gcloud container clusters update CLUSTER_NAME \ --workload-pool=PROJECT_ID.svc.id.goog \ --location=CLUSTER_LOCATION
Replace the following values:
CLUSTER_NAME
: the name of the cluster.PROJECT_ID
: your Google Cloud project ID.CLUSTER_LOCATION
: the region of the cluster.
Enable Multi-Tier Checkpointing and the Cloud Storage FUSE CSI driver:
gcloud container clusters update CLUSTER_NAME \ --update-addons=HighScaleCheckpointing=ENABLED,GcsFuseCsiDriver=ENABLED \ --location=CLUSTER_LOCATION
Configure permissions to use Multi-Tier Checkpointing
This section covers how to configure permissions to use Multi-Tier Checkpointing.
Grant access to Cloud Storage buckets
The ephemeral volumes used by the Multi-Tier Checkpointing CSI driver must use existing Cloud Storage buckets.
To store checkpoints in a Cloud Storage bucket, Multi-Tier Checkpointing needs
access to the bucket. Grant the Storage Object User (roles/storage.objectUser
) IAM role on the bucket to the Kubernetes service account for Multi-Tier Checkpointing.
gcloud storage buckets add-iam-policy-binding gs://GCS_BUCKET \
--member "principal://iam.googleapis.com/projects/PROJECT_NUMBER/locations/global/workloadIdentityPools/PROJECT_ID.svc.id.goog/subject/ns/gke-managed-checkpointing/sa/gke-checkpointing-multitier-node" \
--role "roles/storage.objectUser"
Replace the following values:
GCS_BUCKET
: the name of the Cloud Storage bucket that you'll transfer data from.PROJECT_ID
: your Google Cloud project ID.PROJECT_NUMBER
: an automatically generated unique identifier for your project. To find this value, refer to Creating and managing projects.
(Optional) Grant Compute Engine default service account access
If your Compute Engine instances need read access to the Cloud Storage
bucket, grant the Storage Object Viewer (roles/storage.objectViewer
) IAM
role to the Compute Engine default service account.
gcloud storage buckets add-iam-policy-binding gs://GCS_BUCKET \
--member serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com \
--role roles/storage.objectViewer
Deploy the JobSet controller
The JobSet controller is responsible for managing the batch jobs that run your model training on GKE, and its resource allocation is adjusted to handle the workload efficiently. Make sure that your training job launcher deploys and uses JobSet.
To increase the memory request to 1 Gi, the memory limit to 2 Gi, and the CPU request to 1 for the manager container in your JobSet deployment, run the following patch command:
kubectl patch -n jobset-system deploy jobset-controller-manager --type json \
--patch '[{"op": "add", "path": "/spec/template/spec/containers/0/resources", "value": {"limits": {"memory": "2Gi"}, "requests": {"cpu": "1", "memory": "1Gi"}}}]'
Initialize the Multi-Tier Checkpointing CSI driver
This section describes how to initialize the Multi-Tier Checkpointing CSI driver on nodes where your workloads will run. The CSI driver is responsible for handling the storage and management of checkpoints during your model training process.
Create a CheckpointConfiguration
A CheckpointConfiguration is a Kubernetes custom resource that specifies properties for deploying the Multi-Tier Checkpointing CSI driver. This resource is cluster-scoped.
Create the following
checkpoint.yaml
manifest.kind: CheckpointConfiguration metadata: name: MTC_CONFIG_NAME-configuration spec: cloudStorageBucketName: GCS_BUCKET nodeSelector: node.kubernetes.io/instance-type: MACHINE_TYPE tolerations: - key: TOLERATION_KEY operator: Exists effect: NoSchedule inMemoryVolumeSize: IN_MEMORY_VOLUME_SIZE gcsFuseMountOptions: - implicit-dirs - metadata-cache:negative-ttl-secs:0 - metadata-cache:ttl-secs:-1 - metadata-cache:stat-cache-max-size-mb:-1 - metadata-cache:type-cache-max-size-mb:-1 - file-cache:max-size-mb:-1 - file-cache:cache-file-for-range-read:true - file-system:kernel-list-cache-ttl-secs:0 - file-cache:enable-parallel-downloads:true - read_ahead_kb=1024 - write:enable-streaming-writes:true
Replace the following:
- MTC_CONFIG_NAME: the name of your CheckpointConfiguration. This name is global for the cluster and is not job-specific.
- GCS_BUCKET: the name of the Cloud Storage bucket where you'll store checkpoint data. Use the bucket that you set up in the Set up a Cloud Storage bucket with permissions step.
MACHINE_TYPE: the machine type for the corresponding accelerators. The value can be one of the following:
- TPU v5p:
ct5p-hightpu-4t
- TPU v5e:
ct5e-hightpu-4t
- TPU v6e:
ct6e-standard-4t
- NVIDIA H100 80GB GPUs (A3 series):
For more information on running distributed workloads on GPUs with GKE, see Running multi-instance GPUs. For TPUs, see Create the TPU slice node pool.
- TPU v5p:
TOLERATION_KEY: this field allows the CSI driver to be scheduled on nodes with matching taints. For more information about how taints work on different accelerator types, see these pages:
IN_MEMORY_VOLUME_SIZE: the size for the in-memory checkpointing cache. Specify the quantity and unit (for example, 200 Gi).This value should be:
- The local checkpoint size for TPUs multiplied by 2.2
- The local checkpoint size for GPUs with a single peer multiplied by 6.6.
Apply the manifest:
kubectl apply -f checkpoint.yaml
Check that the CSI driver is running:
kubectl get pod -n gke-managed-checkpointing
The output should be similar to the following. There will be multiple entries, one per accelerated node.
NAME READY STATUS RESTARTS AGE multitier-driver-e2b033a7-a4e7-496a-87a3-ffd7fcc2e57b-2d4fz 5/5 Running 0 114s
Uninstall the Multi-Tier Checkpointing CSI driver
If you want to undeploy the Multi-Tier Checkpointing CSI driver, delete the CheckpointConfiguration
resource. The Multi-Tier Checkpointing controller removes the CSI driver from the nodes. This removes the RAM disks and frees up memory for other workloads. For example:
kubectl delete -f checkpoint.yaml
Manage data retention and garbage collection for Cloud Storage backups
You're responsible for implementing retention policies for the Cloud Storage backups of checkpoints. Multi-Tier Checkpointing only writes checkpoint backups to Cloud Storage and never modifies or deletes them.
Many open source tools can handle retention and garbage collection, including the following:
The following example uses backup-warden
where the backup
directory is
mounted to a backup location that uses Cloud Storage FUSE:
# Add --delete option to actually delete the backups, as is it only shows what would be deleted (dry-run)
backup-warden -p backup \
--hourly 24 \
--daily 7 \
--weekly 5 \
--monthly always \
--yearly always \
--prefer-recent
Update the workload JobSet manifest
Update the JobSet manifest for your job to include the large-scale checkpoint volume. The details depend on your workload.
For example, to extend the sample JobSet from Deploy TPU Multislices in GKE, do the following steps:
Add the following lines to the
jax-tpu
container.volumeMounts: - name: checkpoint mountPath: CHECKPOINT_DIR
Replace CHECKPOINT_DIR with the path to your checkpoint directory. This is the location where the
replicator.yaml
is generated and Multi-Tier Checkpointing performs the checkpoint save operation. For more information, see Integrate Multi-Tier Checkpointing in your application.Add the following lines to the
spec.template.spec
field of the Job specification.volumes: - name: checkpoint csi: driver: multitier-checkpoint.csi.storage.gke.io
Integrate Multi-Tier Checkpointing in your application
To share information about checkpoint locations and replication readiness, modify your application to use the following protocol to communicate with Multi-Tier Checkpointing.
Startup
This section describes the initial steps the application needs to interact with Multi-Tier Checkpointing.
The Replicator is a core component of Multi-Tier Checkpointing, running on every node as part of the CSI driver. The Replicator manages checkpoint replication across storage tiers, from the local RAM disk to peer nodes and to external storage like Cloud Storage.
The replicator.yaml
file acts as a dynamic control plane between your ML
training job (framework code) and the Replicator component. Your ML application
programmatically generates this file on the local volume (RAMDisk), which is
accessible by both the training job and the Replicator service. This manifest
allows the ML framework to provide runtime configuration and lifecycle management
instructions to the Replicator, distinct from static infrastructure parameters (
for example, Cloud Storage upload frequency) defined during backend setup.
For a concrete example of this interaction, see:
- The MaxText project, which utilizes this architecture for JAX on Cloud TPU.
- The PyTorch reference example, which utilizes this architecture with PyTorch and NVIDIA NeMO on GPUs.
Your application should do the following steps during startup:
Wait until the
replicator.yaml
file is absent, which indicates that the Replicator is ready to be configured by your application. Thereplicator.yaml
file is generated in the CHECKPOINT_DIR location you configured in the Update the workload JobSet manifest section.When the model training job is first created, the
replicator.yaml
file won't exist and your application can proceed immediately. However, if the job was restarted (for example, due to a failure or manual intervention), the system might still be processing the previous job instance, and thereplicator.yaml
from that instance might still be present on the local volume.You application or ML job creates the
replicator.yaml
file with the configuration similar to the following.Orbax
job-name: orbax framework: orbax assume-data-parallelism: 3 node-rank: 0 nodes: 32 peer-ranks: [1, 16] or peers-per-node: 2 backup-interval-minutes: 30
PyTorch
job-name: nemo framework: pytorch.distributed node-rank: 0 nodes: 32 peer-ranks: [1, 16] or peers-per-node: 2 backup-interval-minutes: 30
This example configuration has the following fields:
name
: the name of the training job.framework
: the ML framework being used by the training job.node-rank
: the unique identifier of the current node within the distributed training job. This represents the node rank of the node creating this file. Each node participating in the run will have its own rank.nodes
: the total number of nodes participating in the distributed training job. This value comes from from the Pod's metadata. The ML training job can also view this value.peer-ranks
orpeers-per-node
: two alternative ways to specify the replication topology. Only one of these two parameters should be present.peer-ranks
: explicit ranks of peer nodes to which the current node's checkpoint data should be replicated. This gives fine-grained control over which specific nodes serve as replication partners.peers-per-node
: the number of peer nodes per node that the Replicator should automatically select for replication.
backup-interval-minutes
: the frequency, in minutes, at which checkpoints are backed up to Cloud Storage. We recommend that you set this value to 30 minutes or more.
Wait until the new
replicator.yaml
file is deleted by the system. This signals that the replicator has re-started and performed cleanup. This step lets you avoid any stale or temporary files on the local volume when your application performs the steps in the next section.
Restore from the last known good (LKG) checkpoint
After the replicator is initialized, Multi-Tier Checkpointing creates one symbolic link per TPU or GPU worker. These symbolic links are created in the same mounted local volume as
replicator.yaml
file, where the job saves checkpoints.The symbolic links have the form
<job-name>-s{step}-n<node-rank>-w<worker-index>.restore
.Restore each worker from the corresponding
.restore
file. For an example, see the Orbax replicated checkpoint manager example in the next section.
Save the checkpoint
Your application performs these steps multiple times while the training job progresses. The save operations happen in the CHECKPOINT_DIR location you configured in the Update the workload JobSet manifest.
Orbax
Create Orbax checkpoints. The directory is named with the step number. The replicator detects the newly created checkpoint diretcory, performs replication or backup as needed, and cleans up automatically.
For more information about how to use the Orbax replicator checkpoint manager, see the
MaxtTest checkpointing
file.
For an example of replicator service interaction, see the MaxText max_utils
file.
PyTorch
Use InClusterLocalCheckpointIO
as a custom pytorch_lightning.CheckpointIO
to enable correct distributed checkpointing with local storage. The following
example command enables multi-tier checkpointing using a reference
implementation built on the NVIDIA NeMo framework:
torchrun train.py <other_train_flags> \
--local-ckpt-dir=CHECKPOINT_DIR \
--local-ckpt-interval=20 \
--job-name=JOB_NAME \
--enable-high-scale-ckpt
Replace the following:
CHECKPOINT_DIR
: the path to your checkpoint directory.JOB_NAME
: the name of your training job workload.
Troubleshoot
This section provides troubleshooting guidance to resolve issues with Multi-Tier Checkpointing. For general storage troubleshooting, see Troubleshooting Cloud Storage in GKE.
Multi-Tier Checkpointing not enabled
The following error indicates that Multi-Tier Checkpointing is not enabled on your cluster:
error: unable to recognize "checkpoint.yaml": no matches for kind "CheckpointConfiguration" in version "checkpointing.gke.io/v1"
You might encounter this error after running kubectl apply -f checkpoint.yaml
in the Create a CheckpointConfiguration step.
To resolve this issue, check if you have enabled Multi-Tier Checkpointing on your cluster with the following command:
gcloud container clusters describe CLUSTER_NAME \
--project PROJECT_ID
--location CLUSTER_LOCATION
If Multi-Tier Checkpointing enabled, the output should be similar to the following:
addonsConfig:
gcePersistentDiskCsiDriverConfig:
enabled: true
gcsFuseCsiDriverConfig:
enabled: true
highScaleCheckpointingConfig:
enabled: true
kubernetesDashboard:
disabled: true
networkPolicyConfig:
disabled: true
If Multi-Tier Checkpointing is not enabled, update your cluster to enable Multi-Tier Checkpointing.
Multi-Tier Checkpointing CSI driver unable to mount volumes
You might encounter this issue if the CSI driver is unable to mount the Cloud Storage volume. There might be multiple lines similar to this.
kubectl get pod -n gke-managed-checkpointing
NAME READY STATUS RESTARTS AGE
multitier-driver-14694e4d-774f-4104-8bba-f0bd82fd7557-5vxr9 0/5 Init:0/1 0 6m32s
To resolve this issue, check the CSI driver Pod events, as shown in the following example:
kubectl describe pod multitier-driver-14694e4d-774f-4104-8bba-f0bd82fd7557-5vxr9 -n gke-managed-checkpointing
Events:
Type Reason Age From Message
---- ------ ---- ---- -------
Normal Scheduled 17m default-scheduler Successfully assigned gke-managed-checkpointing/multitier-driver-14694e4d-774f-4104-8bba-f0bd82fd7557-5vxr9 to gke-my-cluster-default-pool-353c773f-6d8q
Warning FailedMount 82s (x16 over 17m) kubelet MountVolume.SetUp failed for volume "gcs" : rpc error: code = PermissionDenied desc = failed to get GCS bucket "checkpointing-test-bucket": googleapi: Error 403: Caller does not have storage.objects.list access to the Google Cloud Storage bucket. Permission 'storage.objects.list' denied on resource (or it may not exist)., forbidden
If the issue occurs because of Cloud Storage bucket PermissionDenied
error, as
shown in the example, you can resolve the problem by correctly setting up permissions.
What's next
- Learn more about deploying TPU Multislice on Google Kubernetes Engine.
- Learn how to optimize the Cloud Storage FUSE CSI Driver for Google Kubernetes Engine for performance.
- Explore Orbax checkpointing options.