Preservar o progresso do treinamento usando o Autocheckpoint

Historicamente, quando uma VM TPU requer manutenção, o procedimento é iniciado imediatamente, sem deixar tempo para que os usuários realizem ações de preservação de progresso, como salvar um ponto de verificação. Isso é mostrado na Figura 1(a).

Diagrama mostrando o impacto da manutenção do host com e sem checkpoint automático

Figura 1. Ilustração do recurso de checkpoint automático: (a) Sem o checkpoint automático, o progresso do treinamento do último checkpoint é perdido quando há um evento de manutenção. (b) Com o ponto de verificação automático, o progresso do treinamento desde o último ponto de verificação pode ser preservado quando há um evento de manutenção.

Você pode usar o Autocheckpoint (Figura 1(b)) para preservar o progresso do treinamento, configurando o código para salvar um ponto de controle não programado quando um evento de manutenção ocorrer. Quando um evento de manutenção ocorre, o progresso desde o último ponto de verificação é salvo automaticamente. O recurso funciona em fatias únicas e em fatias múltiplas.

O recurso de checkpoint automático funciona com frameworks que podem capturar sinais SIGTERM e, em seguida, salvar um checkpoint. Os frameworks compatíveis incluem:

Como usar o checkpoint automático

O recurso de verificação automática fica desativado por padrão. Ao criar um TPU ou solicitar um recurso enfileirado, é possível ativar o ponto de verificação automático adicionando a flag --autocheckpoint-enabled ao provisionar o TPU. Com o recurso ativado, o Cloud TPU executa as etapas a seguir quando recebe a notificação de um evento de manutenção:

  1. Capturar o sinal SIGTERM enviado para o processo usando o dispositivo TPU
  2. Aguarde até que o processo seja encerrado ou 5 minutos tenham se passado, o que ocorrer primeiro.
  3. Realizar a manutenção das fatias afetadas

A infraestrutura usada pelo Autocheckpoint é independente do framework de ML. Qualquer framework de ML pode oferecer suporte ao autocheckpoint se puder capturar o sinal SIGTERM e iniciar um processo de checkpoint.

No código do aplicativo, é necessário ativar os recursos de verificação automática fornecidos pelo framework de ML. No Pax, por exemplo, isso significa ativar as flags de linha de comando ao iniciar o treinamento. Para mais informações, consulte o Guia de início rápido do Checkpoint automático com Pax. Nos bastidores, os frameworks salvam um ponto de verificação não programado quando um sinal SIGTERM é recebido, e a VM de TPU afetada passa por manutenção quando a TPU não está mais em uso.

Guia de início rápido: checkpoint automático com MaxText

O MaxText é um LLM de alto desempenho, escalonável de forma arbitrária, de código aberto e bem testado, escrito em Python/JAX puro para Cloud TPUs. O MaxText contém toda a configuração necessária para usar o recurso de verificação automática.

O arquivo README MaxText descreve duas maneiras de executar o MaxText em escala:

Ao usar multihost_runner.py, ative o ponto de controle automático definindo a flag autocheckpoint-enabled ao provisionar o recurso na fila.

Ao usar multihost_job.py, ative o ponto de verificação automático especificando a flag de linha de comando ENABLE_AUTOCHECKPOINT=true ao iniciar o job.

Guia de início rápido: verificação automática com Pax em uma única fatia

Esta seção mostra um exemplo de como configurar e usar o Autocheckpoint com Pax em uma única fatia. Com a configuração adequada:

  • Um ponto de controle será salvo quando um evento de manutenção ocorrer.
  • A Cloud TPU vai realizar a manutenção nas VMs afetadas depois que o ponto de verificação for salvo.
  • Quando a manutenção da Cloud TPU for concluída, você poderá usar a VM TPU normalmente.
  1. Use a flag autocheckpoint-enabled ao criar a VM TPU ou solicitar um recurso na fila.

    Exemplo:

    export PROJECT=your-gcp-project-name
    export ZONE=zone-you-want-to-use
    export NODE_ID=your-node-id
    export ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
        --accelerator-type $ACCELERATOR_TYPE \
        --version tpu-ubuntu2204-base \
        --autocheckpoint-enabled
  2. Conecte-se à TPU usando SSH:

    gcloud compute tpus tpu-vm ssh $NODE_ID
    
  3. Instalar o Pax em uma única fatia

    O recurso de verificação automática funciona nas versões 1.1.0 e mais recentes do Pax. Na VM da TPU, instale o jax[tpu] e o paxml mais recente:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. Configure o modelo LmCloudSpmd2B. Antes de executar o script de treinamento, mude ICI_MESH_SHAPE para [1, 8, 1]:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. Inicie o treinamento com a configuração adequada.

    O exemplo a seguir mostra como configurar o modelo LmCloudSpmd2B para salvar pontos de verificação acionados pelo Autocheckpoint em um bucket do Cloud Storage. Substitua your-storage-bucket pelo nome de um bucket existente ou crie um novo bucket.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Observe as duas flags transmitidas para o comando:

    • jax_fully_async_checkpoint: com essa flag ativada, orbax.checkpoint.AsyncCheckpointer será usado. A classe AsyncCheckpointer salva automaticamente um ponto de controle quando o script de treinamento recebe um sinal SIGTERM.
    • exit_after_ondemand_checkpoint: com essa flag ativada, o processo da TPU é encerrado depois que o checkpoint automático é salvo, o que aciona a manutenção para ser realizada imediatamente. Se você não usar essa flag, o treinamento vai continuar depois que o checkpoint for salvo e o Cloud TPU vai esperar que um tempo limite ocorra (5 minutos) antes de realizar a manutenção necessária.

Ponto de controle automático com Orbax

O recurso de verificação automática não é limitado a MaxText ou Pax. Qualquer framework que possa capturar o sinal SIGTERM e iniciar um processo de verificação funciona com a infraestrutura fornecida pelo Autocheckpoint. O Orbax, um namespace que oferece bibliotecas de utilitários comuns para usuários do JAX, oferece esses recursos.

Conforme explicado na documentação do Orbex, esses recursos são ativados por padrão para os usuários do orbax.checkpoint.CheckpointManager. O método save que é chamado após cada etapa verifica automaticamente se um evento de manutenção está iminente e, se for o caso, salva um ponto de controle mesmo que o número da etapa não seja um múltiplo de save_interval_steps. A documentação do GitHub (link em inglês) também ilustra como fazer com que o treinamento saia após salvar um Autocheckpoint, com uma modificação no código do usuário.