Docker コンテナで TPU ワークロードを実行する

Docker コンテナでは、コードとすべての必要な依存関係が 1 つの配布可能なパッケージにまとまることで、アプリケーションの構成が簡単になります。TPU VM 内で Docker コンテナを実行すると、Cloud TPU アプリケーションの構成と共有を簡素化できます。このドキュメントでは、Cloud TPU でサポートされている各 ML フレームワークに Docker コンテナを設定する方法について説明します。

Docker コンテナで PyTorch モデルをトレーニングする

TPU デバイス

  1. Cloud TPU VM を作成します。

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. SSH を使用して TPU VM に接続します。

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=europe-west4-a
  3. Google Cloud ユーザーに Artifact Registry 読み取りロールが付与されていることを確認します。詳細については、Artifact Registry ロールの付与をご覧ください。

  4. PyTorch/XLA ナイトリー イメージを使用して TPU VM でコンテナを起動します。

    sudo docker run --net=host -ti --rm --name your-container-name --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. TPU ランタイムを構成します。

    PyTorch/XLA ランタイムには、PJRT と XRT の 2 つのオプションがあります。XRT を使用する理由がない限り、PJRT を使用することをおすすめします。さまざまなランタイム構成の詳細については、PJRT ランタイム ドキュメントをご覧ください。

    PJRT

    export PJRT_DEVICE=TPU

    XRT

    export XRT_TPU_CONFIG="localservice;0;localhost:51011"
  6. PyTorch XLA リポジトリのクローンを作成します。

    git clone --recursive https://github.com/pytorch/xla.git
  7. ResNet50 をトレーニングします。

    python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

トレーニング スクリプトが完了したら、リソースをクリーンアップしてください。

  1. exit」と入力して Docker コンテナを終了します。
  2. exit」と入力して TPU VM を終了します。
  3. TPU VM を削除します。

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU スライス

TPU スライスで PyTorch コードを実行する場合は、すべての TPU ワーカーでコードを同時に実行する必要があります。これを行う方法の一つに、--worker=all フラグと --command フラグを指定して gcloud compute tpus tpu-vm ssh コマンドを使用する方法があります。次の手順では、各 TPU ワーカーを簡単に設定できるように Docker イメージを作成する方法を示します。

  1. TPU VM を作成します。

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-32 \
    --version=tpu-ubuntu2204-base
  2. 現在のユーザーを Docker グループに追加します。

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=us-central2-b \
    --worker=all \
    --command='sudo usermod -a -G docker $USER'
  3. PyTorch XLA リポジトリのクローンを作成します。

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="git clone --recursive https://github.com/pytorch/xla.git"
  4. すべての TPU ワーカーでコンテナ内のトレーニング スクリプトを実行します。

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="docker run --rm --privileged --net=host  -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"

    Docker コマンドフラグ:

    • --rm は、プロセスが終了した後にコンテナを削除します。
    • --privileged は、TPU デバイスをコンテナに公開します。
    • --net=host は、コンテナのすべてのポートを TPU VM にバインドして、Pod 内のホスト間の通信を可能にします。
    • -e は環境変数を設定します。

トレーニング スクリプトが完了したら、リソースをクリーンアップしてください。

次のコマンドを使用して TPU VM を削除します。

gcloud compute tpus tpu-vm delete your-tpu-name \
--zone=us-central2-b

Docker コンテナで JAX モデルをトレーニングする

TPU デバイス

  1. TPU VM を作成します。

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. SSH を使用して TPU VM に接続します。

    gcloud compute tpus tpu-vm ssh your-tpu-name  --zone=europe-west4-a
  3. TPU VM で Docker デーモンを起動します。

    sudo systemctl start docker
  4. Docker コンテナを起動します。

    sudo docker run --net=host -ti --rm --name your-container-name \
    --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. JAX をインストールします。

    pip install jax[tpu]
  6. Flax をインストールします。

    pip install --upgrade clu
    git clone https://github.com/google/flax.git
    pip install --user -e flax
  7. tensorflow パッケージと tensorflow-dataset パッケージをインストールします。

    pip install tensorflow
    pip install tensorflow-datasets
  8. Flax MNIST トレーニング スクリプトを実行します。

    cd flax/examples/mnist
    python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

トレーニング スクリプトが完了したら、リソースをクリーンアップしてください。

  1. exit」と入力して Docker コンテナを終了します。
  2. exit」と入力して TPU VM を終了します。
  3. TPU VM を削除します。

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU スライス

TPU スライスで JAX コードを実行する場合は、すべての TPU ワーカーで JAX コードを同時に実行する必要があります。これを行う方法の一つに、--worker=all フラグと --command フラグを指定して gcloud compute tpus tpu-vm ssh コマンドを使用する方法があります。次の手順では、各 TPU ワーカーを簡単に設定できるように Docker イメージを作成する方法を示します。

  1. 現在のディレクトリに Dockerfile という名前のファイルを作成し、次のテキストを貼り付けます。

    FROM python:3.10
    RUN pip install jax[tpu]
    RUN pip install --upgrade clu
    RUN git clone https://github.com/google/flax.git
    RUN pip install --user -e flax
    RUN pip install tensorflow
    RUN pip install tensorflow-datasets
    WORKDIR ./flax/examples/mnist
  2. Artifact Registry を準備します。

    gcloud artifacts repositories create your-repo \
    --repository-format=docker \
    --location=europe-west4 --description="Docker repository" \
    --project=your-project
    
    gcloud artifacts repositories list \
    --project=your-project
    
    gcloud auth configure-docker europe-west4-docker.pkg.dev
  3. Docker イメージをビルドします。

    docker build -t your-image-name .
  4. Docker イメージにタグを追加してから、Artifact Registry に push します。Artifact Registry の操作の詳細については、コンテナ イメージの使用をご覧ください。

    docker tag your-image-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  5. Docker イメージを Artifact Registry に push します。

    docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  6. TPU VM を作成します。

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  7. すべての TPU ワーカーで Artifact Registry から Docker イメージを pull します。

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command='sudo usermod -a -G docker ${USER}'
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
  8. すべての TPU ワーカーでコンテナを実行します。

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag bash"
  9. すべての TPU ワーカーでトレーニング スクリプトを実行します。

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5"

トレーニング スクリプトが完了したら、リソースをクリーンアップしてください。

  1. すべてのワーカーでコンテナをシャットダウンします。

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker kill your-container-name"
  2. TPU VM を削除します。

    gcloud compute tpus tpu-vm delete your-tpu-name \
    --zone=europe-west4-a

JAX Stable Stack を使用して Docker コンテナで JAX モデルをトレーニングする

JAX Stable Stack ベースイメージを使用して、MaxTextMaxDiffusion Docker イメージをビルドできます。

JAX Stable Stack は、JAX を orbaxflaxoptaxlibtpu.so などのコアパッケージとバンドルすることで、MaxText と MaxDiffusion の一貫した環境を提供します。これらのライブラリは、互換性を確保し、MaxText と MaxDiffusion のビルドと実行のための安定した基盤を提供するためにテストされています。これにより、互換性のないパッケージ バージョンによる競合の発生を防ぐことができます。

JAX Stable Stack には、完全にリリースされ、適格性を確認済みの libtpu.so が含まれています。これは、TPU プログラムのコンパイル、実行、ICI ネットワーク構成を駆動するコアライブラリです。libtpu リリースは、JAX で以前に使用されていたナイトリー ビルドに代わるもので、HLO/StableHLO IR で PJRT レベルの適格性テストを行い、TPU での XLA 計算の一貫した機能を保証します。

JAX Stable Stack で MaxText と MaxDiffusion の Docker イメージをビルドするには、docker_build_dependency_image.sh スクリプトを実行するときに、MODE 変数を stable_stack に設定し、BASEIMAGE 変数を使用するベースイメージに設定します。

docker_build_dependency_image.sh は、MaxDiffusion GitHub リポジトリMaxText GitHub リポジトリにあります。使用するリポジトリのクローンを作成し、そのリポジトリの docker_build_dependency_image.sh スクリプトを実行して Docker イメージをビルドします。

git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
git clone https://github.com/AI-Hypercomputer/maxtext.git

次のコマンドは、us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 をベースイメージとして使用して、MaxText と MaxDiffusion で使用する Docker イメージを生成します。

sudo bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

使用可能な JAX Stable Stack ベースイメージの一覧については、Artifact Registry の JAX Stable Stack イメージをご覧ください。

次のステップ