Training mit Autocheckpoint fortsetzen

Bisher wurde die Wartung einer TPU-VM sofort gestartet, ohne dass Nutzer Zeit hatten, Aktionen auszuführen, die den Fortschritt sichern, z. B. das Speichern eines Checkpoints. Dies ist in Abbildung 1(a) dargestellt.

Diagramm mit den Auswirkungen der Hostwartung mit und ohne automatisches Checkpointing

Abbildung 1 Abbildung der Funktion „Autocheckpoint“: (a) Ohne Autocheckpoint geht der Trainingsfortschritt ab dem letzten Checkpoint verloren, wenn ein Wartungsereignis bevorsteht. (b) Mit der automatischen Checkpoint-Funktion kann der Trainingsfortschritt seit dem letzten Checkpoint bei einem bevorstehenden Wartungsereignis beibehalten werden.

Mit dem automatischen Checkpoint (Abbildung 1(b)) können Sie den Trainingsfortschritt beibehalten, indem Sie Ihren Code so konfigurieren, dass ein nicht geplanter Checkpoint gespeichert wird, wenn ein Wartungsereignis auftritt. Wenn ein Wartungsereignis auftritt, wird der Fortschritt seit dem letzten Checkpoint automatisch gespeichert. Die Funktion funktioniert sowohl für einzelne Slices als auch für Multi-Slices.

Die Funktion „Autocheckpoint“ funktioniert mit Frameworks, die SIGTERM-Signale erfassen und anschließend einen Checkpoint speichern können. Zu den unterstützten Frameworks gehören:

Autocheckpoint verwenden

Die Funktion „Automatische Checkpoints“ ist standardmäßig deaktiviert. Wenn Sie eine TPU erstellen oder eine in die Warteschlange gestellte Ressource anfordern, können Sie Autocheckpoint aktivieren, indem Sie bei der Bereitstellung der TPU das Flag --autocheckpoint-enabled hinzufügen. Wenn die Funktion aktiviert ist, führt Cloud TPU die folgenden Schritte aus, sobald eine Benachrichtigung zu einem Wartungsereignis eingegangen ist:

  1. SIGTERM-Signal erfassen, das über das TPU-Gerät an den Prozess gesendet wird
  2. Warten Sie, bis der Vorgang beendet ist oder fünf Minuten vergangen sind, je nachdem, was zuerst eintritt.
  3. Wartung der betroffenen Segmente ausführen

Die von Autocheckpoint verwendete Infrastruktur ist unabhängig vom ML-Framework. Jedes ML-Framework kann Autocheckpoint unterstützen, wenn es das SIGTERM-Signal erfassen und einen Prüfpunktprozess initiieren kann.

Im Anwendungscode müssen Sie die vom ML-Framework bereitgestellten Funktionen für automatische Checkpoints aktivieren. In Pax bedeutet das beispielsweise, dass Befehlszeilen-Flags beim Starten des Trainings aktiviert werden müssen. Weitere Informationen finden Sie in der Kurzanleitung für Autocheckpoints mit Pax. Im Hintergrund speichern die Frameworks einen nicht geplanten Checkpoint, wenn ein SIGTERM-Signal empfangen wird. Die betroffene TPU-VM wird dann gewartet, wenn die TPU nicht mehr verwendet wird.

Kurzanleitung: Autocheckpoint mit MaxText

MaxText ist ein leistungsstarker, beliebig skalierbarer, Open-Source-LLM, der in reiner Python/JAX geschrieben wurde und auf Cloud TPUs ausgerichtet ist. MaxText enthält alle erforderlichen Einstellungen für die Verwendung der Funktion „Autocheckpoint“.

In der README-Datei für MaxText werden zwei Möglichkeiten beschrieben, MaxText im großen Maßstab auszuführen:

Wenn Sie multihost_runner.py verwenden, aktivieren Sie „Autocheckpoint“, indem Sie das Flag autocheckpoint-enabled beim Bereitstellen der Ressourcen in der Warteschlange setzen.

Wenn Sie multihost_job.py verwenden, aktivieren Sie den automatischen Checkpoint, indem Sie beim Starten des Jobs das Befehlszeilenflag ENABLE_AUTOCHECKPOINT=true angeben.

Kurzanleitung: Automatischer Checkpoint mit Pax auf einer einzelnen Scheibe

In diesem Abschnitt wird ein Beispiel für die Einrichtung und Verwendung von Autocheckpoint mit Pax auf einem einzelnen Slice beschrieben. Bei entsprechender Einrichtung:

  • Wenn ein Wartungsereignis auftritt, wird ein Checkpoint gespeichert.
  • Nach dem Speichern des Checkpoints führt Cloud TPU Wartungsarbeiten an den betroffenen TPU-VMs durch.
  • Sobald die Wartung von Cloud TPU abgeschlossen ist, können Sie die TPU-VM wie gewohnt verwenden.
  1. Verwenden Sie das Flag autocheckpoint-enabled, wenn Sie die TPU-VM erstellen oder eine in der Warteschlange befindliche Ressource anfordern.

    Beispiel:

    export PROJECT=your-gcp-project-name
    export ZONE=zone-you-want-to-use
    export NODE_ID=your-node-id
    export ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
        --accelerator-type $ACCELERATOR_TYPE \
        --version tpu-ubuntu2204-base \
        --autocheckpoint-enabled
  2. Stellen Sie eine SSH-Verbindung zur TPU her:

    gcloud compute tpus tpu-vm ssh $NODE_ID 
    
  3. Pax auf einem einzelnen Slice installieren

    Die Funktion „Autocheckpoint“ funktioniert mit Pax-Versionen ab 1.1.0. Installieren Sie auf der TPU-VM jax[tpu] und die neueste paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. Starten Sie das Training mit der entsprechenden Konfiguration.

    Im folgenden Beispiel wird gezeigt, wie Sie das LmCloudSpmd2B-Modell so konfigurieren, dass von Autocheckpoint ausgelöste Checkpoints in einem Cloud Storage-Bucket gespeichert werden. Ersetzen Sie your-storage-bucket durch den Namen eines vorhandenen Buckets oder erstellen Sie einen neuen Bucket.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Beachten Sie die beiden Flags, die an den Befehl übergeben werden:

    • jax_fully_async_checkpoint: Wenn dieses Flag aktiviert ist, wird orbax.checkpoint.AsyncCheckpointer verwendet. Die Klasse AsyncCheckpointer speichert automatisch einen Checkpoint, wenn das Trainingsscript ein SIGTERM-Signal empfängt.
    • exit_after_ondemand_checkpoint: Wenn dieses Flag aktiviert ist, wird der TPU-Prozess beendet, nachdem der Autocheckpoint erfolgreich gespeichert wurde. Dadurch wird die Wartung sofort ausgeführt. Wenn Sie dieses Flag nicht verwenden, wird das Training nach dem Speichern des Checkpoints fortgesetzt und Cloud TPU wartet 5 Minuten, bevor die erforderliche Wartung durchgeführt wird.

Kurzanleitung: Autocheckpoint mit Pax bei Multislice

Die automatische Checkpoint-Funktion funktioniert nicht nur für einzelne Scheiben, sondern auch für Mehrere Scheiben. In diesem Abschnitt wird beschrieben, wie Sie automatische Checkpoints mit Multislice verwenden.

  1. Geben Sie „Autocheckpoint“ an, wenn Sie Ressourcen in der Warteschlange erstellen.

    Eine Multi-Slice-Umgebung kann nur über eine in der Warteschlange befindliche Ressourcenanfrage bereitgestellt werden. Ähnlich wie beim Fall mit einer einzelnen Scheibe verwenden Sie das Flag autocheckpoint-enabled im Aufruf, um eine Ressourcenwarteschlange zu erstellen.

    export QR_ID=your-qr-id
    export NODE_COUNT=your-node-count
    export ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
        --node-count $NODE_COUNT \
        --accelerator-type $ACCELERATOR_TYPE \
        --runtime-version tpu-ubuntu2204-base \
        --autocheckpoint-enabled

    Weitere Informationen zu allen verfügbaren Optionen finden Sie im Multislice-Nutzerhandbuch. Wenn die angeforderte Ressource in der Warteschlange erstellt wurde und den Status ACTIVE hat, führen Sie die folgenden Schritte aus, um Pax mit Autocheckpoint auszuführen.

  2. Installieren Sie jax[tpu] und die neueste paxml auf allen TPU-VMs in Ihrer Multislice-Umgebung.

    gcloud compute tpus queued-resources ssh $QR_ID \
        --node=all \
        --worker=all \
        --batch-size=your-batch-size \
        --command="pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

    Legen Sie das Flag --batch-size auf die Anzahl der gleichzeitigen Verbindungen fest, die mit TPU-Workern hergestellt werden sollen. Weitere Informationen zur Auswahl einer Batchgröße für Multislice-Arbeitslasten finden Sie unter Training optimieren.

  3. Konfigurieren Sie das LmCloudSpmd2B-Modell für den automatischen Checkpoint, wenn Sie in einer Multislice-Umgebung trainieren. Legen Sie vor dem Ausführen des Trainingsscripts DCN_MESH_SHAPE auf [2, 1, 1] fest, wie im folgenden Beispiel gezeigt:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 4, 1]
        DCN_MESH_SHAPE = [2, 1, 1]
  4. Wenn Sie Wegpunkte häufiger setzen möchten, legen Sie task_p.train.save_interval_steps und task_p.train.save_max_to_keep fest, wie im folgenden Beispiel gezeigt:

    @experiment_registry.register
    class LmCloudSpmd2BLimitSteps(LmCloudSpmd2B):
    """SPMD model with 2B params and limited steps.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    
    def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
        task_p = super().task()
        task_p.train.save_interval_steps = 50
        task_p.train.save_max_to_keep = 5
        return task_p
    
  5. Starten Sie das Training, indem Sie den folgenden Befehl für jeden Host ausführen. Ersetzen Sie your-storage-bucket durch den Namen eines vorhandenen Buckets oder erstellen Sie einen neuen Bucket.

    export TF_CPP_MIN_LOG_LEVEL=0
    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --num_hosts=2 \
        --host_idx=host-index \
        --server_addr=worker0-node0-ip-address \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Beim Starten des Trainings sind zusätzlich zu den Befehlszeilenoptionen, die im Fall mit einer einzelnen Scheibe beschrieben wurden, noch drei weitere erforderlich:

    • num_hosts: die Gesamtzahl der Hosts. In diesem Fall ist das 2.
    • host_idx: Der Index des Hosts, der das Training startet. Sie variiert zwischen 0 und N-1, wobei N die Gesamtzahl der Hosts ist.
    • server_addr: die IP-Adresse von Worker 0 von Knoten 0 mit einem nicht verwendeten Port (z. B. 8476). Verwenden Sie dazu hostname -i auf Worker 0 von Knoten 0.

Automatischer Checkpoint mit Orbax

Die Funktion „Autocheckpoint“ ist nicht auf MaxText oder Pax beschränkt. Jedes Framework, das das SIGTERM-Signal erfassen und einen Checkpoint-Prozess initiieren kann, funktioniert mit der von Autocheckpoint bereitgestellten Infrastruktur. Diese Funktionen bietet Orbax, ein Namespace mit gängigen Dienstbibliotheken für JAX-Nutzer.

Wie in der Orbax-Dokumentation erläutert, sind diese Funktionen für Nutzer von orbax.checkpoint.CheckpointManager standardmäßig aktiviert. Die Methode save, die nach jedem Schritt aufgerufen wird, prüft automatisch, ob ein Wartungsereignis bevorsteht. Falls ja, wird ein Checkpoint gespeichert, auch wenn die Schrittnummer kein Vielfaches von save_interval_steps ist. In der GitHub-Dokumentation wird auch veranschaulicht, wie das Training nach dem Speichern eines automatischen Checkpoints beendet werden kann, indem der Nutzercode geändert wird.