Cómo conservar el progreso del entrenamiento con Autocheckpoint

Históricamente, cuando una VM de TPU requiere mantenimiento, el procedimiento se inicia de inmediato, sin dejar tiempo para que los usuarios realicen acciones que preserven el progreso, como guardar un punto de control. Esto se muestra en la figura 1(a).

Diagrama que muestra el impacto del mantenimiento del host con y sin la función de autocheckpointing

Fig. 1. Ilustración de la función Autocheckpoint: (a) Sin Autocheckpoint, el progreso del entrenamiento desde el último punto de control se pierde cuando hay un evento de mantenimiento próximo. (b) Con Autocheckpoint, se puede conservar el progreso del entrenamiento desde el último punto de control cuando hay un evento de mantenimiento próximo.

Puedes usar Autocheckpoint (figura 1(b)) para conservar el progreso del entrenamiento. Para ello, configura tu código para que guarde un punto de control no programado cuando ocurra un evento de mantenimiento. Cuando se produce un evento de mantenimiento, se guarda automáticamente el progreso desde el último punto de control. La función funciona tanto en porciones únicas como en Multislice.

La función Autocheckpoint funciona con frameworks que pueden capturar señales SIGTERM y, luego, guardar un punto de control. Los frameworks compatibles incluyen los siguientes:

Cómo usar Autocheckpoint

La función de Autocheckpoint está inhabilitada de forma predeterminada. Cuando creas una TPU o solicitas un recurso en cola, puedes habilitar Autocheckpoint agregando la marca --autocheckpoint-enabled cuando aprovisiones la TPU. Con la función habilitada, Cloud TPU realiza los siguientes pasos una vez que recibe la notificación de un evento de mantenimiento:

  1. Captura la señal SIGTERM enviada al proceso con el dispositivo de TPU
  2. Espera hasta que finalice el proceso o hasta que transcurran 5 minutos, lo que ocurra primero.
  3. Realiza el mantenimiento de las segmentaciones afectadas

La infraestructura que usa Autocheckpoint es independiente del framework de AA. Cualquier framework de AA puede admitir Autocheckpoint si puede capturar el indicador SIGTERM y, luego, iniciar un proceso de creación de puntos de control.

En el código de la aplicación, debes habilitar las capacidades de Autocheckpoint que proporciona el framework de AA. En Pax, por ejemplo, esto significa habilitar marcas de línea de comandos cuando se inicia el entrenamiento. Para obtener más información, consulta la guía de inicio rápido de Autocheckpoint con Pax. En segundo plano, los frameworks guardan un punto de control no programado cuando se recibe un indicador SIGTERM, y la VM de TPU afectada se somete a mantenimiento cuando la TPU ya no está en uso.

Guía de inicio rápido: Autocheckpoint con MaxText

MaxText es un LLM de código abierto, de alto rendimiento y escalable de forma arbitraria, que se escribió en Python/JAX puro y se orienta a las Cloud TPU. MaxText contiene toda la configuración necesaria para usar la función de Autocheckpoint.

El archivo README de MaxText describe dos formas de ejecutar MaxText a gran escala:

Cuando uses multihost_runner.py, habilita Autocheckpoint configurando la marca autocheckpoint-enabled cuando aprovisiones el recurso en cola.

Cuando uses multihost_job.py, habilita Autocheckpoint especificando la marca de línea de comandos ENABLE_AUTOCHECKPOINT=true cuando inicies el trabajo.

Guía de inicio rápido: Autocheckpoint con Pax en una sola división

En esta sección, se proporciona un ejemplo de cómo configurar y usar Autocheckpoint con Pax en una sola división. Con la configuración adecuada, sucede lo siguiente:

  • Se guardará un punto de control cuando se produzca un evento de mantenimiento.
  • Cloud TPU realizará el mantenimiento de las VMs de TPU afectadas después de que se guarde el punto de control.
  • Cuando la Cloud TPU complete el mantenimiento, podrás usar la VM de TPU como de costumbre.
  1. Usa la marca autocheckpoint-enabled cuando crees la VM de TPU o solicites un recurso en cola.

    Por ejemplo:

    1. Establece las variables de entorno:

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      Descripciones de las variables de entorno

      Variable Descripción
      PROJECT_ID El ID de tu proyecto Google Cloud . Usa un proyecto existente o crea uno nuevo.
      TPU_NAME Nombre de la TPU.
      ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU.
      ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
      RUNTIME_VERSION Versión de software de Cloud TPU.

    2. Establece el ID y la zona del proyecto en tu configuración activa:

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. Crea una TPU:

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. Conéctate a la TPU con SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. Instala Pax en una sola porción

    La función Autocheckpoint funciona en las versiones 1.1.0 y posteriores de Pax. En la VM de TPU, instala jax[tpu] y la versión más reciente de paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. Configura el modelo LmCloudSpmd2B. Antes de ejecutar la secuencia de comandos de entrenamiento, cambia ICI_MESH_SHAPE por [1, 8, 1]:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. Inicia el entrenamiento con la configuración adecuada.

    En el siguiente ejemplo, se muestra cómo configurar el modelo LmCloudSpmd2B para guardar los puntos de control activados por Autocheckpoint en un bucket de Cloud Storage. Reemplaza your-storage-bucket por el nombre de un bucket existente o crea uno nuevo.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Ten en cuenta las dos marcas que se pasan al comando:

    • jax_fully_async_checkpoint: Con esta marca activada, se usará orbax.checkpoint.AsyncCheckpointer. La clase AsyncCheckpointer guarda automáticamente un punto de control cuando la secuencia de comandos de entrenamiento recibe un indicador SIGTERM.
    • exit_after_ondemand_checkpoint: Con esta marca activada, el proceso de TPU se cierra después de que se guarda correctamente el Autocheckpoint, lo que activa el mantenimiento para que se realice de inmediato. Si no usas esta marca, el entrenamiento continuará después de que se guarde el punto de control y la Cloud TPU esperará a que se produzca un tiempo de espera (5 minutos) antes de realizar el mantenimiento requerido.

Cómo crear puntos de control automáticamente con Orbax

La función Autocheckpoint no se limita a MaxText ni a Pax. Cualquier framework que pueda capturar la señal SIGTERM y, luego, iniciar un proceso de creación de puntos de control funciona con la infraestructura que proporciona Autocheckpoint. Orbax, un espacio de nombres que proporciona bibliotecas de utilidades comunes para los usuarios de JAX, ofrece estas capacidades.

Como se explica en la documentación de Orbax, estas capacidades están habilitadas de forma predeterminada para los usuarios de orbax.checkpoint.CheckpointManager. El método save que se llama después de cada paso verifica automáticamente si se acerca un evento de mantenimiento y, si es así, guarda un punto de control incluso si el número de paso no es un múltiplo de save_interval_steps. En la documentación de GitHub, también se ilustra cómo hacer que el entrenamiento finalice después de guardar un Autocheckpoint, con una modificación en el código del usuario.