TPU-Arbeitslasten in einem Docker-Container ausführen
Docker-Container erleichtern die Konfiguration von Anwendungen, da Code und alle erforderlichen Abhängigkeiten in einem distribuierbaren Paket kombiniert werden. Sie können Docker-Container in TPU-VMs ausführen, um die Konfiguration und Freigabe Ihrer Cloud TPU-Anwendungen zu vereinfachen. In diesem Dokument wird beschrieben, wie Sie einen Docker-Container für jedes von Cloud TPU unterstützte ML-Framework einrichten.
PyTorch-Modell in einem Docker-Container trainieren
TPU-Gerät
Cloud TPU-VM erstellen
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base
SSH-Verbindung zur TPU-VM herstellen
gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=europe-west4-a
Prüfen Sie, ob Ihrem Nutzer Google Cloud die Rolle „Artifact Registry Reader“ zugewiesen wurde. Weitere Informationen finden Sie unter Artifact Registry-Rollen gewähren.
Container mit dem nächtlichen PyTorch/XLA-Image in der TPU-VM starten
sudo docker run --net=host -ti --rm --name your-container-name --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \ bash
TPU-Laufzeit konfigurieren
Es gibt zwei PyTorch/XLA-Laufzeitoptionen: PJRT und XRT. Wir empfehlen, PJRT zu verwenden, es sei denn, Sie haben einen Grund, XRT zu verwenden. Weitere Informationen zu den verschiedenen Laufzeitkonfigurationen finden Sie in der PJRT-Laufzeitdokumentation.
PJRT
export PJRT_DEVICE=TPU
XRT
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
PyTorch XLA-Repository klonen
git clone --recursive https://github.com/pytorch/xla.git
ResNet50 trainieren
python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1
Bereinigen Sie die Ressourcen nach Abschluss des Trainingsskripts.
- Geben Sie
exit
ein, um den Docker-Container zu beenden. - Geben Sie
exit
ein, um die TPU-VM zu beenden. TPU-VM löschen
gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a
TPU-Slice
Wenn Sie PyTorch-Code auf einem TPU-Speicherplatz ausführen, müssen Sie ihn gleichzeitig auf allen TPU-Workern ausführen. Dazu können Sie den Befehl gcloud compute tpus tpu-vm ssh
mit den Flags --worker=all
und --command
verwenden. Im Folgenden wird beschrieben, wie Sie ein Docker-Image erstellen, um die Einrichtung der einzelnen TPU-Arbeitsstationen zu vereinfachen.
TPU-VM erstellen
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=us-central2-b \ --accelerator-type=v4-32 \ --version=tpu-ubuntu2204-base
Den aktuellen Nutzer der Docker-Gruppe hinzufügen
gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=us-central2-b \ --worker=all \ --command='sudo usermod -a -G docker $USER'
PyTorch XLA-Repository klonen
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=us-central2-b \ --command="git clone --recursive https://github.com/pytorch/xla.git"
Trainingsskript in einem Container auf allen TPU-Workern ausführen
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=us-central2-b \ --command="docker run --rm --privileged --net=host -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"
Docker-Befehls-Flags:
--rm
entfernt den Container, nachdem der Prozess beendet wurde.--privileged
stellt das TPU-Gerät dem Container zur Verfügung.--net=host
bindet alle Ports des Containers an die TPU-VM, um die Kommunikation zwischen den Hosts im Pod zu ermöglichen.-e
legt Umgebungsvariablen fest.
Bereinigen Sie die Ressourcen nach Abschluss des Trainingsskripts.
Löschen Sie die TPU-VM mit dem folgenden Befehl:
gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=us-central2-b
JAX-Modell in einem Docker-Container trainieren
TPU-Gerät
TPU-VM erstellen
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base
SSH-Verbindung zur TPU-VM herstellen
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=europe-west4-a
Docker-Daemon in der TPU-VM starten
sudo systemctl start docker
Docker-Container starten
sudo docker run --net=host -ti --rm --name your-container-name \ --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \ bash
JAX installieren
pip install jax[tpu]
FLAX installieren
pip install --upgrade clu git clone https://github.com/google/flax.git pip install --user -e flax
tensorflow
- undtensorflow-dataset
-Pakete installierenpip install tensorflow pip install tensorflow-datasets
FLAX MNIST-Trainingsskript ausführen
cd flax/examples/mnist python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
Bereinigen Sie die Ressourcen nach Abschluss des Trainingsskripts.
- Geben Sie
exit
ein, um den Docker-Container zu beenden. - Geben Sie
exit
ein, um die TPU-VM zu beenden. TPU-VM löschen
gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a
TPU-Slice
Wenn Sie JAX-Code auf einem TPU-Speicherplatz ausführen, müssen Sie ihn gleichzeitig auf allen TPU-Workern ausführen. Dazu können Sie den Befehl gcloud compute tpus tpu-vm ssh
mit den Flags --worker=all
und --command
verwenden. Im Folgenden wird beschrieben, wie Sie ein Docker-Image erstellen, um die Einrichtung der einzelnen TPU-Arbeitsstationen zu vereinfachen.
Erstellen Sie im aktuellen Verzeichnis eine Datei mit dem Namen
Dockerfile
und fügen Sie den folgenden Text ein:FROM python:3.10 RUN pip install jax[tpu] RUN pip install --upgrade clu RUN git clone https://github.com/google/flax.git RUN pip install --user -e flax RUN pip install tensorflow RUN pip install tensorflow-datasets WORKDIR ./flax/examples/mnist
Artifact Registry vorbereiten
gcloud artifacts repositories create your-repo \ --repository-format=docker \ --location=europe-west4 --description="Docker repository" \ --project=your-project gcloud artifacts repositories list \ --project=your-project gcloud auth configure-docker europe-west4-docker.pkg.dev
Docker-Image erstellen
docker build -t your-image-name .
Fügen Sie Ihrem Docker-Image ein Tag hinzu, bevor Sie es in Artifact Registry veröffentlichen. Weitere Informationen zur Arbeit mit Artifact Registry finden Sie unter Mit Container-Images arbeiten.
docker tag your-image-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
Docker-Image in Artifact Registry veröffentlichen
docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
TPU-VM erstellen
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base
Docker-Image auf allen TPU-Arbeitsstationen aus der Artifact Registry abrufen
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command='sudo usermod -a -G docker ${USER}'
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
Container auf allen TPU-Workern ausführen
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag bash"
Trainingsskript auf allen TPU-Arbeitsstationen ausführen
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5"
Bereinigen Sie die Ressourcen nach Abschluss des Trainingsskripts.
Container auf allen Arbeitsstationen herunterfahren
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker kill your-container-name"
TPU-VM löschen
gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=europe-west4-a
JAX-Modell in einem Docker-Container mit JAX Stable Stack trainieren
Sie können die Docker-Images MaxText und MaxDiffusion mit dem Basis-Image JAX Stable Stack erstellen.
JAX Stable Stack bietet eine einheitliche Umgebung für MaxText und MaxDiffusion, da JAX mit Kernpaketen wie orbax
, flax
, optax
und libtpu.so
gebündelt wird. Diese Bibliotheken werden auf Kompatibilität getestet und bieten eine stabile Grundlage für die Erstellung und Ausführung von MaxText und MaxDiffusion.
So werden potenzielle Konflikte aufgrund von inkompatiblen Paketversionen vermieden.
Der stabile JAX-Stack enthält eine vollständig veröffentlichte und qualifizierte libtpu.so
, die Kernbibliothek, die die Kompilierung, Ausführung und ICI-Netzwerkkonfiguration von TPU-Programmen steuert. Der libtpu-Release ersetzt den bisher von JAX verwendeten Nightly-Build und sorgt mit Qualifikationstests auf PJRT-Ebene in HLO/StableHLO-IRs für eine konsistente Funktion von XLA-Berechnungen auf TPUs.
Wenn Sie das Docker-Image für MaxText und MaxDiffusion mit dem JAX Stable Stack erstellen möchten, legen Sie beim Ausführen des docker_build_dependency_image.sh
-Scripts die Variable MODE
auf stable_stack
und die Variable BASEIMAGE
auf das gewünschte Basis-Image fest.
docker_build_dependency_image.sh
befindet sich im GitHub-Repository von MaxDiffusion und im GitHub-Repository von MaxText.
Klonen Sie das gewünschte Repository und führen Sie das docker_build_dependency_image.sh
-Script aus diesem Repository aus, um das Docker-Image zu erstellen.
git clone https://github.com/AI-Hypercomputer/maxdiffusion.git git clone https://github.com/AI-Hypercomputer/maxtext.git
Mit dem folgenden Befehl wird ein Docker-Image für die Verwendung mit MaxText und MaxDiffusion generiert. Dabei wird us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
als Basis-Image verwendet.
sudo bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
Eine Liste der verfügbaren JAX Stable Stack-Basis-Images finden Sie unter JAX Stable Stack-Images in Artifact Registry.