Préserver la progression de l'entraînement à l'aide d'un point de contrôle automatique

Historiquement, lorsqu'une VM TPU nécessite une maintenance, la procédure est lancée immédiatement, sans laisser aux utilisateurs le temps d'effectuer des actions de préservation de la progression telles que l'enregistrement d'un point de contrôle. C'est ce que montre la figure 1(a).

Schéma illustrant l'impact de la maintenance de l'hôte avec et sans point de contrôle automatique

Fig. 1. Illustration de la fonctionnalité de point de contrôle automatique : (a) Sans point de contrôle automatique, la progression de l'entraînement à partir du dernier point de contrôle est perdue en cas d'événement de maintenance à venir. (b) Avec Autocheckpoint, la progression de l'entraînement depuis le dernier point de contrôle peut être préservée en cas d'événement de maintenance à venir.

Vous pouvez utiliser le point de contrôle automatique (figure 1b) pour préserver la progression de l'entraînement en configurant votre code pour enregistrer un point de contrôle non planifié lorsqu'un événement de maintenance se produit. Lorsqu'un événement de maintenance se produit, la progression depuis le dernier point de contrôle est automatiquement enregistrée. Cette fonctionnalité fonctionne à la fois sur les tranches uniques et sur Multislice.

La fonctionnalité Autocheckpoint fonctionne avec les frameworks pouvant capturer des signaux SIGTERM et enregistrer ensuite un point de contrôle. Les frameworks compatibles incluent:

Utiliser le point de contrôle automatique

La fonctionnalité de point de contrôle automatique est désactivée par défaut. Lorsque vous créez un TPU ou demandez une ressource en file d'attente, vous pouvez activer le point de contrôle automatique en ajoutant l'indicateur --autocheckpoint-enabled lors du provisionnement du TPU. Lorsque cette fonctionnalité est activée, Cloud TPU effectue les étapes suivantes lorsqu'il reçoit une notification d'événement de maintenance:

  1. Capturer le signal SIGTERM envoyé au processus à l'aide de l'appareil TPU
  2. Attendez la fin du processus ou cinq minutes, selon la première échéance atteinte.
  3. Effectuer la maintenance des segments concernés

L'infrastructure utilisée par Autocheckpoint est indépendante du framework de ML. N'importe quel framework de ML peut prendre en charge la création automatique de points de contrôle s'il peut capturer le signal SIGTERM et lancer un processus de création de points de contrôle.

Dans le code de l'application, vous devez activer les fonctionnalités de point de contrôle automatique fournies par le framework de ML. Dans Pax, par exemple, cela signifie activer les indicateurs de ligne de commande lors du lancement de l'entraînement. Pour en savoir plus, consultez le guide de démarrage rapide Autocheckpoint avec Pax. En coulisses, les frameworks enregistrent un point de contrôle non planifié lorsqu'un signal SIGTERM est reçu, et la VM TPU concernée est soumise à une maintenance lorsque le TPU n'est plus utilisé.

Guide de démarrage rapide: Point de contrôle automatique avec MaxText

MaxText est un LLM Open Source hautes performances, évolutif de manière arbitraire et bien testé, écrit en Python/JAX pur et ciblant les Cloud TPU. MaxText contient toute la configuration nécessaire pour utiliser la fonctionnalité de point de contrôle automatique.

Le fichier README MaxText décrit deux façons d'exécuter MaxText à grande échelle:

Lorsque vous utilisez multihost_runner.py, activez le point de contrôle automatique en définissant l'indicateur autocheckpoint-enabled lors du provisionnement de la ressource mise en file d'attente.

Lorsque vous utilisez multihost_job.py, activez le point de contrôle automatique en spécifiant l'indicateur de ligne de commande ENABLE_AUTOCHECKPOINT=true lors du lancement de la tâche.

Démarrage rapide: point de contrôle automatique avec Pax sur une seule tranche

Cette section fournit un exemple de configuration et d'utilisation d'Autocheckpoint avec Pax sur une seule tranche. Avec la configuration appropriée:

  • Un point de contrôle est enregistré lorsqu'un événement de maintenance se produit.
  • Cloud TPU effectuera la maintenance des VM TPU concernées une fois le point de contrôle enregistré.
  • Une fois la maintenance terminée, vous pouvez utiliser la VM TPU comme d'habitude.
  1. Utilisez l'option autocheckpoint-enabled lorsque vous créez la VM TPU ou demandez une ressource mise en file d'attente.

    Exemple :

    1. Définissez les variables d'environnement :

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      Descriptions des variables d'environnement

      Variable Description
      PROJECT_ID L'ID de votre Google Cloud projet. Utilisez un projet existant ou créez-en un.
      TPU_NAME Nom du TPU.
      ZONE Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez la section Régions et zones de TPU.
      ACCELERATOR_TYPE Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU.
      RUNTIME_VERSION Version logicielle de Cloud TPU.

    2. Définissez votre ID de projet et votre zone dans votre configuration active:

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. Créez un TPU:

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. Connectez-vous au TPU à l'aide de SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. Installer Pax sur une seule tranche

    La fonctionnalité de point de contrôle automatique fonctionne avec les versions Pax 1.1.0 et ultérieures. Sur la VM TPU, installez jax[tpu] et la dernière version de paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. Configurez le modèle LmCloudSpmd2B. Avant d'exécuter le script d'entraînement, remplacez ICI_MESH_SHAPE par [1, 8, 1]:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. Lancez l'entraînement avec la configuration appropriée.

    L'exemple suivant montre comment configurer le modèle LmCloudSpmd2B pour enregistrer les points de contrôle déclenchés par Autocheckpoint dans un bucket Cloud Storage. Remplacez your-storage-bucket par le nom d'un bucket existant ou créez un bucket.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Notez les deux options transmises à la commande:

    • jax_fully_async_checkpoint : lorsque cette option est activée, orbax.checkpoint.AsyncCheckpointer est utilisé. La classe AsyncCheckpointer enregistre automatiquement un point de contrôle lorsque le script d'entraînement reçoit un signal SIGTERM.
    • exit_after_ondemand_checkpoint : lorsque cet indicateur est activé, le processus TPU se termine une fois le point de contrôle automatique enregistré, ce qui déclenche immédiatement la maintenance. Si vous n'utilisez pas cet indicateur, l'entraînement se poursuit une fois le point de contrôle enregistré, et Cloud TPU attend un délai avant expiration (cinq minutes) avant d'effectuer la maintenance requise.

Point de contrôle automatique avec Orbax

La fonctionnalité de contrôle automatique n'est pas limitée à MaxText ou Pax. Tout framework capable de capturer le signal SIGTERM et d'initier un processus de point de contrôle fonctionne avec l'infrastructure fournie par Autocheckpoint. Orbax, un espace de noms qui fournit des bibliothèques d'utilitaires courantes pour les utilisateurs de JAX, fournit ces fonctionnalités.

Comme expliqué dans la documentation Orbax, ces fonctionnalités sont activées par défaut pour les utilisateurs de orbax.checkpoint.CheckpointManager. La méthode save appelée après chaque étape vérifie automatiquement si un événement de maintenance est imminent. Si c'est le cas, elle enregistre un point de contrôle, même si le numéro d'étape n'est pas un multiple de save_interval_steps. La documentation GitHub explique également comment arrêter l'entraînement après avoir enregistré un point de contrôle automatique, avec une modification dans le code utilisateur.