Addestra un modello utilizzando TPU v5e
Con un footprint più piccolo di 256 chip per pod, TPU v5e è ottimizzata per essere un prodotto di alto valore per l'addestramento, l'ottimizzazione e la pubblicazione di transformer, modelli di sintesi di immagini dal testo e reti neurali convoluzionali (CNN). Per saperne di più sull'utilizzo di Cloud TPU v5e per la pubblicazione, consulta Inferenza con v5e.
Per saperne di più sull'hardware e sulle configurazioni TPU v5e, consulta la sezione TPU v5e.
Inizia
Le sezioni seguenti descrivono come iniziare a utilizzare TPU v5e.
Quota per le richieste
Per utilizzare TPU v5e per l'addestramento, devi disporre di una quota. Esistono diversi tipi di quote per TPU on demand, TPU riservate e VM spot TPU. Se utilizzi la TPU v5e per l'inferenza, sono necessarie quote separate. Per ulteriori informazioni sulle quote, consulta Quote. Per richiedere la quota TPU v5e, contatta il team di vendite Cloud.
Crea un account e un progetto Google Cloud
Per utilizzare Cloud TPU, devi disporre di un account e di un progetto Google Cloud . Per maggiori informazioni, vedi Configurare un ambiente Cloud TPU.
Crea una Cloud TPU
La best practice consiste nel eseguire il provisioning di Cloud TPU v5e come risorse in coda
utilizzando il comando queued-resource create
. Per saperne di più, consulta
Gestire le risorse in coda.
Puoi anche utilizzare l'API Create Node (gcloud compute tpus tpu-vm create
) per eseguire il provisioning delle Cloud TPU v5e. Per saperne di più, consulta Gestire le risorse TPU.
Per ulteriori informazioni sulle configurazioni v5e disponibili per l'addestramento, consulta la sezione Tipi di Cloud TPU v5e per l'addestramento.
Configurazione del framework
Questa sezione descrive la procedura di configurazione generale per l'addestramento di modelli personalizzati utilizzando JAX o PyTorch con TPU v5e.
Per le istruzioni di configurazione dell'inferenza, vedi Introduzione all'inferenza v5e.
Definisci alcune variabili di ambiente:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
Configurazione di JAX
Se hai forme di slice maggiori di 8 chip, avrai più VM in uno
slice. In questo caso, devi utilizzare il flag --worker=all
per eseguire l'installazione su tutte le VM TPU in un unico passaggio senza utilizzare SSH per accedere a ciascuna separatamente:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Descrizioni dei flag dei comandi
Variabile | Descrizione |
TPU_NAME | L'ID testo assegnato dall'utente della TPU creato quando viene allocata la richiesta di risorse in coda. |
PROJECT_ID | Google Cloud Nome progetto. Utilizza un progetto esistente o creane uno nuovo in Configura il progetto Google Cloud |
ZONE | Consulta il documento Regioni e zone TPU per le zone supportate. |
worker | La VM TPU che ha accesso alle TPU sottostanti. |
Puoi eseguire il seguente comando per controllare il numero di dispositivi (gli output mostrati qui sono stati prodotti con una sezione v5litepod-16). Questo codice verifica che tutto sia installato correttamente controllando che JAX veda i TensorCore di Cloud TPU e possa eseguire operazioni di base:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
L'output sarà simile al seguente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
mostra il numero totale di chip nella sezione specificata.
jax.local_device_count()
indica il numero di chip accessibili da una singola
VM in questa sezione.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
L'output sarà simile al seguente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
Prova i tutorial JAX in questo documento per iniziare l'addestramento v5e utilizzando JAX.
Configurazione per PyTorch
Tieni presente che v5e supporta solo il runtime PJRT e PyTorch 2.1+ utilizzerà PJRT come runtime predefinito per tutte le versioni di TPU.
Questa sezione descrive come iniziare a utilizzare PJRT su v5e con PyTorch/XLA con comandi per tutti i worker.
Installa le dipendenze
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Sostituisci PYTORCH_VERSION
con la versione di PyTorch che vuoi utilizzare.
PYTORCH_VERSION
viene utilizzato per specificare la stessa versione per PyTorch/XLA. Si consiglia la versione 2.6.0.
Per ulteriori informazioni sulle versioni di PyTorch e PyTorch/XLA, consulta PyTorch - Get Started e PyTorch/XLA releases.
Per saperne di più sull'installazione di PyTorch/XLA, consulta Installazione di PyTorch/XLA.
Se ricevi un errore durante l'installazione delle ruote per torch
, torch_xla
o
torchvision
come
pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
,
esegui il downgrade della versione con questo comando:
pip3 install setuptools==62.1.0
Eseguire uno script con PJRT
unset LD_PRELOAD
Di seguito è riportato un esempio che utilizza uno script Python per eseguire un calcolo su una VM v5e:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Viene generato un output simile al seguente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
Prova i tutorial di PyTorch in questo documento per iniziare l'addestramento v5e utilizzando PyTorch.
Elimina la TPU e la risorsa in coda al termine della sessione. Per eliminare una risorsa in coda, elimina la sezione e poi la risorsa in coda in due passaggi:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Questi due passaggi possono essere utilizzati anche per rimuovere le richieste di risorse in coda con stato FAILED
.
Esempi di JAX/FLAX
Le sezioni seguenti descrivono esempi di come addestrare modelli JAX e FLAX su TPU v5e.
Addestra ImageNet su v5e
Questo tutorial descrive come addestrare ImageNet su v5e utilizzando dati di input falsi. Se vuoi utilizzare dati reali, consulta il file README su GitHub.
Configura
Crea variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai connetterti tramite SSH alla tua VM TPU una volta che la risorsa in coda si trova nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando QueuedResource è nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa la versione più recente di JAX e jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Clona il modello ImageNet e installa i requisiti corrispondenti:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
Per generare dati fittizi, il modello ha bisogno di informazioni sulle dimensioni del set di dati. Questi dati possono essere raccolti dai metadati del set di dati ImageNet:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
Addestra il modello
Una volta completati tutti i passaggi precedenti, puoi addestrare il modello.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"
Elimina la TPU e la risorsa in coda
Elimina la TPU e la risorsa in coda al termine della sessione.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Modelli FLAX Hugging Face
I modelli Hugging Face implementati in FLAX funzionano immediatamente su Cloud TPU v5e. Questa sezione fornisce istruzioni per l'esecuzione di modelli popolari.
Addestra ViT su Imagenette
Questo tutorial mostra come addestrare il modello Vision Transformer (ViT) di Hugging Face utilizzando il set di dati Imagenette di Fast AI su Cloud TPU v5e.
Il modello ViT è stato il primo a eseguire l'addestramento di un codificatore Transformer su ImageNet con risultati eccellenti rispetto alle reti convoluzionali. Per maggiori informazioni, consulta le seguenti risorse:
Configura
Crea variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai accedere alla tua VM TPU tramite SSH una volta che la risorsa in coda si trova nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando la risorsa in coda si trova nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa JAX e la relativa libreria:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Scarica il repository di Hugging Face e installa i requisiti:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
Scarica il set di dati Imagenette:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Addestra il modello
Addestra il modello con un buffer pre-mappato a 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
Elimina la TPU e la risorsa in coda
Elimina la TPU e la risorsa in coda al termine della sessione.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Risultati del benchmarking di ViT
Lo script di addestramento è stato eseguito su v5litepod-4, v5litepod-16 e v5litepod-64. La tabella seguente mostra i throughput con diversi tipi di acceleratore.
Tipo di acceleratore | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Epoca | 3 | 3 | 3 |
Dimensione batch globale | 32 | 128 | 512 |
Throughput (esempi/sec) | 263,40 | 429,34 | 470.71 |
Train Diffusion sui Pokémon
Questo tutorial mostra come addestrare il modello Stable Diffusion da HuggingFace utilizzando il set di dati Pokémon su Cloud TPU v5e.
Il modello Stable Diffusion è un modello latente da testo a immagine che genera immagini fotorealistiche da qualsiasi input di testo. Per maggiori informazioni, consulta le seguenti risorse:
Configura
Imposta una variabile di ambiente per il nome del bucket di archiviazione:
export GCS_BUCKET_NAME=your_bucket_name
Configura un bucket di archiviazione per l'output del modello:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
Crea variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai utilizzare SSH per accedere alla VM TPU una volta che la risorsa in coda si trova nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando la risorsa in coda si trova nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa JAX e la relativa libreria.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Scarica il repository di HuggingFace e installa i requisiti.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
Addestra il modello
Addestra il modello con un buffer pre-mappato a 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
pip3 install gcsfs
export LIBTPU_INIT_ARGS=''
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"
Esegui la pulizia
Elimina la TPU, la risorsa in coda e il bucket Cloud Storage al termine della sessione.
Elimina la TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Elimina la risorsa in coda:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Elimina il bucket Cloud Storage:
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
Risultati del benchmarking per la diffusione
Lo script di addestramento è stato eseguito su v5litepod-4, v5litepod-16 e v5litepod-64. La tabella seguente mostra i throughput.
Tipo di acceleratore | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Passaggio di addestramento | 1500 | 1500 | 1500 |
Dimensione batch globale | 32 | 64 | 128 |
Throughput (esempi/sec) | 36,53 | 43,71 | 49,36 |
PyTorch/XLA
Le sezioni seguenti descrivono esempi di come addestrare modelli PyTorch/XLA su TPU v5e.
Addestra ResNet utilizzando il runtime PJRT
PyTorch/XLA sta eseguendo la migrazione da XRT a PjRt a partire da PyTorch 2.0+. Ecco le istruzioni aggiornate per configurare v5e per i carichi di lavoro di addestramento PyTorch/XLA.
Configura
Crea variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai accedere tramite SSH alla tua VM TPU una volta che QueuedResource è nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando la risorsa in coda si trova nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa le dipendenze specifiche di Torch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Sostituisci
PYTORCH_VERSION
con la versione di PyTorch che vuoi utilizzare.PYTORCH_VERSION
viene utilizzato per specificare la stessa versione per PyTorch/XLA. Si consiglia la versione 2.6.0.Per ulteriori informazioni sulle versioni di PyTorch e PyTorch/XLA, consulta PyTorch - Get Started e PyTorch/XLA releases.
Per saperne di più sull'installazione di PyTorch/XLA, consulta Installazione di PyTorch/XLA.
Addestra il modello ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git checkout release-r2.6
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
Elimina la TPU e la risorsa in coda
Elimina la TPU e la risorsa in coda al termine della sessione.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Risultato benchmark
La tabella seguente mostra i throughput di riferimento.
Tipo di acceleratore | Throughput (esempi/secondo) |
v5litepod-4 | 4240 ex/s |
v5litepod-16 | 10.810 ex/s |
v5litepod-64 | 46.154 ex/s |
Addestra ViT su v5e
Questo tutorial spiega come eseguire VIT su v5e utilizzando il repository HuggingFace su PyTorch/XLA sul set di dati cifar10.
Configura
Crea variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo andando alla pagina Service Accounts nella console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai accedere tramite SSH alla tua VM TPU una volta che QueuedResource si trova nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando la risorsa in coda si trova nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa le dipendenze di PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Sostituisci
PYTORCH_VERSION
con la versione di PyTorch che vuoi utilizzare.PYTORCH_VERSION
viene utilizzato per specificare la stessa versione per PyTorch/XLA. Si consiglia la versione 2.6.0.Per ulteriori informazioni sulle versioni di PyTorch e PyTorch/XLA, consulta PyTorch - Get Started e PyTorch/XLA releases.
Per saperne di più sull'installazione di PyTorch/XLA, consulta Installazione di PyTorch/XLA.
Scarica il repository di HuggingFace e installa i requisiti.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Addestra il modello
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
Elimina la TPU e la risorsa in coda
Elimina la TPU e la risorsa in coda al termine della sessione.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Risultato benchmark
La tabella seguente mostra i throughput di benchmark per i diversi tipi di acceleratori.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoca | 3 | 3 | 3 |
Dimensione batch globale | 32 | 128 | 512 |
Throughput (esempi/sec) | 201 | 657 | 2844 |