Pathways provides resiliency benefits in the following ways:
- Suspend-Resume: tolerance in face of planned interruptions like preemption notices without needing the user to write any custom preemption handling code.
- Elastic Training: tolerance in face of unplanned hardware failures
without causing the client to crash but requiring users to write model
specific recovery code.
Before you begin
Make sure you have:
- Installed XPK
- Installed Kubernetes tools
- Installed the gcloud CLI
- Enabled the TPU API
- Enabled the GKE API
- Ensure your Google Cloud project is allowlisted for Pathways
Suspend-resume
Typically, GKE sends a preemption notice to an accelerator pod, before the pod is preempted. Pathways preemption tolerance is enabled by default on all cloud deployments and Pathways accelerator jobs listen for these notices.
When a preemption notice arrives, Pathways first determines whether the current workload is restorable - whether Pathways can transparently save and restore the workload. If so, then it attempts to transparently suspend your ML workload by writing out its current state to persistent storage such as Cloud Storage before GKE evicts your accelerator jobs. When GKE reschedules your jobs later, Pathways resumes your ML workload by reading back its persisted state.
If the workload is not restorable, Pathways shuts down the accelerator job and forwards the failure to your job if Elastic training is configured. If Elastic training is not configured, GKE restarts the entire workload based on the JobSet restart policy.
Typical ML workloads defined using JAX rely on stateless Pathways XLA components which are restorable using a high bandwidth memory (HBM) snapshot. Certain ML workloads such as those defined using the JAX colocated python API rely on stateful Pathways components; these are not restorable.
Elastic training
Elastic training allows your training job to continue even when hardware failures occur. This is achieved through a combination of Pathways system capabilities and user-defined model recovery logic:
- Detection of failure: When a hardware failure happens (for example, a TPU worker crashes), the Pathways system detects this and notifies the user's training job through an exception the next time data that was located on that hardware is accessed. This notification doesn't crash your workload; it allows your code to handle the notification and reconfigure your resources to either continue processing or exit gracefully.
- User-defined elasticity handler: User's model code needs to be able to
handle this exception. This is what makes it "model-specific recovery".
- Snapshotting: The most common approach is to periodically save snapshots of your model's state. When a failure occurs, you can load from the most recent snapshot to resume training.
- Reconfiguration: You will likely need to reconfigure your training job to adjust for the number of available slices. For example, if one slice stops working, you might reduce the number of active slices by one until a replacement is available. For more information, see Elastic Handler.
- Data/Computation graph updates: Your code needs to handle any changes in the number of devices available to your computation by re-creating the computation graph as needed. This might involve re-partitioning data or re-compiling your model.
- Pathways' role in recovery: Pathways provides the primitives to support
user defined reconfiguration:
- Slice replacement: If a failed slice is replaced, the client can be informed once the new slice is available. Your code can then reconfigure to use this new slice.
- Transparent recovery: Pathways handles the lower-level details of the recovery, like re-establishing connections to the healthy portions of the cluster.
- Utilities in pathwaysutils: A set of Pathways utilities defined in pathways-utils.
Implement an elastic handler
Most of the code you will have to write will be in a user-defined elastic handler. This handler reacts to elastic events (such as a TPU slice becoming unavailable) by re-creating the mesh and reinitializing the training loop.
Each workload is unique. The complexity of the elastic handler may scale with the complexity of the workload. The inputs and outputs of the handler should be the minimum arguments and return values needed to reinitialize the train loop.
def elastic_handler(elastic_utils, *args, **kwargs): mesh = initialize_mesh(**kwargs["mesh_kwargs"]) initial_state, initial_step, jitted_train_step, other_variables = initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"]) step, snapshot = elastic_utils.get_next_snapshot() state = initial_state.replace(**snapshot) return state, step, mesh, jitted_train_step, other_variables
Update your training loop
You need to make the following changes to your training loop:
- Create an elastic manager
- Wrap your training loop inside a try-except blocks that handles
jax.errors.JaxRuntimeError
s - Within your
jax.errors.JaxRuntimeError
handler, callmaybe_reshard_down
. The elastic manager will reshard down if the error is related to an elastic event or otherwise reraise it. - Call
maybe_snapshot
andmaybe_reshard_up
at the end of the training loop
import pathwaysutils from pathwaysutils.elastic import manager def initialize_mesh(**kwargs): ... def initialize_training_loop(**kwargs): ... def train_loop( final_step, elastic_manager, mesh_kwargs, initialize_training_loop_kwargs, ): mesh = initialize_mesh(**mesh_kwargs) initial_state, initial_step, jitted_train_step, other_variables = initialize_training_loop(mesh, **initialize_training_loop_kwargs) step = initial_step while step < final_step: try: state = jitted_train_step(state) elastic_manager.maybe_snapshot(step=step, snapshot=state) handler_returns = elastic_manager.maybe_reshard_up( step=step, snapshot=state, elastic_handler=elsatic_handler, handler_args=(), handler_kwargs=dict( mesh_kwargs=mesh_kwargs, initialize_training_loop_kwargs=initialize_training_loop_kwargs, ), ) if handler_returns: state, step, mesh, jitted_train_step, other_variables = handler_returns step += 1 except jax.errors.JaxRuntimeError as error: handler_returns = elastic_manager.maybe_reshard_down( error=error, elastic_handler=elsatic_handler, handler_args=(), handler_kwargs=dict( mesh_kwargs=mesh_kwargs, initialize_training_loop_kwargs=initialize_training_loop_kwargs, ), ) if handler_returns: state, step, mesh, jitted_train_step, other_variables = handler_returns return state def main(): elastic_manager = manager.Manager( devices=jax.devices(), snapshot_period=10, snapshot_buffer_size=1, reshard_check_period=5, max_elastic_down_event_count=10, max_reshard_retry_count=3, ) train_loop(100, elastic_manager, {}, {})
Configure the elastic manager
The elastic manager can be configured in a few different ways. The frequency of snapshotting is determined by the snapshot period. The snapshot period effects the average number of steps lost due to an elastic event. The reshard check period determines how often your training loop will poll for slice availability. The
max_elastic_down_event_count
lets you set the number of elastic events due to slice loss your training loop will support. Themax_reshard_retry_count
specifies the number of times the elastic manager should retry resharding. The manager is a singleton object and should be created only once.Snapshots
Based on the elastic manager configuration, the function may snapshot data onto host memory that will be available to use by your elastic handler during an elastic event.
Reduce sharding
After catching a
jax.errors.JaxRuntimeError
, Pathways will check if the error is due to an elastic event due to a lost slice. If so, it will call the elastic handler in a loop until success or the maximum retry attempts. If the error is not due to an elastic event, the error will be raised again. The return values of the elastic handler are passed through to the caller.Increase sharding
Based on the elastic manager configuration and if there are unavailable slices, Pathways will check if additional slices have become available. If so, it will immediately save a snapshot (if a pre-existing snapshot for the current step was not already taken) and call the elastic handler in a loop until success or the maximum number of retry attempts is reached. If re-sharding occurs, the return values of the elastic handler are passed through to the caller. Otherwise,
None
is returned.Hot-swap
Hot-Swap refers to a feature of the GKE JobSet API where a higher-priority job can quickly take over resources from a lower-priority job, minimizing downtime and ensuring faster recovery.
When a JobSet is created, GKE schedules the workload across multiple slices, as specified in the JobSet configuration. If a hardware failure occurs on one or more slices, the affected Pods are marked as failed. When rescheduling this Jobset, if you have elected to keep a spare slice in your GKE cluster which could be utilized for a lower priority Job, the JobSet system will remap the workload of the failed slice of the higher-priority job onto the spare slice being utilized by the lower priority job within the same GKE cluster. This remapping typically takes less than a minute.
Upon JobSet restart, hot-swap can occur in the following situations:
- Default Mode: If spare, idle TPU slices are available within the same cluster, the Kubernetes scheduler will prioritize scheduling the restarted jobs onto these slices rather than waiting for the failed slices to be repaired. This provides faster recovery.
- Heterogeneous Workloads: In clusters running multiple workloads with
a configured Kubernetes PriorityClass, a restarted JobSet can trigger a hot
swap. If the restarted job's affinity matches a lower-priority job's
resources, Kubernetes preempts the lower-priority job, allowing the
higher-priority job to start immediately. For example, you can configure your
Pathways worker pods with different priorities using
PriorityClass
.
To use priorities in your cluster, define a priority class, for example:
kind: PriorityClass metadata: name: high-prior-job value: 2000 globalDefault: false description: "This priority class should be used for high priority job."
Apply this YAML to your GKE cluster:
kubectl apply -f high-prior-job.yaml
Next, attach the new priority class to your Pathways worker job by adding the following text to the podspec of your
pathways-worker
Pod.priorityClassName: high-prior-job
What's next