Punto de control automático de Cloud TPU [versión preliminar pública]
Descripción general
Históricamente, cuando una VM de TPU requiere mantenimiento, el procedimiento se inicia de inmediato, sin dejar tiempo para que los usuarios realizar acciones para preservar el progreso, como guardar un punto de control Este es se muestra en la Figura 1(a).
Fig. 1) Ilustración de la función Punto de control automático: (a) Sin el punto de control automático, el progreso del entrenamiento desde el último punto de control se pierda cuando haya un evento de mantenimiento próximo. (b) Con el punto de control automático, el progreso del entrenamiento desde el último punto de control se puede conservar cuando hay un evento de mantenimiento próximo.
Puedes usar el punto de control automático (Figura 1(b)) para preservar el progreso del entrenamiento. Para ello, configura tu código para que guarde un punto de control no programado cuando se produzca un evento de mantenimiento. Cuando ocurre un evento de mantenimiento, el progreso desde el último el punto de control se guarda automáticamente. Esta función funciona en segmentos y Multislice.
La función Punto de control automático funciona con frameworks que pueden capturar usar SIGTERM y, posteriormente, guardar un punto de control. Entre los frameworks compatibles, se incluyen los siguientes: MaxText, Pax, y JAX con Orbax Se anunciará la compatibilidad con frameworks adicionales a medida que estén disponibles.
Solo las TPU (v2-v4 y v5e) creadas a través de la API de Cloud TPU pueden usan esta función por ahora. La compatibilidad con las TPU en GKE se anunciará cuando esté disponible.
Usa el punto de control automático
La funcionalidad de punto de control automático está inhabilitada de forma predeterminada. Cuando creas un
TPU o un recurso en cola,
puedes habilitarlo agregando la marca --autocheckpoint-enabled
cuando aprovisiones
la TPU.
Con la función habilitada, Cloud TPU
realiza los siguientes pasos una vez que recibe la notificación de un
evento de mantenimiento:
- Capturar el SIGTERM enviado al proceso con el dispositivo de TPU
- Espera hasta que el proceso finalice o hayan transcurrido 5 minutos, lo que sea primero y realiza el mantenimiento de las porciones afectadas.
Ten en cuenta que la infraestructura que usa Autocheckpoint no depende del framework del AA. Cualquier framework de AA admite el punto de control automático, siempre que pueda capturar la señal SIGTERM e iniciar un proceso de punto de control.
En el código de la aplicación, debes habilitar el punto de control automático. del framework de AA. En Pax, por ejemplo, Esto significa habilitar las marcas de línea de comandos cuando se inicie el (consulta la Guía de inicio rápido del punto de control automático con Pax). En segundo plano, los frameworks ahorran una punto de control no programado cuando se recibe un SIGTERM y la VM de la TPU afectada pasa por un mantenimiento cuando la TPU deja de estar en uso.
Guía de inicio rápido: Punto de control automático con MaxText
MaxText es una campaña de generación LLM arbitrariamente escalable, de código abierto y bien probado escrito en Python/JAX puro orientadas a las Cloud TPU”. MaxText contiene toda la configuración necesaria para usar el punto de control automático .
En el archivo readme de MaxText, se describen dos formas de ejecutar MaxText a gran escala:
- Uso de
multihost_runner.py
, recomendado para la experimentación - Usando
multihost_job.job
(recomendado para producción)
Cuando se usa multihost_runner.py
, el único cambio necesario
es configurar la marca autocheckpoint-enabled
cuando aprovisiones
el recurso en cola. Cuando uses
multihost_job.py
, el único cambio necesario es especificar el
La marca de línea de comandos ENABLE_AUTOCHECKPOINT=true
cuando inicias el trabajo.
Guía de inicio rápido: Punto de control automático con Pax en una sola porción
En esta sección, proporcionamos un ejemplo de cómo configurar y usar Autocheckpoint con Pax en una sola porción. Con la configuración adecuada:
- Se guardará un punto de control cuando se produzca un evento de mantenimiento.
- Cloud TPU realizará el mantenimiento en las VMs de TPU afectadas después de se guarda el punto de control.
- Cuando Cloud TPU complete el mantenimiento, podrás usar la VM de TPU como de costumbre.
Usa la marca
autocheckpoint-enabled
cuando crees la VM de TPU o recurso en cola.Por ejemplo:
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Instala Pax en una sola porción
La función de punto de control automático es compatible con las versiones de Pax posteriores a la 1.1.0. En las VMs de TPU, instala
jax[tpu]
y la última versión depaxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Inicia el entrenamiento con la configuración adecuada
En el siguiente ejemplo, se muestra cómo configurar
LmCloudSpmd2B
para guardar los puntos de control activados por el punto de control automático en un bucket de Google Cloud Storage:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Ten en cuenta las dos marcas que se pasan al comando:
jax_fully_async_checkpoint
: Con esta marca activada, se usaráorbax.checkpoint.AsyncCheckpointer
. La claseAsyncCheckpointer
guarda automáticamente un punto de control cuando la secuencia de comandos de entrenamiento recibe una señal SIGTERM.exit_after_ondemand_checkpoint
: Con esta marca activada, el proceso de TPU finaliza después de El punto de control automático se guardó correctamente, lo que activa el el mantenimiento se realice de inmediato. Si no la usas, marca, el entrenamiento continuará después de que se guarde el punto de control y Cloud TPU esperará a que se agote el tiempo de espera (5 minutos) antes de realizar el mantenimiento requerido.
Guía de inicio rápido: Punto de control automático con Pax en Multislice
El punto de control automático funciona no solo para una sola porción, sino también para Multislice. Esta sección detalla los pasos necesarios para usar el punto de control automático con Multislice.
Especifica el punto de control automático durante la creación de recursos en cola.
Un entorno de Multislice solo se puede aprovisionar a través de una solicitud de recurso en fila. Al igual que en el caso de una sola porción, usa la marca
autocheckpoint-enabled
en la llamada para crear un recurso en cola.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Consulta la Guía del usuario de Multislice para conocer los detalles de todas las opciones disponibles. Una vez que el recurso en cola se crea una solicitud En el estado
ACTIVE
, sigue los siguientes pasos para ejecutar Pax con Punto de control automáticoInstalar Pax en todas las VMs del entorno de Multislice.
En las VMs de TPU, instala
jax[tpu]
y la últimapaxml
. en todas las VMs de TPU en tu entorno de Multislice:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Inicia el entrenamiento con la configuración adecuada
En este ejemplo, se muestra cómo configurar el modelo
LmCloudSpmd2B
para el punto de control automático cuando se entrena en un entorno de Multislice. Antes de ejecutar la secuencia de comandos de entrenamiento, establece DCN_MESH_SHAPE en [2, 1, 1], como se muestra en el siguiente código:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
Cuando inicies el entrenamiento, además de las marcas de línea de comandos analizadas en el caso de una sola porción, se requieren tres más:
num_hosts
: Es la cantidad total de hosts. En este caso, es 2.host_index
: Es el índice del host que inicia el entrenamiento. Varía de 0 aN-1
, en el queN
es la cantidad total de hosts.server_addr
: Es la dirección IP del trabajador 0 del nodo 0, con un valor no utilizado. (por ejemplo, 8476). Para encontrar esta información, usahostname -i
. en el trabajador 0 del nodo 0.
Punto de control automático con Orbax
La función Punto de control automático no se limita a MaxText ni Pax. Cualquier framework que pueda capturar la señal SIGTERM e iniciar una funciona con la infraestructura que proporciona Autocheckpoint. Orbax, un espacio de nombres que proporciona bibliotecas de utilidades comunes para usuarios de JAX, proporciona estas capacidades.
Como se explica en la documentación de Orbax,
estas funciones están habilitadas
de forma predeterminada para los usuarios
de orbax.checkpoint.CheckpointManager
. El método save
al que se llama después de cada paso verifica automáticamente si hay un evento de mantenimiento inminente y, de ser así, guarda un punto de control, incluso si el número de paso no es un múltiplo de save_interval_steps
.
La documentación de GitHub
también ilustra cómo realizar la salida del entrenamiento después de guardar un
Punto de control automático, con una modificación en el código de usuario.