Addestra un modello utilizzando TPU v5e

Con un footprint più piccolo di 256 chip per pod, TPU v5e è ottimizzata per essere un prodotto di alto valore per l'addestramento, l'ottimizzazione e la pubblicazione di transformer, modelli di sintesi di immagini dal testo e reti neurali convoluzionali (CNN). Per saperne di più sull'utilizzo di Cloud TPU v5e per la pubblicazione, consulta Inferenza con v5e.

Per saperne di più sull'hardware e sulle configurazioni TPU v5e, consulta la sezione TPU v5e.

Inizia

Le sezioni seguenti descrivono come iniziare a utilizzare TPU v5e.

Quota per le richieste

Per utilizzare TPU v5e per l'addestramento, devi disporre di una quota. Esistono diversi tipi di quote per TPU on demand, TPU riservate e VM spot TPU. Se utilizzi la TPU v5e per l'inferenza, sono necessarie quote separate. Per ulteriori informazioni sulle quote, consulta Quote. Per richiedere la quota TPU v5e, contatta il team di vendite Cloud.

Crea un account e un progetto Google Cloud

Per utilizzare Cloud TPU, devi disporre di un account e di un progetto Google Cloud . Per maggiori informazioni, vedi Configurare un ambiente Cloud TPU.

Crea una Cloud TPU

La best practice consiste nel eseguire il provisioning di Cloud TPU v5e come risorse in coda utilizzando il comando queued-resource create. Per saperne di più, consulta Gestire le risorse in coda.

Puoi anche utilizzare l'API Create Node (gcloud compute tpus tpu-vm create) per eseguire il provisioning delle Cloud TPU v5e. Per saperne di più, consulta Gestire le risorse TPU.

Per ulteriori informazioni sulle configurazioni v5e disponibili per l'addestramento, consulta la sezione Tipi di Cloud TPU v5e per l'addestramento.

Configurazione del framework

Questa sezione descrive la procedura di configurazione generale per l'addestramento di modelli personalizzati utilizzando JAX o PyTorch con TPU v5e.

Per le istruzioni di configurazione dell'inferenza, vedi Introduzione all'inferenza v5e.

Definisci alcune variabili di ambiente:

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id

Configurazione di JAX

Se hai forme di slice maggiori di 8 chip, avrai più VM in uno slice. In questo caso, devi utilizzare il flag --worker=all per eseguire l'installazione su tutte le VM TPU in un unico passaggio senza utilizzare SSH per accedere a ciascuna separatamente:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Descrizioni dei flag dei comandi

Variabile Descrizione
TPU_NAME L'ID testo assegnato dall'utente della TPU creato quando viene allocata la richiesta di risorse in coda.
PROJECT_ID Google Cloud Nome progetto. Utilizza un progetto esistente o creane uno nuovo in Configura il progetto Google Cloud
ZONE Consulta il documento Regioni e zone TPU per le zone supportate.
worker La VM TPU che ha accesso alle TPU sottostanti.

Puoi eseguire il seguente comando per controllare il numero di dispositivi (gli output mostrati qui sono stati prodotti con una sezione v5litepod-16). Questo codice verifica che tutto sia installato correttamente controllando che JAX veda i TensorCore di Cloud TPU e possa eseguire operazioni di base:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

L'output sarà simile al seguente:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() mostra il numero totale di chip nella sezione specificata. jax.local_device_count() indica il numero di chip accessibili da una singola VM in questa sezione.

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

L'output sarà simile al seguente:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

Prova i tutorial JAX in questo documento per iniziare l'addestramento v5e utilizzando JAX.

Configurazione per PyTorch

Tieni presente che v5e supporta solo il runtime PJRT e PyTorch 2.1+ utilizzerà PJRT come runtime predefinito per tutte le versioni di TPU.

Questa sezione descrive come iniziare a utilizzare PJRT su v5e con PyTorch/XLA con comandi per tutti i worker.

Installa le dipendenze

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip install mkl mkl-include
      pip install tf-nightly tb-nightly tbp-nightly
      pip install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Sostituisci PYTORCH_VERSION con la versione di PyTorch che vuoi utilizzare. PYTORCH_VERSION viene utilizzato per specificare la stessa versione per PyTorch/XLA. Si consiglia la versione 2.6.0.

Per ulteriori informazioni sulle versioni di PyTorch e PyTorch/XLA, consulta PyTorch - Get Started e PyTorch/XLA releases.

Per saperne di più sull'installazione di PyTorch/XLA, consulta Installazione di PyTorch/XLA.

Se ricevi un errore durante l'installazione delle ruote per torch, torch_xla o torchvision come pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222, esegui il downgrade della versione con questo comando:

pip3 install setuptools==62.1.0

Eseguire uno script con PJRT

unset LD_PRELOAD

Di seguito è riportato un esempio che utilizza uno script Python per eseguire un calcolo su una VM v5e:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

Viene generato un output simile al seguente:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

Prova i tutorial di PyTorch in questo documento per iniziare l'addestramento v5e utilizzando PyTorch.

Elimina la TPU e la risorsa in coda al termine della sessione. Per eliminare una risorsa in coda, elimina la sezione e poi la risorsa in coda in due passaggi:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Questi due passaggi possono essere utilizzati anche per rimuovere le richieste di risorse in coda con stato FAILED.

Esempi di JAX/FLAX

Le sezioni seguenti descrivono esempi di come addestrare modelli JAX e FLAX su TPU v5e.

Addestra ImageNet su v5e

Questo tutorial descrive come addestrare ImageNet su v5e utilizzando dati di input falsi. Se vuoi utilizzare dati reali, consulta il file README su GitHub.

Configura

  1. Crea variabili di ambiente:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descrizioni delle variabili di ambiente

    Variabile Descrizione
    PROJECT_ID L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo.
    TPU_NAME Il nome della TPU.
    ZONE La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU.
    ACCELERATOR_TYPE Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU.
    RUNTIME_VERSION La versione software di Cloud TPU.
    SERVICE_ACCOUNT L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud .

    Ad esempio: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID L'ID testo assegnato dall'utente della richiesta di risorsa in coda.

  2. Crea una risorsa TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Potrai connetterti tramite SSH alla tua VM TPU una volta che la risorsa in coda si trova nello stato ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Quando QueuedResource è nello stato ACTIVE, l'output sarà simile al seguente:

     state: ACTIVE
    
  3. Installa la versione più recente di JAX e jaxlib:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Clona il modello ImageNet e installa i requisiti corrispondenti:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
    
  5. Per generare dati fittizi, il modello ha bisogno di informazioni sulle dimensioni del set di dati. Questi dati possono essere raccolti dai metadati del set di dati ImageNet:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
    

Addestra il modello

Una volta completati tutti i passaggi precedenti, puoi addestrare il modello.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

Elimina la TPU e la risorsa in coda

Elimina la TPU e la risorsa in coda al termine della sessione.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Modelli FLAX Hugging Face

I modelli Hugging Face implementati in FLAX funzionano immediatamente su Cloud TPU v5e. Questa sezione fornisce istruzioni per l'esecuzione di modelli popolari.

Addestra ViT su Imagenette

Questo tutorial mostra come addestrare il modello Vision Transformer (ViT) di Hugging Face utilizzando il set di dati Imagenette di Fast AI su Cloud TPU v5e.

Il modello ViT è stato il primo a eseguire l'addestramento di un codificatore Transformer su ImageNet con risultati eccellenti rispetto alle reti convoluzionali. Per maggiori informazioni, consulta le seguenti risorse:

Configura

  1. Crea variabili di ambiente:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descrizioni delle variabili di ambiente

    Variabile Descrizione
    PROJECT_ID L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo.
    TPU_NAME Il nome della TPU.
    ZONE La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU.
    ACCELERATOR_TYPE Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU.
    RUNTIME_VERSION La versione software di Cloud TPU.
    SERVICE_ACCOUNT L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud .

    Ad esempio: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID L'ID testo assegnato dall'utente della richiesta di risorsa in coda.

  2. Crea una risorsa TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Potrai accedere alla tua VM TPU tramite SSH una volta che la risorsa in coda si trova nello stato ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Quando la risorsa in coda si trova nello stato ACTIVE, l'output sarà simile al seguente:

     state: ACTIVE
    
  3. Installa JAX e la relativa libreria:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Scarica il repository di Hugging Face e installa i requisiti:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
    
  5. Scarica il set di dati Imagenette:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

Addestra il modello

Addestra il modello con un buffer pre-mappato a 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

Elimina la TPU e la risorsa in coda

Elimina la TPU e la risorsa in coda al termine della sessione.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Risultati del benchmarking di ViT

Lo script di addestramento è stato eseguito su v5litepod-4, v5litepod-16 e v5litepod-64. La tabella seguente mostra i throughput con diversi tipi di acceleratore.

Tipo di acceleratore v5litepod-4 v5litepod-16 v5litepod-64
Epoca 3 3 3
Dimensione batch globale 32 128 512
Throughput (esempi/sec) 263,40 429,34 470.71

Train Diffusion sui Pokémon

Questo tutorial mostra come addestrare il modello Stable Diffusion da HuggingFace utilizzando il set di dati Pokémon su Cloud TPU v5e.

Il modello Stable Diffusion è un modello latente da testo a immagine che genera immagini fotorealistiche da qualsiasi input di testo. Per maggiori informazioni, consulta le seguenti risorse:

Configura

  1. Imposta una variabile di ambiente per il nome del bucket di archiviazione:

    export GCS_BUCKET_NAME=your_bucket_name
  2. Configura un bucket di archiviazione per l'output del modello:

    gcloud storage buckets create gs://GCS_BUCKET_NAME \
        --project=your_project \
        --location=us-west1
  3. Crea variabili di ambiente:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west1-c
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descrizioni delle variabili di ambiente

    Variabile Descrizione
    PROJECT_ID L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo.
    TPU_NAME Il nome della TPU.
    ZONE La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU.
    ACCELERATOR_TYPE Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU.
    RUNTIME_VERSION La versione software di Cloud TPU.
    SERVICE_ACCOUNT L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud .

    Ad esempio: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID L'ID testo assegnato dall'utente della richiesta di risorsa in coda.

  4. Crea una risorsa TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Potrai utilizzare SSH per accedere alla VM TPU una volta che la risorsa in coda si trova nello stato ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Quando la risorsa in coda si trova nello stato ACTIVE, l'output sarà simile al seguente:

     state: ACTIVE
    
  5. Installa JAX e la relativa libreria.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  6. Scarica il repository di HuggingFace e installa i requisiti.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
         --project=${PROJECT_ID} \
         --zone=${ZONE} \
         --worker=all \
         --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
    

Addestra il modello

Addestra il modello con un buffer pre-mappato a 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip3 install -r requirements.txt
    pip3 install .
    pip3 install gcsfs
    export LIBTPU_INIT_ARGS=''
    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

Esegui la pulizia

Elimina la TPU, la risorsa in coda e il bucket Cloud Storage al termine della sessione.

  1. Elimina la TPU:

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  2. Elimina la risorsa in coda:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  3. Elimina il bucket Cloud Storage:

    gcloud storage rm -r gs://${GCS_BUCKET_NAME}
    

Risultati del benchmarking per la diffusione

Lo script di addestramento è stato eseguito su v5litepod-4, v5litepod-16 e v5litepod-64. La tabella seguente mostra i throughput.

Tipo di acceleratore v5litepod-4 v5litepod-16 v5litepod-64
Passaggio di addestramento 1500 1500 1500
Dimensione batch globale 32 64 128
Throughput (esempi/sec) 36,53 43,71 49,36

PyTorch/XLA

Le sezioni seguenti descrivono esempi di come addestrare modelli PyTorch/XLA su TPU v5e.

Addestra ResNet utilizzando il runtime PJRT

PyTorch/XLA sta eseguendo la migrazione da XRT a PjRt a partire da PyTorch 2.0+. Ecco le istruzioni aggiornate per configurare v5e per i carichi di lavoro di addestramento PyTorch/XLA.

Configura
  1. Crea variabili di ambiente:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descrizioni delle variabili di ambiente

    Variabile Descrizione
    PROJECT_ID L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo.
    TPU_NAME Il nome della TPU.
    ZONE La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU.
    ACCELERATOR_TYPE Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU.
    RUNTIME_VERSION La versione software di Cloud TPU.
    SERVICE_ACCOUNT L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud .

    Ad esempio: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID L'ID testo assegnato dall'utente della richiesta di risorsa in coda.

  2. Crea una risorsa TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Potrai accedere tramite SSH alla tua VM TPU una volta che QueuedResource è nello stato ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Quando la risorsa in coda si trova nello stato ACTIVE, l'output sarà simile al seguente:

     state: ACTIVE
    
  3. Installa le dipendenze specifiche di Torch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --worker=all \
      --command='
         sudo apt-get update -y
         sudo apt-get install libomp5 -y
         pip3 install mkl mkl-include
         pip3 install tf-nightly tb-nightly tbp-nightly
         pip3 install numpy
         sudo apt-get install libopenblas-dev -y
         pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    Sostituisci PYTORCH_VERSION con la versione di PyTorch che vuoi utilizzare. PYTORCH_VERSION viene utilizzato per specificare la stessa versione per PyTorch/XLA. Si consiglia la versione 2.6.0.

    Per ulteriori informazioni sulle versioni di PyTorch e PyTorch/XLA, consulta PyTorch - Get Started e PyTorch/XLA releases.

    Per saperne di più sull'installazione di PyTorch/XLA, consulta Installazione di PyTorch/XLA.

Addestra il modello ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git checkout release-r2.6
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 —num_workers=16  --log_steps=300 --batch_size=64 --profile'

Elimina la TPU e la risorsa in coda

Elimina la TPU e la risorsa in coda al termine della sessione.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
Risultato benchmark

La tabella seguente mostra i throughput di riferimento.

Tipo di acceleratore Throughput (esempi/secondo)
v5litepod-4 4240 ex/s
v5litepod-16 10.810 ex/s
v5litepod-64 46.154 ex/s

Addestra ViT su v5e

Questo tutorial spiega come eseguire VIT su v5e utilizzando il repository HuggingFace su PyTorch/XLA sul set di dati cifar10.

Configura

  1. Crea variabili di ambiente:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descrizioni delle variabili di ambiente

    Variabile Descrizione
    PROJECT_ID L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo.
    TPU_NAME Il nome della TPU.
    ZONE La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU.
    ACCELERATOR_TYPE Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU.
    RUNTIME_VERSION La versione software di Cloud TPU.
    SERVICE_ACCOUNT L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud .

    Ad esempio: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID L'ID testo assegnato dall'utente della richiesta di risorsa in coda.

  2. Crea una risorsa TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Potrai accedere tramite SSH alla tua VM TPU una volta che QueuedResource si trova nello stato ACTIVE:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Quando la risorsa in coda si trova nello stato ACTIVE, l'output sarà simile al seguente:

     state: ACTIVE
    
  3. Installa le dipendenze di PyTorch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
      pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    Sostituisci PYTORCH_VERSION con la versione di PyTorch che vuoi utilizzare. PYTORCH_VERSION viene utilizzato per specificare la stessa versione per PyTorch/XLA. Si consiglia la versione 2.6.0.

    Per ulteriori informazioni sulle versioni di PyTorch e PyTorch/XLA, consulta PyTorch - Get Started e PyTorch/XLA releases.

    Per saperne di più sull'installazione di PyTorch/XLA, consulta Installazione di PyTorch/XLA.

  4. Scarica il repository di HuggingFace e installa i requisiti.

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

Addestra il modello

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

Elimina la TPU e la risorsa in coda

Elimina la TPU e la risorsa in coda al termine della sessione.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Risultato benchmark

La tabella seguente mostra i throughput di benchmark per i diversi tipi di acceleratori.

v5litepod-4 v5litepod-16 v5litepod-64
Epoca 3 3 3
Dimensione batch globale 32 128 512
Throughput (esempi/sec) 201 657 2844