Préserver la progression de l'entraînement à l'aide d'Autocheckpoint
Historiquement, lorsqu'une VM TPU nécessite une maintenance, la procédure est lancée immédiatement, sans laisser le temps aux utilisateurs d'effectuer des actions de préservation de la progression, comme l'enregistrement d'un point de contrôle. C'est ce qu'illustre la figure 1(a).
Fig. 1. Illustration de la fonctionnalité Autocheckpoint : (a) Sans Autocheckpoint, la progression de l'entraînement à partir du dernier point de contrôle est perdue en cas d'événement de maintenance à venir. (b) Avec Autocheckpoint, la progression de l'entraînement depuis le dernier point de contrôle peut être conservée en cas d'événement de maintenance à venir.
Vous pouvez utiliser Autocheckpoint (figure 1(b)) pour préserver la progression de l'entraînement en configurant votre code afin d'enregistrer un point de contrôle non programmé lorsqu'un événement de maintenance se produit. Lorsqu'un événement de maintenance se produit, la progression depuis le dernier point de contrôle est automatiquement enregistrée. Cette fonctionnalité fonctionne à la fois sur les tranches uniques et sur Multislice.
La fonctionnalité Autocheckpoint fonctionne avec les frameworks qui peuvent capturer les signaux SIGTERM et enregistrer ensuite un point de contrôle. Voici les frameworks compatibles :
Utiliser Autocheckpoint
La fonctionnalité Autocheckpoint est désactivée par défaut. Lorsque vous créez une TPU ou demandez une ressource en file d'attente, vous pouvez activer la vérification automatique des points de contrôle en ajoutant l'indicateur --autocheckpoint-enabled
lors du provisionnement de la TPU.
Lorsque cette fonctionnalité est activée, Cloud TPU effectue les étapes suivantes une fois qu'il reçoit une notification d'événement de maintenance :
- Capturer le signal SIGTERM envoyé au processus à l'aide de l'appareil TPU
- Attendez la fin du processus ou cinq minutes, selon la première échéance atteinte.
- Effectuer la maintenance des tranches concernées
L'infrastructure utilisée par Autocheckpoint est indépendante du framework de ML. Tout framework de ML peut prendre en charge la création automatique de points de contrôle s'il peut capturer le signal SIGTERM et lancer un processus de création de points de contrôle.
Dans le code de l'application, vous devez activer les fonctionnalités Autocheckpoint fournies par le framework de ML. Dans Pax, par exemple, cela signifie activer les indicateurs de ligne de commande lors du lancement de l'entraînement. Pour en savoir plus, consultez le guide de démarrage rapide d'Autocheckpoint avec Pax. En arrière-plan, les frameworks enregistrent un point de contrôle non planifié lorsqu'un signal SIGTERM est reçu, et la VM TPU concernée fait l'objet d'une maintenance lorsque le TPU n'est plus utilisé.
Guide de démarrage rapide : Autocheckpoint avec MaxText
MaxText est un LLM Open Source hautes performances, évolutif à volonté et bien testé, écrit en Python/JAX pur et ciblant les Cloud TPU. MaxText contient toute la configuration nécessaire pour utiliser la fonctionnalité Autocheckpoint.
Le fichier MaxText README
décrit deux façons d'exécuter MaxText à grande échelle :
- Utiliser
multihost_runner.py
, recommandé pour les tests - Utiliser
multihost_job.py
, recommandé pour la production
Lorsque vous utilisez multihost_runner.py
, activez Autocheckpoint en définissant l'indicateur autocheckpoint-enabled
lors du provisionnement de la ressource mise en file d'attente.
Lorsque vous utilisez multihost_job.py
, activez Autocheckpoint en spécifiant l'indicateur de ligne de commande ENABLE_AUTOCHECKPOINT=true
lorsque vous lancez le job.
Guide de démarrage rapide : Autocheckpoint avec Pax sur un seul segment
Cette section fournit un exemple de configuration et d'utilisation d'Autocheckpoint avec Pax sur un seul slice. Avec la configuration appropriée :
- Un point de contrôle est enregistré lorsqu'un événement de maintenance se produit.
- Cloud TPU effectuera la maintenance sur les VM TPU concernées une fois le point de contrôle enregistré.
- Une fois la maintenance terminée, vous pouvez utiliser la VM TPU comme d'habitude.
Utilisez l'option
autocheckpoint-enabled
lorsque vous créez la VM TPU ou demandez une ressource mise en file d'attente.Exemple :
Définissez les variables d'environnement :
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=zone-you-want-to-use export ACCELERATOR_TYPE=your-accelerator-type export RUNTIME_VERSION=tpu-ubuntu2204-base
Descriptions des variables d'environnement
Variable Description PROJECT_ID
L'ID de votre projet Google Cloud . Utilisez un projet existant ou créez-en un. TPU_NAME
Nom du TPU. ZONE
Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez Régions et zones de TPU. ACCELERATOR_TYPE
Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez Versions de TPU. RUNTIME_VERSION
La version logicielle de Cloud TPU. Définissez votre ID de projet et votre zone dans votre configuration active :
gcloud config set project $PROJECT_ID gcloud config set compute/zone $ZONE
Créez un TPU :
gcloud alpha compute tpus tpu-vm create $TPU_NAME \ --accelerator-type $ACCELERATOR_TYPE \ --version $RUNTIME_VERSION \ --autocheckpoint-enabled
Connectez-vous à la TPU à l'aide de SSH :
gcloud compute tpus tpu-vm ssh $TPU_NAME
Installer Pax sur une seule tranche
La fonctionnalité Autocheckpoint fonctionne sur les versions 1.1.0 et ultérieures de Pax. Sur la VM TPU, installez
jax[tpu]
et la dernière version depaxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Configurez le modèle
LmCloudSpmd2B
. Avant d'exécuter le script d'entraînement, remplacezICI_MESH_SHAPE
par[1, 8, 1]
:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 8, 1]
Lancez l'entraînement avec la configuration appropriée.
L'exemple suivant montre comment configurer le modèle
LmCloudSpmd2B
pour enregistrer les points de contrôle déclenchés par Autocheckpoint dans un bucket Cloud Storage. Remplacez your-storage-bucket par le nom d'un bucket existant ou créez-en un.export JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py \ --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Notez les deux indicateurs transmis à la commande :
jax_fully_async_checkpoint
: Lorsque cette option est activée,orbax.checkpoint.AsyncCheckpointer
est utilisé. La classeAsyncCheckpointer
enregistre automatiquement un point de contrôle lorsque le script d'entraînement reçoit un signal SIGTERM.exit_after_ondemand_checkpoint
: lorsque ce signalement est activé, le processus TPU se termine une fois l'autocheckpoint enregistré, ce qui déclenche la maintenance immédiate. Si vous n'utilisez pas cet indicateur, l'entraînement se poursuivra après l'enregistrement du point de contrôle, et Cloud TPU attendra qu'un délai d'expiration se produise (5 minutes) avant d'effectuer la maintenance requise.
Point de contrôle automatique avec Orbax
La fonctionnalité Autocheckpoint n'est pas limitée à MaxText ni à Pax. Tout framework capable de capturer le signal SIGTERM et de lancer un processus de point de contrôle fonctionne avec l'infrastructure fournie par Autocheckpoint. Orbax, un espace de noms qui fournit des bibliothèques utilitaires courantes pour les utilisateurs de JAX, offre ces fonctionnalités.
Comme expliqué dans la documentation Orbax, ces fonctionnalités sont activées par défaut pour les utilisateurs de orbax.checkpoint.CheckpointManager
. La méthode save
appelée après chaque étape vérifie automatiquement si un événement de maintenance est imminent et, le cas échéant, enregistre un point de contrôle même si le numéro d'étape n'est pas un multiple de save_interval_steps
.
La documentation GitHub explique également comment faire en sorte que l'entraînement se termine après l'enregistrement d'un point de contrôle automatique, en modifiant le code utilisateur.