Addestramento di Cloud TPU v5e
Con un'impronta più ridotta di 256 chip per pod, TPU v5e è ottimizzato per essere un prodotto di alto valore per l'addestramento, l'ottimizzazione e il servizio di transformer, text-to-image e reti neurali convoluzionali (CNN). Per ulteriori informazioni sull'utilizzo di Cloud TPU v5e per il servizio, consulta Inferenza con v5e.
Per ulteriori informazioni sull'hardware e sulle configurazioni di Cloud TPU v5e, consulta TPU v5e.
Inizia
Le sezioni seguenti descrivono come iniziare a utilizzare TPU v5e.
Quota per le richieste
Per utilizzare TPU v5e per l'addestramento, devi disporre di una quota. Esistono diversi tipi di quote per le TPU on demand, le TPU riservate e le VM TPU Spot. Sono richieste quote separate se utilizzi TPU v5e per l'inferenza. Per ulteriori informazioni sulle quote, consulta Quote. Per richiedere una quota TPU v5e, contatta il team di Cloud Sales.
Crea un Google Cloud account e un progetto
Per utilizzare Cloud TPU, devi disporre di un Google Cloud account e di un progetto. Per maggiori informazioni, consulta Configurare un ambiente Cloud TPU.
Crea una Cloud TPU
La best practice è eseguire il provisioning di Cloud TPU v5 come risorse in coda utilizzando il comando queued-resource create
. Per ulteriori informazioni, consulta la pagina Gestire le risorse in coda.
Puoi anche utilizzare l'API Create Node (gcloud compute tpus tpu-vm create
) per eseguire il provisioning di Cloud TPU v5e. Per saperne di più, consulta Gestire le risorse TPU.
Per ulteriori informazioni sulle configurazioni v5e disponibili per l'addestramento, consulta Tipi di Cloud TPU v5e per l'addestramento.
Configurazione del framework
Questa sezione descrive la procedura di configurazione generale per l'addestramento di modelli personalizzati utilizzando JAX o PyTorch con TPU v5e.
Per le istruzioni di configurazione dell'inferenza, consulta l'introduzione all'inferenza v5e.
Definisci alcune variabili di ambiente:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
Configurazione per JAX
Se le forme dei segmenti sono superiori a 8 chip, avrai più VM in un
segmento. In questo caso, devi utilizzare il flag --worker=all
per eseguire l'installazione su tutte le VM TPU in un unico passaggio senza utilizzare SSH per accedere a ciascuna singola VM:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Descrizioni dei flag dei comandi
Variabile | Descrizione |
TPU_NAME | L'ID testo assegnato dall'utente della TPU che viene creato quando viene allocata la richiesta di risorse in coda. |
PROJECT_ID | Google Cloud Nome progetto. Utilizza un progetto esistente o creane uno nuovo in Configurare il Google Cloud progetto |
ZONA | Consulta il documento Regioni e zone TPU per le zone supportate. |
worker | La VM TPU che ha accesso alle TPU sottostanti. |
Puoi eseguire il seguente comando per controllare il numero di dispositivi (gli output mostrati qui sono stati prodotti con una fetta v5litepod-16). Questo codice verifica che tutto sia installato correttamente controllando che JAX veda i TensorCore di Cloud TPU e possa eseguire operazioni di base:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
L'output sarà simile al seguente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
mostra il numero totale di chip nell'intervallo specificato.
jax.local_device_count()
indica il numero di chip accessibili da una singola VM in questo slice.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
L'output sarà simile al seguente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
Prova i tutorial su JAX in questo documento per iniziare con l'addestramento di v5e utilizzando JAX.
Configurazione per PyTorch
Tieni presente che la versione v5e supporta solo il runtime PJRT e che PyTorch 2.1 e versioni successive utilizzeranno PJRT come runtime predefinito per tutte le versioni di TPU.
Questa sezione descrive come iniziare a utilizzare PJRT su v5e con PyTorch/XLA con comandi per tutti i worker.
Installa le dipendenze
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Sostituisci PYTORCH_VERSION
con la versione di PyTorch che vuoi utilizzare.
PYTORCH_VERSION
viene utilizzato per specificare la stessa versione per PyTorch/XLA. È consigliata la versione 2.6.0.
Per ulteriori informazioni sulle versioni di PyTorch e PyTorch/XLA, consulta PyTorch - Guida introduttiva e Uscite di PyTorch/XLA.
Per ulteriori informazioni sull'installazione di PyTorch/XLA, consulta la sezione Installazione di PyTorch/XLA.
Se ricevi un errore durante l'installazione dei wheel per torch
, torch_xla
o
torchvision
come
pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
,
esegui il downgrade della versione con questo comando:
pip3 install setuptools==62.1.0
Eseguire uno script con PJRT
unset LD_PRELOAD
Di seguito è riportato un esempio che utilizza uno script Python per eseguire un calcolo su una VM v5e:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Viene generato un output simile al seguente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
Prova i tutorial di PyTorch in questo documento per iniziare con l'addestramento di v5e utilizzando PyTorch.
Elimina la TPU e la risorsa in coda al termine della sessione. Per eliminare una risorsa in coda, elimina il segmento e poi la risorsa in coda in due passaggi:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Questi due passaggi possono essere utilizzati anche per rimuovere le richieste di risorse in coda nello stato FAILED
.
Esempi di JAX/FLAX
Le sezioni seguenti descrivono esempi di come addestrare i modelli JAX e FLAX su TPU v5e.
Addestra ImageNet su v5e
Questo tutorial descrive come addestrare ImageNet su v5e utilizzando dati di input falsi. Se vuoi utilizzare dati reali, consulta il file README su GitHub.
Configura
Crea le variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
Il tuo Google Cloud ID progetto. Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, consulta Versioni TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo nella pagina Account di servizio della console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID di testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai accedere tramite SSH alla VM TPU quando la risorsa in coda sarà nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando QueuedResource è nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa la versione più recente di JAX e jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Clona il modello ImageNet e installa i requisiti corrispondenti:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
Per generare dati falsi, il modello ha bisogno di informazioni sulle dimensioni del set di dati. Questi dati possono essere raccolti dai metadati del set di dati ImageNet:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
Addestra il modello
Una volta completati tutti i passaggi precedenti, puoi addestrare il modello.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"
Elimina la TPU e la risorsa in coda
Elimina la TPU e la risorsa in coda al termine della sessione.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Modelli FLAX di Hugging Face
I modelli Hugging Face implementati in FLAX funzionano subito su Cloud TPU v5e. Questa sezione fornisce istruzioni per eseguire i modelli più diffusi.
Addestramento di ViT su Imagenette
Questo tutorial mostra come addestrare il modello Vision Transformer (ViT) di HuggingFace utilizzando il set di dati Imagenette di Fast AI su Cloud TPU v5e.
Il modello ViT è stato il primo ad aver addestrato correttamente un codificatore Transformer su ImageNet con risultati eccellenti rispetto alle reti convoluzionali. Per maggiori informazioni, consulta le seguenti risorse:
Configura
Crea le variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
Il tuo Google Cloud ID progetto. Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, consulta Versioni TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo nella pagina Account di servizio della console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID di testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai accedere tramite SSH alla VM TPU quando la risorsa in coda sarà nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando la risorsa in coda è nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa JAX e la relativa libreria:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Scarica il repository di Hugging Face e i requisiti di installazione:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
Scarica il set di dati Imagenette:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Addestra il modello
Addestra il modello con un buffer premappato di 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
Elimina la TPU e la risorsa in coda
Elimina la TPU e la risorsa in coda al termine della sessione.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Risultati del benchmarking di ViT
Lo script di addestramento è stato eseguito su v5litepod-4, v5litepod-16 e v5litepod-64. La tabella seguente mostra le portate con diversi tipi di acceleratori.
Tipo di acceleratore | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Epoca | 3 | 3 | 3 |
Dimensione del batch globale | 32 | 128 | 512 |
Velocità effettiva (esempi/sec) | 263,40 | 429,34 | 470,71 |
Allenare la diffusione su Pokémon
Questo tutorial mostra come addestrare il modello di diffusione stabile di HuggingFace utilizzando il set di dati Pokémon su Cloud TPU v5e.
Il modello di diffusione stabile è un modello latente di conversione di testo in immagine che genera immagini fotorealistiche da qualsiasi input di testo. Per saperne di più, consulta le seguenti risorse:
Configura
Imposta una variabile di ambiente per il nome del bucket di archiviazione:
export GCS_BUCKET_NAME=your_bucket_name
Configura un bucket di archiviazione per l'output del modello:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
Crea le variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
Il tuo Google Cloud ID progetto. Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, consulta Versioni TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo nella pagina Account di servizio della console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID di testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai accedere tramite SSH alla VM TPU quando la risorsa in coda sarà nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando la risorsa in coda è nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa JAX e la relativa libreria.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Scarica il repository HuggingFace e i requisiti di installazione.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
Addestra il modello
Addestra il modello con un buffer premappato di 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
pip3 install gcsfs
export LIBTPU_INIT_ARGS=''
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"
Esegui la pulizia
Elimina la TPU, la risorsa in coda e il bucket Cloud Storage al termine della sessione.
Elimina la TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Elimina la risorsa in coda:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Elimina il bucket Cloud Storage:
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
Risultati del benchmarking per la diffusione
Lo script di addestramento è stato eseguito su v5litepod-4, v5litepod-16 e v5litepod-64. La tabella seguente mostra le portate.
Tipo di acceleratore | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Passaggio di addestramento | 1500 | 1500 | 1500 |
Dimensione del batch globale | 32 | 64 | 128 |
Velocità effettiva (esempi/sec) | 36,53 | 43,71 | 49,36 |
PyTorch/XLA
Le sezioni seguenti descrivono esempi di come addestrare i modelli PyTorch/XLA su TPU v5e.
Addestramento di ResNet utilizzando il runtime PJRT
PyTorch/XLA esegue la migrazione da XRT a PjRt da PyTorch 2.0 e versioni successive. Di seguito sono riportate le istruzioni aggiornate per configurare la versione 5e per i workload di addestramento PyTorch/XLA.
Configura
Crea le variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
Il tuo Google Cloud ID progetto. Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, consulta Versioni TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo nella pagina Account di servizio della console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID di testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai accedere tramite SSH alla VM TPU quando la risorsa in coda sarà nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando la risorsa in coda è nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa le dipendenze specifiche di Torch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Sostituisci
PYTORCH_VERSION
con la versione di PyTorch che vuoi utilizzare.PYTORCH_VERSION
viene utilizzato per specificare la stessa versione per PyTorch/XLA. È consigliata la versione 2.6.0.Per ulteriori informazioni sulle versioni di PyTorch e PyTorch/XLA, consulta PyTorch - Guida introduttiva e Uscite di PyTorch/XLA.
Per ulteriori informazioni sull'installazione di PyTorch/XLA, consulta la sezione Installazione di PyTorch/XLA.
Addestrare il modello ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git checkout release-r2.6
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
Elimina la TPU e la risorsa in coda
Elimina la TPU e la risorsa in coda al termine della sessione.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Risultato del benchmark
La tabella seguente mostra le portate di benchmark.
Tipo di acceleratore | Velocità effettiva (esempi/secondo) |
v5litepod-4 | 4240 ex/s |
v5litepod-16 | 10.810 ex/s |
v5litepod-64 | 46.154 ex/s |
Addestrare ViT su v5e
Questo tutorial spiega come eseguire VIT su v5e utilizzando il repository HuggingFace su PyTorch/XLA sul set di dati cifar10.
Configura
Crea le variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_ID
Il tuo Google Cloud ID progetto. Utilizza un progetto esistente o creane uno nuovo. TPU_NAME
Il nome della TPU. ZONE
La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPE
Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, consulta Versioni TPU. RUNTIME_VERSION
La versione software di Cloud TPU. SERVICE_ACCOUNT
L'indirizzo email del tuo account di servizio. Puoi trovarlo nella pagina Account di servizio della console Google Cloud . Ad esempio:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
L'ID di testo assegnato dall'utente della richiesta di risorsa in coda. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Potrai connetterti tramite SSH alla VM TPU quando la risorsa in coda è nello stato
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando la risorsa in coda è nello stato
ACTIVE
, l'output sarà simile al seguente:state: ACTIVE
Installa le dipendenze PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Sostituisci
PYTORCH_VERSION
con la versione di PyTorch che vuoi utilizzare.PYTORCH_VERSION
viene utilizzato per specificare la stessa versione per PyTorch/XLA. È consigliata la versione 2.6.0.Per ulteriori informazioni sulle versioni di PyTorch e PyTorch/XLA, consulta PyTorch - Guida introduttiva e Uscite di PyTorch/XLA.
Per ulteriori informazioni sull'installazione di PyTorch/XLA, consulta la sezione Installazione di PyTorch/XLA.
Scarica i requisiti del repository HuggingFace e installali.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Addestra il modello
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
Elimina la TPU e la risorsa in coda
Elimina la TPU e la risorsa in coda al termine della sessione.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Risultato del benchmark
La tabella seguente mostra i throughput del benchmark per i diversi tipi di acceleratori.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoca | 3 | 3 | 3 |
Dimensione del batch globale | 32 | 128 | 512 |
Velocità effettiva (esempi/sec) | 201 | 657 | 2844 |