Esegui il codice JAX nelle sezioni TPU

Prima di eseguire i comandi in questo documento, assicurati di aver seguito le istruzioni riportate in Configurare un account e un progetto Cloud TPU.

Dopo aver eseguito il codice JAX su una singola scheda TPU, puoi eseguire lo scaling up del codice eseguendolo su una sezione TPU. Le sezioni TPU sono più schede TPU collegate tra loro tramite connessioni di rete ad alta velocità dedicate. Questo documento è un'introduzione all'esecuzione del codice JAX su sezioni TPU. Per informazioni più approfondite, consulta Utilizzare JAX in ambienti multi-host e multi-processo.

Creare un segmento Cloud TPU

  1. Crea alcune variabili di ambiente:

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5litepod-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export TPU_NAME=your-tpu-name

    Descrizioni delle variabili di ambiente

    PROJECT_ID
    Il tuo Google Cloud ID progetto.
    ACCELERATOR_TYPE
    Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratori supportati per ogni versione TPU, consulta Versioni TPU.
    ZONE
    La zona in cui prevedi di creare la tua Cloud TPU.
    RUNTIME_VERSION
    La versione del runtime Cloud TPU.
    TPU_NAME
    Il nome assegnato dall'utente alla tua Cloud TPU.
  2. Crea uno slice TPU utilizzando il comando gcloud. Ad esempio, per creare un slice v5litepod-32, utilizza il seguente comando:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --accelerator-type=${ACCELERATOR_TYPE}  \
    --version=${RUNTIME_VERSION} 

Installa JAX nella tua slice

Dopo aver creato la sezione TPU, devi installare JAX su tutti gli host della sezione. Puoi farlo utilizzando il comando gcloud compute tpus tpu-vm ssh con i parametri --worker=all e --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Esegui il codice JAX nella sezione

Per eseguire il codice JAX in una sezione TPU, devi eseguire il codice su ogni host della sezione TPU. La chiamata jax.device_count() smette di rispondere finché non viene invocata su ogni host nel segmento. L'esempio seguente illustra come eseguire un calcolo JAX su uno slice TPU.

Preparare il codice

È necessaria la versione gcloud >= 344.0.0 (per il comando scp). Utilizza gcloud --version per controllare la versione di gcloud ed eseguire gcloud components upgrade, se necessario.

Crea un file denominato example.py con il seguente codice:


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copia example.py in tutte le VM worker TPU del segmento

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Se non hai mai utilizzato il comando scp, potresti visualizzare un errore simile al seguente:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Per risolvere l'errore, esegui il comando ssh-add visualizzato nel messaggio di errore e riavvialo.

Esegui il codice sul segmento

Avvia il programma example.py su ogni VM:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Output (generato con una fetta v5litepod-32):

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

Esegui la pulizia

Al termine dell'utilizzo della VM TPU, segui questi passaggi per ripulire le risorse.

  1. Elimina le risorse Cloud TPU e Compute Engine.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. Verifica che le risorse siano state eliminate eseguendo gcloud compute tpus execution-groups list. L'eliminazione potrebbe richiedere alcuni minuti. L'output del seguente comando non deve includere nessuna delle risorse create in questo tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}