TPU-Arbeitslasten in einem Docker-Container ausführen

Docker-Container erleichtern die Konfiguration von Anwendungen, da Code und alle erforderlichen Abhängigkeiten in einem distribuierbaren Paket kombiniert werden. Sie können Docker-Container in TPU-VMs ausführen, um die Konfiguration und Freigabe Ihrer Cloud TPU-Anwendungen zu vereinfachen. In diesem Dokument wird beschrieben, wie Sie einen Docker-Container für jedes von Cloud TPU unterstützte ML-Framework einrichten.

PyTorch-Modell in einem Docker-Container trainieren

TPU-Gerät

  1. Cloud TPU-VM erstellen

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. SSH-Verbindung zur TPU-VM herstellen

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=europe-west4-a
  3. Prüfen Sie, ob Ihrem Nutzer Google Cloud die Rolle „Artifact Registry Reader“ zugewiesen wurde. Weitere Informationen finden Sie unter Artifact Registry-Rollen gewähren.

  4. Container mit dem nächtlichen PyTorch/XLA-Image in der TPU-VM starten

    sudo docker run --net=host -ti --rm --name your-container-name --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. TPU-Laufzeit konfigurieren

    Es gibt zwei PyTorch/XLA-Laufzeitoptionen: PJRT und XRT. Wir empfehlen, PJRT zu verwenden, es sei denn, Sie haben einen Grund, XRT zu verwenden. Weitere Informationen zu den verschiedenen Laufzeitkonfigurationen finden Sie in der PJRT-Laufzeitdokumentation.

    PJRT

    export PJRT_DEVICE=TPU

    XRT

    export XRT_TPU_CONFIG="localservice;0;localhost:51011"
  6. PyTorch XLA-Repository klonen

    git clone --recursive https://github.com/pytorch/xla.git
  7. ResNet50 trainieren

    python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

Bereinigen Sie die Ressourcen nach Abschluss des Trainingsskripts.

  1. Geben Sie exit ein, um den Docker-Container zu beenden.
  2. Geben Sie exit ein, um die TPU-VM zu beenden.
  3. TPU-VM löschen

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU-Slice

Wenn Sie PyTorch-Code auf einem TPU-Speicherplatz ausführen, müssen Sie ihn gleichzeitig auf allen TPU-Workern ausführen. Dazu können Sie den Befehl gcloud compute tpus tpu-vm ssh mit den Flags --worker=all und --command verwenden. Im Folgenden wird beschrieben, wie Sie ein Docker-Image erstellen, um die Einrichtung der einzelnen TPU-Arbeitsstationen zu vereinfachen.

  1. TPU-VM erstellen

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-32 \
    --version=tpu-ubuntu2204-base
  2. Den aktuellen Nutzer der Docker-Gruppe hinzufügen

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=us-central2-b \
    --worker=all \
    --command='sudo usermod -a -G docker $USER'
  3. PyTorch XLA-Repository klonen

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="git clone --recursive https://github.com/pytorch/xla.git"
  4. Trainingsskript in einem Container auf allen TPU-Workern ausführen

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="docker run --rm --privileged --net=host  -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"

    Docker-Befehls-Flags:

    • --rm entfernt den Container, nachdem der Prozess beendet wurde.
    • --privileged stellt das TPU-Gerät dem Container zur Verfügung.
    • --net=host bindet alle Ports des Containers an die TPU-VM, um die Kommunikation zwischen den Hosts im Pod zu ermöglichen.
    • -e legt Umgebungsvariablen fest.

Bereinigen Sie die Ressourcen nach Abschluss des Trainingsskripts.

Löschen Sie die TPU-VM mit dem folgenden Befehl:

gcloud compute tpus tpu-vm delete your-tpu-name \
--zone=us-central2-b

JAX-Modell in einem Docker-Container trainieren

TPU-Gerät

  1. TPU-VM erstellen

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. SSH-Verbindung zur TPU-VM herstellen

    gcloud compute tpus tpu-vm ssh your-tpu-name  --zone=europe-west4-a
  3. Docker-Daemon in der TPU-VM starten

    sudo systemctl start docker
  4. Docker-Container starten

    sudo docker run --net=host -ti --rm --name your-container-name \
    --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. JAX installieren

    pip install jax[tpu]
  6. FLAX installieren

    pip install --upgrade clu
    git clone https://github.com/google/flax.git
    pip install --user -e flax
  7. tensorflow- und tensorflow-dataset-Pakete installieren

    pip install tensorflow
    pip install tensorflow-datasets
  8. FLAX MNIST-Trainingsskript ausführen

    cd flax/examples/mnist
    python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

Bereinigen Sie die Ressourcen nach Abschluss des Trainingsskripts.

  1. Geben Sie exit ein, um den Docker-Container zu beenden.
  2. Geben Sie exit ein, um die TPU-VM zu beenden.
  3. TPU-VM löschen

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU-Slice

Wenn Sie JAX-Code auf einem TPU-Speicherplatz ausführen, müssen Sie ihn gleichzeitig auf allen TPU-Workern ausführen. Dazu können Sie den Befehl gcloud compute tpus tpu-vm ssh mit den Flags --worker=all und --command verwenden. Im Folgenden wird beschrieben, wie Sie ein Docker-Image erstellen, um die Einrichtung der einzelnen TPU-Arbeitsstationen zu vereinfachen.

  1. Erstellen Sie im aktuellen Verzeichnis eine Datei mit dem Namen Dockerfile und fügen Sie den folgenden Text ein:

    FROM python:3.10
    RUN pip install jax[tpu]
    RUN pip install --upgrade clu
    RUN git clone https://github.com/google/flax.git
    RUN pip install --user -e flax
    RUN pip install tensorflow
    RUN pip install tensorflow-datasets
    WORKDIR ./flax/examples/mnist
  2. Artifact Registry vorbereiten

    gcloud artifacts repositories create your-repo \
    --repository-format=docker \
    --location=europe-west4 --description="Docker repository" \
    --project=your-project
    
    gcloud artifacts repositories list \
    --project=your-project
    
    gcloud auth configure-docker europe-west4-docker.pkg.dev
  3. Docker-Image erstellen

    docker build -t your-image-name .
  4. Fügen Sie Ihrem Docker-Image ein Tag hinzu, bevor Sie es in Artifact Registry veröffentlichen. Weitere Informationen zur Arbeit mit Artifact Registry finden Sie unter Mit Container-Images arbeiten.

    docker tag your-image-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  5. Docker-Image in Artifact Registry veröffentlichen

    docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  6. TPU-VM erstellen

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  7. Docker-Image auf allen TPU-Arbeitsstationen aus der Artifact Registry abrufen

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command='sudo usermod -a -G docker ${USER}'
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
  8. Container auf allen TPU-Workern ausführen

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag bash"
  9. Trainingsskript auf allen TPU-Arbeitsstationen ausführen

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5"

Bereinigen Sie die Ressourcen nach Abschluss des Trainingsskripts.

  1. Container auf allen Arbeitsstationen herunterfahren

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker kill your-container-name"
  2. TPU-VM löschen

    gcloud compute tpus tpu-vm delete your-tpu-name \
    --zone=europe-west4-a

JAX-Modell in einem Docker-Container mit JAX Stable Stack trainieren

Sie können die Docker-Images MaxText und MaxDiffusion mit dem Basis-Image JAX Stable Stack erstellen.

JAX Stable Stack bietet eine einheitliche Umgebung für MaxText und MaxDiffusion, da JAX mit Kernpaketen wie orbax, flax, optax und libtpu.so gebündelt wird. Diese Bibliotheken werden auf Kompatibilität getestet und bieten eine stabile Grundlage für die Erstellung und Ausführung von MaxText und MaxDiffusion. So werden potenzielle Konflikte aufgrund von inkompatiblen Paketversionen vermieden.

Der stabile JAX-Stack enthält eine vollständig veröffentlichte und qualifizierte libtpu.so, die Kernbibliothek, die die Kompilierung, Ausführung und ICI-Netzwerkkonfiguration von TPU-Programmen steuert. Der libtpu-Release ersetzt den bisher von JAX verwendeten Nightly-Build und sorgt mit Qualifikationstests auf PJRT-Ebene in HLO/StableHLO-IRs für eine konsistente Funktion von XLA-Berechnungen auf TPUs.

Wenn Sie das Docker-Image für MaxText und MaxDiffusion mit dem JAX Stable Stack erstellen möchten, legen Sie beim Ausführen des docker_build_dependency_image.sh-Scripts die Variable MODE auf stable_stack und die Variable BASEIMAGE auf das gewünschte Basis-Image fest.

docker_build_dependency_image.sh befindet sich im GitHub-Repository von MaxDiffusion und im GitHub-Repository von MaxText. Klonen Sie das gewünschte Repository und führen Sie das docker_build_dependency_image.sh-Script aus diesem Repository aus, um das Docker-Image zu erstellen.

git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
git clone https://github.com/AI-Hypercomputer/maxtext.git

Mit dem folgenden Befehl wird ein Docker-Image für die Verwendung mit MaxText und MaxDiffusion generiert. Dabei wird us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 als Basis-Image verwendet.

sudo bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

Eine Liste der verfügbaren JAX Stable Stack-Basis-Images finden Sie unter JAX Stable Stack-Images in Artifact Registry.

Nächste Schritte