Exécuter des charges de travail TPU dans un conteneur Docker
Les conteneurs Docker facilitent la configuration des applications en combinant votre code et toutes les dépendances nécessaires dans un seul package distribuable. Vous pouvez exécuter des conteneurs Docker dans des VM TPU pour simplifier la configuration et le partage de vos applications Cloud TPU. Ce document explique comment configurer un conteneur Docker pour chaque framework de ML compatible avec Cloud TPU.
Entraîner un modèle PyTorch dans un conteneur Docker
Appareil TPU
Créer une VM Cloud TPU
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base
Se connecter à la VM TPU à l'aide de SSH
gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=europe-west4-a
Assurez-vous que le rôle Lecteur du registre des artefacts a été attribué à votre utilisateur Google Cloud . Pour en savoir plus, consultez Attribuer des rôles Artifact Registry.
Démarrer un conteneur dans la VM TPU à l'aide de l'image PyTorch/XLA quotidienne
sudo docker run --net=host -ti --rm --name your-container-name --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \ bash
Configurer l'environnement d'exécution TPU
Il existe deux options d'exécution PyTorch/XLA: PJRT et XRT. Nous vous recommandons d'utiliser PJRT, sauf si vous avez une raison d'utiliser XRT. Pour en savoir plus sur les différentes configurations d'exécution, consultez la documentation d'exécution PJRT.
PJRT
export PJRT_DEVICE=TPU
XRT
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
Cloner le dépôt PyTorch XLA
git clone --recursive https://github.com/pytorch/xla.git
Entraîner ResNet50
python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1
Une fois le script d'entraînement terminé, veillez à nettoyer les ressources.
- Saisissez
exit
pour quitter le conteneur Docker. - Saisissez
exit
pour quitter la VM TPU. Supprimer la VM TPU
gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a
Tranche TPU
Lorsque vous exécutez du code PyTorch sur une tranche TPU, vous devez exécuter votre code sur tous les nœuds de calcul TPU en même temps. Pour ce faire, vous pouvez utiliser la commande gcloud compute tpus tpu-vm ssh
avec les options --worker=all
et --command
. La procédure suivante vous explique comment créer une image Docker pour faciliter la configuration de chaque nœud de travail TPU.
Créer une VM TPU
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=us-central2-b \ --accelerator-type=v4-32 \ --version=tpu-ubuntu2204-base
Ajouter l'utilisateur actuel au groupe Docker
gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=us-central2-b \ --worker=all \ --command='sudo usermod -a -G docker $USER'
Cloner le dépôt PyTorch XLA
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=us-central2-b \ --command="git clone --recursive https://github.com/pytorch/xla.git"
Exécuter le script d'entraînement dans un conteneur sur tous les nœuds TPU
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=us-central2-b \ --command="docker run --rm --privileged --net=host -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"
Options de commande Docker:
--rm
supprime le conteneur une fois son processus terminé.--privileged
expose l'appareil TPU au conteneur.--net=host
lie tous les ports du conteneur à la VM TPU pour permettre la communication entre les hôtes du pod.-e
définit des variables d'environnement.
Une fois le script d'entraînement terminé, veillez à nettoyer les ressources.
Supprimez la VM TPU à l'aide de la commande suivante:
gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=us-central2-b
Entraîner un modèle JAX dans un conteneur Docker
Appareil TPU
Créez la VM TPU.
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base
Se connecter à la VM TPU à l'aide de SSH
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=europe-west4-a
Démarrer le daemon Docker dans la VM TPU
sudo systemctl start docker
Démarrer le conteneur Docker
sudo docker run --net=host -ti --rm --name your-container-name \ --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \ bash
Installer JAX
pip install jax[tpu]
Installer FLAX
pip install --upgrade clu git clone https://github.com/google/flax.git pip install --user -e flax
Installer les packages
tensorflow
ettensorflow-dataset
pip install tensorflow pip install tensorflow-datasets
Exécutez le script d'entraînement FLAX MNIST.
cd flax/examples/mnist python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
Une fois le script d'entraînement terminé, veillez à nettoyer les ressources.
- Saisissez
exit
pour quitter le conteneur Docker. - Saisissez
exit
pour quitter la VM TPU. Supprimer la VM TPU
gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a
Tranche TPU
Lorsque vous exécutez du code JAX sur une tranche TPU, vous devez exécuter votre code JAX sur tous les nœuds de travail TPU en même temps. Pour ce faire, vous pouvez utiliser la commande gcloud compute tpus tpu-vm ssh
avec les options --worker=all
et --command
. La procédure suivante vous explique comment créer une image Docker pour faciliter la configuration de chaque nœud de travail TPU.
Créez un fichier nommé
Dockerfile
dans votre répertoire actuel et collez-y le texte suivant :FROM python:3.10 RUN pip install jax[tpu] RUN pip install --upgrade clu RUN git clone https://github.com/google/flax.git RUN pip install --user -e flax RUN pip install tensorflow RUN pip install tensorflow-datasets WORKDIR ./flax/examples/mnist
Préparer un dépôt Artifact Registry
gcloud artifacts repositories create your-repo \ --repository-format=docker \ --location=europe-west4 --description="Docker repository" \ --project=your-project gcloud artifacts repositories list \ --project=your-project gcloud auth configure-docker europe-west4-docker.pkg.dev
Compiler l'image Docker
docker build -t your-image-name .
Ajoutez un tag à votre image Docker avant de la transférer vers Artifact Registry. Pour en savoir plus sur l'utilisation d'Artifact Registry, consultez Utiliser des images de conteneurs.
docker tag your-image-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
Transférer votre image Docker dans Artifact Registry
docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
Créer une VM TPU
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base
Extraire l'image Docker d'Artifact Registry sur tous les nœuds de travail TPU
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command='sudo usermod -a -G docker ${USER}'
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
Exécuter le conteneur sur tous les nœuds de travail TPU
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag bash"
Exécuter le script d'entraînement sur tous les nœuds TPU
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5"
Une fois le script d'entraînement terminé, veillez à nettoyer les ressources.
Arrêter le conteneur sur tous les nœuds de travail
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker kill your-container-name"
Supprimer la VM TPU
gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=europe-west4-a
Entraîner un modèle JAX dans un conteneur Docker à l'aide de la pile stable JAX
Vous pouvez créer les images Docker MaxText et MaxDiffusion à l'aide de l'image de base de la pile stable JAX.
JAX Stable Stack fournit un environnement cohérent pour MaxText et MaxDiffusion en regroupant JAX avec des packages de base tels que orbax
, flax
, optax
et libtpu.so
. Ces bibliothèques sont testées pour garantir la compatibilité et fournir une base stable pour créer et exécuter MaxText et MaxDiffusion.
Cela élimine les conflits potentiels dus à des versions de paquets incompatibles.
La pile stable JAX inclut une libtpu.so
entièrement publiée et qualifiée, la bibliothèque de base qui gère la compilation, l'exécution et la configuration réseau ICI des programmes TPU. La version libtpu remplace le build quotidien précédemment utilisé par JAX et garantit la fonctionnalité cohérente des calculs XLA sur le TPU avec des tests de qualification au niveau PJRT dans les IR HLO/StableHLO.
Pour créer l'image Docker MaxText et MaxDiffusion avec la pile stable JAX, lorsque vous exécutez le script docker_build_dependency_image.sh
, définissez la variable MODE
sur stable_stack
et la variable BASEIMAGE
sur l'image de base que vous souhaitez utiliser.
docker_build_dependency_image.sh
se trouve dans le dépôt GitHub MaxDiffusion et dans le dépôt GitHub MaxText.
Clonez le dépôt que vous souhaitez utiliser et exécutez le script docker_build_dependency_image.sh
à partir de ce dépôt pour compiler l'image Docker.
git clone https://github.com/AI-Hypercomputer/maxdiffusion.git git clone https://github.com/AI-Hypercomputer/maxtext.git
La commande suivante génère une image Docker à utiliser avec MaxText et MaxDiffusion à l'aide de us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
comme image de base.
sudo bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
Pour obtenir la liste des images de base de la pile stable JAX disponibles, consultez la section Images de la pile stable JAX dans Artifact Registry.