Preservare i progressi dell'addestramento utilizzando il controllo automatico
In passato, quando una VM TPU richiedeva manutenzione, la procedura veniva avviata immediatamente, senza lasciare tempo agli utenti di eseguire azioni che preservano i progressi, come il salvataggio di un checkpoint. Come mostrato nella Figura 1(a).
Figura 1. Illustrazione della funzionalità Autocheckpoint: (a) senza Autocheckpoint, l'avanzamento dell'addestramento dall'ultimo checkpoint viene perso quando è imminente un evento di manutenzione. (b) Con il controllo automatico, i progressi dell'addestramento dall'ultimo controllo possono essere preservati in caso di evento di manutenzione imminente.
Puoi utilizzare il controllo automatico (Figura 1(b)) per preservare i progressi dell'addestramento configurando il codice in modo da salvare un controllo non pianificato quando si verifica un evento di manutenzione. Quando si verifica un evento di manutenzione, i progressi dall'ultimo checkpoint vengono salvati automaticamente. La funzionalità funziona sia su singole sezioni che su Multislice.
La funzionalità Autocheckpoint funziona con i framework che possono acquisire indicatori SIGTERM e successivamente salvare un checkpoint. I framework supportati include:
Utilizzo di Controllo automatico
La funzionalità di controllo automatico è disattivata per impostazione predefinita. Quando crei un TPU o richiedi una risorsa in coda, puoi attivare il controllo automatico aggiungendo il flag --autocheckpoint-enabled
durante il provisioning del TPU.
Con la funzionalità attivata, Cloud TPU
esegue i seguenti passaggi dopo aver ricevuto la notifica di un
evento di manutenzione:
- Acquisisci il segnale SIGTERM inviato al processo che utilizza il dispositivo TPU
- Attendi che il processo termini o che siano trascorsi 5 minuti, a seconda dell'evento che si verifica per primo
- Eseguire la manutenzione dei segmenti interessati
L'infrastruttura utilizzata da Autocheckpoint è indipendente dal framework ML. Qualsiasi framework ML può supportare il controllo automatico se è in grado di acquisire l'indicatore SIGTERM e avviare un processo di controllo.
Nel codice dell'applicazione, devi attivare le funzionalità di controllo automatico fornite dal framework ML. In Pax, ad esempio, questo significa attivare i flag della riga di comando al momento dell'avvio dell'addestramento. Per ulteriori informazioni, consulta la guida rapida all'utilizzo di Autocheckpoint con Pax. Dietro le quinte, i framework salvano un controllo non pianificato quando viene ricevuto un segnale SIGTERM e la VM TPU interessata viene sottoposta a manutenzione quando la TPU non è più in uso.
Guida rapida: controllo automatico con MaxText
MaxText è un LLM open source ad alte prestazioni, scalabile in modo arbitrario e ben testato, scritto in puro Python/JAX e rivolto alle Cloud TPU. MaxText contiene tutta la configurazione necessaria per utilizzare la funzionalità Punto di controllo automatico.
Il file README
MaxText descrive
due modi per eseguire MaxText su larga scala:
- Utilizzo di
multihost_runner.py
, consigliato per la sperimentazione - Utilizzo di
multihost_job.py
, consigliato per la produzione
Quando utilizzi multihost_runner.py
, attiva il controllo automatico impostando il flag autocheckpoint-enabled
durante il provisioning della risorsa in coda.
Quando utilizzi
multihost_job.py
, abilita il controllo automatico specificando il
ENABLE_AUTOCHECKPOINT=true
flag a riga di comando al momento dell'avvio del job.
Guida rapida: controllo automatico con Pax su un singolo slice
Questa sezione fornisce un esempio di come configurare e utilizzare il controllo automatico con Pax su un singolo slice. Con la configurazione appropriata:
- Un checkpoint verrà salvato quando si verifica un evento di manutenzione.
- Cloud TPU eseguirà la manutenzione delle VM TPU interessate dopo la salvataggio del checkpoint.
- Al termine della manutenzione di Cloud TPU, puoi utilizzare la VM TPU come di consueto.
Utilizza il flag
autocheckpoint-enabled
quando crei la VM TPU o richiedi una risorsa in coda.Ad esempio:
Imposta le variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=zone-you-want-to-use export ACCELERATOR_TYPE=your-accelerator-type export RUNTIME_VERSION=tpu-ubuntu2204-base
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
Il tuo Google Cloud ID progetto. Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, consulta Versioni TPU. RUNTIME_VERSION
La versione software di Cloud TPU. Imposta l'ID progetto e la zona nella configurazione attiva:
gcloud config set project $PROJECT_ID gcloud config set compute/zone $ZONE
Crea una TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME \ --accelerator-type $ACCELERATOR_TYPE \ --version $RUNTIME_VERSION \ --autocheckpoint-enabled
Connettiti alla TPU tramite SSH:
gcloud compute tpus tpu-vm ssh $TPU_NAME
Installare Pax su un singolo slice
La funzionalità di controllo automatico funziona su Pax 1.1.0 e versioni successive. Nella VM TPU, installa
jax[tpu]
e la versione più recente dipaxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Configura il modello
LmCloudSpmd2B
. Prima di eseguire lo script di addestramento, cambiaICI_MESH_SHAPE
in[1, 8, 1]
:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 8, 1]
Avvia l'addestramento con la configurazione appropriata.
L'esempio seguente mostra come configurare il modello
LmCloudSpmd2B
per salvare i checkpoint attivati da Autocheckpoint in un bucket Cloud Storage. Sostituisci your-storage-bucket con il nome di un bucket esistente o creane uno nuovo.export JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py \ --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Tieni presente i due flag passati al comando:
jax_fully_async_checkpoint
: se questo flag è attivo, verrà utilizzatoorbax.checkpoint.AsyncCheckpointer
. La classeAsyncCheckpointer
salva automaticamente un checkpoint quando lo script di addestramento riceve un segnale SIGTERM.exit_after_ondemand_checkpoint
: se questo flag è attivo, il processo TPU esce dopo il salvataggio corretto del checkpoint automatico, attivando l'esecuzione immediata della manutenzione. Se non utilizzi questo flag, l'addestramento continuerà dopo il salvataggio del checkpoint e Cloud TPU attenderà che si verifichi un timeout (5 minuti) prima di eseguire la manutenzione richiesta.
Controllo automatico con Orbax
La funzionalità di controllo automatico non è limitata a MaxText o Pax. Qualsiasi framework che può acquisire l'indicatore SIGTERM e avviare un procedura di checkpointing funziona con l'infrastruttura fornita da Autocheckpoint. Orbax, uno spazio dei nomi che fornisce librerie di utilità comuni per gli utenti di JAX, offre queste funzionalità.
Come spiegato nella documentazione di Orbax, queste funzionalità sono abilitate per impostazione predefinita per gli utenti di orbax.checkpoint.CheckpointManager
. Il metodo save
chiamato dopo ogni passaggio controlla automaticamente se è imminente un evento di manutenzione e, in questo caso, salva un checkpoint anche se il numero del passaggio non è un multiplo di save_interval_steps
.
La documentazione di GitHub illustra anche come far uscire l'addestramento dopo aver salvato un checkpoint automatico, con una modifica nel codice utente.