Entraînement Cloud TPU v5e
Avec une empreinte de 256 puces par pod, le TPU v5e est optimisé pour être un produit de grande valeur pour l'entraînement, le réglage et le traitement des transformateurs, du texte en image et des réseaux de neurones convolutifs (RNN). Pour en savoir plus sur l'utilisation de Cloud TPU v5e pour la diffusion, consultez la section Inférence à l'aide de v5e.
Pour en savoir plus sur le matériel et les configurations des TPU Cloud TPU v5e, consultez la page TPU v5e.
Commencer
Les sections suivantes expliquent comment commencer à utiliser les TPU v5e.
Quota de requêtes
Vous avez besoin d'un quota pour utiliser des TPU v5e pour l'entraînement. Il existe différents types de quotas pour les TPU à la demande, les TPU réservés et les VM Spot TPU. Des quotas distincts sont requis si vous utilisez votre TPU v5e pour l'inférence. Pour en savoir plus sur les quotas, consultez la page Quotas. Pour demander un quota TPU v5e, contactez le service commercial Cloud.
Créer un compte et un projet Google Cloud
Vous avez besoin d'un Google Cloud compte et d'un projet pour utiliser Cloud TPU. Pour en savoir plus, consultez la page Configurer un environnement Cloud TPU.
Créer une instance Cloud TPU
Il est recommandé de provisionner des Cloud TPU v5 en tant que ressources mises en file d'attente à l'aide de la commande queued-resource create
. Pour en savoir plus, consultez la section Gérer les ressources en file d'attente.
Vous pouvez également utiliser l'API Create Node (gcloud compute tpus tpu-vm create
) pour provisionner des Cloud TPU v5e. Pour en savoir plus, consultez Gérer les ressources TPU.
Pour en savoir plus sur les configurations v5e disponibles pour l'entraînement, consultez la section Types de Cloud TPU v5e pour l'entraînement.
Configuration du framework
Cette section décrit le processus de configuration général pour l'entraînement de modèles personnalisés à l'aide de JAX ou de PyTorch avec TPU v5e.
Pour obtenir des instructions de configuration de l'inférence, consultez la page Présentation de l'inférence v5e.
Définissez des variables d'environnement:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
Configuration pour JAX
Si vos formes de tranche contiennent plus de huit chips, vous aurez plusieurs VM dans une même tranche. Dans ce cas, vous devez utiliser l'indicateur --worker=all
pour exécuter l'installation sur toutes les VM TPU en une seule étape, sans utiliser SSH pour vous connecter à chacune d'elles séparément:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Description des options de commande
Variable | Description |
TPU_NAME | ID de texte attribué par l'utilisateur du TPU créé lors de l'allocation de la requête de ressource mise en file d'attente. |
PROJECT_ID | Google Cloud Nom du projet. Utilisez un projet existant ou créez-en un à l'adresse Configurer votre Google Cloud projet. |
ZONE | Pour connaître les zones compatibles, consultez le document Régions et zones TPU. |
Worker [class name, see definition and ref site provided] | VM TPU ayant accès aux TPU sous-jacents. |
Vous pouvez exécuter la commande suivante pour vérifier le nombre d'appareils (les sorties affichées ici ont été produites avec une tranche v5litepod-16). Ce code vérifie que tout est correctement installé en vérifiant que JAX voit les TensorCores Cloud TPU et qu'il peut exécuter des opérations de base:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
Le résultat doit ressembler à ce qui suit :
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
indique le nombre total de chips dans la tranche donnée.
jax.local_device_count()
indique le nombre de puces accessibles par une seule VM dans cette tranche.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
Le résultat doit ressembler à ce qui suit :
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
Suivez les tutoriels JAX de ce document pour commencer à utiliser l'entraînement v5e avec JAX.
Configuration pour PyTorch
Notez que la version v5e n'est compatible qu'avec l'environnement d'exécution PJRT et que PyTorch 2.1 et versions ultérieures utiliseront PJRT comme environnement d'exécution par défaut pour toutes les versions de TPU.
Cette section explique comment commencer à utiliser PJRT sur la version v5e avec PyTorch/XLA avec des commandes pour tous les nœuds de calcul.
Installer des dépendances
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Remplacez PYTORCH_VERSION
par la version de PyTorch que vous souhaitez utiliser.
PYTORCH_VERSION
permet de spécifier la même version pour PyTorch/XLA. La version 2.6.0 est recommandée.
Pour en savoir plus sur les versions de PyTorch et de PyTorch/XLA, consultez PyTorch - Premiers pas et Versions de PyTorch/XLA.
Pour en savoir plus sur l'installation de PyTorch/XLA, consultez la section Installation de PyTorch/XLA.
Si vous recevez une erreur lors de l'installation des roues pour torch
, torch_xla
ou torchvision
comme pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
, rétrogradez votre version avec cette commande:
pip3 install setuptools==62.1.0
Exécuter un script avec PJRT
unset LD_PRELOAD
Vous trouverez ci-dessous un exemple d'utilisation d'un script Python pour effectuer un calcul sur une VM v5e:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Un résultat semblable à celui-ci doit s'afficher :
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
Suivez les tutoriels PyTorch de ce document pour commencer à entraîner v5e avec PyTorch.
Supprimez votre TPU et votre ressource en file d'attente à la fin de votre session. Pour supprimer une ressource en file d'attente, supprimez d'abord la tranche, puis la ressource en file d'attente en deux étapes:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Ces deux étapes peuvent également être utilisées pour supprimer les requêtes de ressources en file d'attente qui sont à l'état FAILED
.
Exemples JAX/FLAX
Les sections suivantes décrivent des exemples d'entraînement de modèles JAX et FLAX sur un TPU v5e.
Entraîner ImageNet sur v5e
Ce tutoriel explique comment entraîner ImageNet sur v5e à l'aide de fausses données d'entrée. Si vous souhaitez utiliser des données réelles, consultez le fichier README sur GitHub.
Configurer
Créez des variables d'environnement :
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descriptions des variables d'environnement
Variable Description PROJECT_ID
L'ID de votre Google Cloud projet. Utilisez un projet existant ou créez-en un. TPU_NAME
Nom du TPU. ZONE
Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez la section Régions et zones de TPU. ACCELERATOR_TYPE
Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU. RUNTIME_VERSION
Version logicielle de Cloud TPU. SERVICE_ACCOUNT
Adresse e-mail de votre compte de service. Pour le trouver, accédez à la page Comptes de service dans la console Google Cloud . Par exemple :
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texte attribué par l'utilisateur de la requête de ressource mise en file d'attente. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Vous pourrez vous connecter en SSH à votre VM TPU une fois que votre ressource en file d'attente sera à l'état
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Lorsque QueuedResource est dans l'état
ACTIVE
, le résultat ressemble à ce qui suit:state: ACTIVE
Installez la dernière version de JAX et de jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Clonez le modèle ImageNet et installez les exigences correspondantes:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
Pour générer de fausses données, le modèle a besoin d'informations sur les dimensions de l'ensemble de données. Vous pouvez le faire à partir des métadonnées de l'ensemble de données ImageNet:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
Entraîner le modèle
Une fois toutes les étapes précédentes terminées, vous pouvez entraîner le modèle.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"
Supprimer le TPU et la ressource en file d'attente
Supprimez votre TPU et votre ressource en file d'attente à la fin de votre session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Modèles Hugging Face FLAX
Les modèles Hugging Face implémentés dans FLAX fonctionnent immédiatement sur Cloud TPU v5e. Cette section fournit des instructions pour exécuter des modèles populaires.
Entraîner ViT sur Imagenette
Ce tutoriel vous explique comment entraîner le modèle Vision Transformer (ViT) de HuggingFace à l'aide de l'ensemble de données Imagenette de Fast AI sur Cloud TPU v5e.
Le modèle ViT a été le premier à entraîner avec succès un encodeur Transformer sur ImageNet, avec d'excellents résultats par rapport aux réseaux de neurones convolutifs. Pour en savoir plus, consultez les ressources suivantes:
Configurer
Créez des variables d'environnement :
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descriptions des variables d'environnement
Variable Description PROJECT_ID
L'ID de votre Google Cloud projet. Utilisez un projet existant ou créez-en un. TPU_NAME
Nom du TPU. ZONE
Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez la section Régions et zones de TPU. ACCELERATOR_TYPE
Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU. RUNTIME_VERSION
Version logicielle de Cloud TPU. SERVICE_ACCOUNT
Adresse e-mail de votre compte de service. Pour le trouver, accédez à la page Comptes de service dans la console Google Cloud . Par exemple :
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texte attribué par l'utilisateur de la requête de ressource mise en file d'attente. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Vous pourrez vous connecter en SSH à votre VM TPU une fois que votre ressource en file d'attente sera à l'état
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Lorsque la ressource mise en file d'attente est dans l'état
ACTIVE
, le résultat ressemble à ce qui suit:state: ACTIVE
Installez JAX et sa bibliothèque:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Téléchargez le dépôt Hugging Face et les exigences d'installation:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
Téléchargez l'ensemble de données Imagenette:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Entraîner le modèle
Entraînez le modèle avec un tampon prémappé de 4 Go.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
Supprimer le TPU et la ressource en file d'attente
Supprimez votre TPU et votre ressource en file d'attente à la fin de votre session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Résultats des benchmarks sur la qualité de l'image
Le script d'entraînement a été exécuté sur v5litepod-4, v5litepod-16 et v5litepod-64. Le tableau suivant présente les débits avec différents types d'accélérateurs.
Type d'accélérateur | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Époque | 3 | 3 | 3 |
Taille du lot global | 32 | 128 | 512 |
Débit (exemples/s) | 263,40 | 429,34 | 470,71 |
Train Diffusion sur Pokémon
Ce tutoriel vous explique comment entraîner le modèle Stable Diffusion de HuggingFace à l'aide de l'ensemble de données Pokémon sur Cloud TPU v5e.
Le modèle Stable Diffusion est un modèle de texte vers image latent qui génère des images photoréalistes à partir de n'importe quelle entrée textuelle. Pour en savoir plus, consultez les ressources suivantes :
Configurer
Définissez une variable d'environnement pour le nom de votre bucket de stockage:
export GCS_BUCKET_NAME=your_bucket_name
Configurez un bucket de stockage pour la sortie de votre modèle:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
Créez des variables d'environnement :
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descriptions des variables d'environnement
Variable Description PROJECT_ID
L'ID de votre Google Cloud projet. Utilisez un projet existant ou créez-en un. TPU_NAME
Nom du TPU. ZONE
Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez la section Régions et zones de TPU. ACCELERATOR_TYPE
Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU. RUNTIME_VERSION
Version logicielle de Cloud TPU. SERVICE_ACCOUNT
Adresse e-mail de votre compte de service. Pour le trouver, accédez à la page Comptes de service dans la console Google Cloud . Par exemple :
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texte attribué par l'utilisateur de la requête de ressource mise en file d'attente. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Vous pourrez vous connecter en SSH à votre VM TPU une fois que votre ressource en file d'attente sera à l'état
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Lorsque la ressource mise en file d'attente est dans l'état
ACTIVE
, le résultat ressemble à ceci:state: ACTIVE
Installez JAX et sa bibliothèque.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Téléchargez le dépôt HuggingFace et les exigences d'installation.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
Entraîner le modèle
Entraînez le modèle avec un tampon prémappé de 4 Go.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
pip3 install gcsfs
export LIBTPU_INIT_ARGS=''
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"
Effectuer un nettoyage
Supprimez votre TPU, votre ressource mise en file d'attente et votre bucket Cloud Storage à la fin de votre session.
Supprimez votre TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Supprimez la ressource en file d'attente:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Supprimez le bucket Cloud Storage :
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
Résultats des benchmarks pour la diffusion
Le script d'entraînement a été exécuté sur v5litepod-4, v5litepod-16 et v5litepod-64. Le tableau suivant présente les débits.
Type d'accélérateur | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Étape d'entraînement | 1500 | 1500 | 1500 |
Taille du lot global | 32 | 64 | 128 |
Débit (exemples/s) | 36,53 | 43,71 | 49.36 |
PyTorch/XLA
Les sections suivantes décrivent des exemples d'entraînement de modèles PyTorch/XLA sur un TPU v5e.
Entraîner ResNet à l'aide de l'environnement d'exécution PJRT
PyTorch/XLA passe de XRT à PjRt à partir de PyTorch 2.0 et versions ultérieures. Voici les instructions mises à jour pour configurer la version 5e pour les charges de travail d'entraînement PyTorch/XLA.
Configurer
Créez des variables d'environnement :
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descriptions des variables d'environnement
Variable Description PROJECT_ID
L'ID de votre Google Cloud projet. Utilisez un projet existant ou créez-en un. TPU_NAME
Nom du TPU. ZONE
Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez la section Régions et zones de TPU. ACCELERATOR_TYPE
Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU. RUNTIME_VERSION
Version logicielle de Cloud TPU. SERVICE_ACCOUNT
Adresse e-mail de votre compte de service. Pour le trouver, accédez à la page Comptes de service dans la console Google Cloud . Par exemple :
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texte attribué par l'utilisateur de la requête de ressource mise en file d'attente. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Vous pourrez vous connecter en SSH à votre VM TPU une fois que votre QueuedResource sera dans l'état
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Lorsque la ressource mise en file d'attente est dans l'état
ACTIVE
, le résultat ressemble à ceci:state: ACTIVE
Installer les dépendances spécifiques à Torch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Remplacez
PYTORCH_VERSION
par la version de PyTorch que vous souhaitez utiliser.PYTORCH_VERSION
permet de spécifier la même version pour PyTorch/XLA. La version 2.6.0 est recommandée.Pour en savoir plus sur les versions de PyTorch et de PyTorch/XLA, consultez PyTorch - Premiers pas et Versions de PyTorch/XLA.
Pour en savoir plus sur l'installation de PyTorch/XLA, consultez la section Installation de PyTorch/XLA.
Entraîner le modèle ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git checkout release-r2.6
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
Supprimer le TPU et la ressource en file d'attente
Supprimez votre TPU et votre ressource en file d'attente à la fin de votre session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Résultat du benchmark
Le tableau suivant présente les débits de référence.
Type d'accélérateur | Débit (exemples/seconde) |
v5litepod-4 | 4 240 ex/s |
v5litepod-16 | 10 810 ex/s |
v5litepod-64 | 46 154 ex/s |
Entraîner ViT sur v5e
Ce tutoriel explique comment exécuter VIT sur v5e à l'aide du dépôt HuggingFace sur PyTorch/XLA sur l'ensemble de données cifar10.
Configurer
Créez des variables d'environnement :
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descriptions des variables d'environnement
Variable Description PROJECT_ID
L'ID de votre Google Cloud projet. Utilisez un projet existant ou créez-en un. TPU_NAME
Nom du TPU. ZONE
Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez la section Régions et zones de TPU. ACCELERATOR_TYPE
Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU. RUNTIME_VERSION
Version logicielle de Cloud TPU. SERVICE_ACCOUNT
Adresse e-mail de votre compte de service. Pour le trouver, accédez à la page Comptes de service dans la console Google Cloud . Par exemple :
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texte attribué par l'utilisateur de la requête de ressource mise en file d'attente. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Vous pourrez vous connecter en SSH à votre VM TPU une fois que votre QueuedResource sera dans l'état
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Lorsque la ressource mise en file d'attente est dans l'état
ACTIVE
, le résultat ressemble à ceci:state: ACTIVE
Installer les dépendances PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Remplacez
PYTORCH_VERSION
par la version de PyTorch que vous souhaitez utiliser.PYTORCH_VERSION
permet de spécifier la même version pour PyTorch/XLA. La version 2.6.0 est recommandée.Pour en savoir plus sur les versions de PyTorch et de PyTorch/XLA, consultez PyTorch - Premiers pas et Versions de PyTorch/XLA.
Pour en savoir plus sur l'installation de PyTorch/XLA, consultez la section Installation de PyTorch/XLA.
Téléchargez le dépôt HuggingFace et les exigences d'installation.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Entraîner le modèle
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
Supprimer le TPU et la ressource en file d'attente
Supprimez votre TPU et votre ressource en file d'attente à la fin de votre session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Résultat du benchmark
Le tableau suivant présente les débits de référence pour différents types d'accélérateurs.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Époque | 3 | 3 | 3 |
Taille du lot global | 32 | 128 | 512 |
Débit (exemples/s) | 201 | 657 | 2 844 |