TPU スライスで PyTorch コードを実行する

このドキュメントのコマンドを実行する前に、アカウントと Cloud TPU プロジェクトを設定するの説明に従ってください。

PyTorch コードを単一の TPU VM で実行したら、TPU スライスで実行してコードをスケールアップできます。TPU スライスは、専用の高速ネットワーク接続で互いに接続された複数の TPU ボードです。このドキュメントでは、TPU スライスで PyTorch コードを実行する方法について説明します。

Cloud TPU スライスを作成する

  1. コマンドを使いやすくするため、いくつかの環境変数を定義します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-32
    export RUNTIME_VERSION=v2-alpha-tpuv5

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン

  2. 次のコマンドを実行して TPU VM を作成します。

    $ gcloud compute tpus tpu-vm create ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --version=${RUNTIME_VERSION}

スライスに PyTorch/XLA をインストールする

TPU スライスを作成したら、TPU スライスのすべてのホストに PyTorch をインストールする必要があります。これを行うには、--worker=all パラメータと --commamnd パラメータを使用して gcloud compute tpus tpu-vm ssh コマンドを実行します。

SSH 接続エラーが原因で次のコマンドが失敗した場合は、TPU VM に外部 IP アドレスがないことが原因である可能性があります。外部 IP アドレスのない TPU VM にアクセスするには、パブリック IP アドレスを持たない TPU VM に接続するの説明に従ってください。

  1. すべての TPU VM ワーカーに PyTorch/XLA をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. すべての TPU VM ワーカーで XLA のクローンを作成します。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="git clone https://github.com/pytorch/xla.git"

TPU スライスでトレーニング スクリプトを実行する

すべてのワーカーでトレーニング スクリプトを実行します。トレーニング スクリプトでは、単一プログラム複数データ(SPMD)のシャーディング戦略を使用します。SPMD の詳細については、PyTorch/XLA SPMD ユーザーガイドをご覧ください。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --worker=all \
   --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py  \
   --fake_data \
   --model=resnet50  \
   --num_epochs=1 2>&1 | tee ~/logs.txt"

トレーニングには 15 分ほどかかります。完了すると、次のようなメッセージが表示されます。

Epoch 1 test end 23:49:15, Accuracy=100.00
     10.164.0.11 [0] Max Accuracy: 100.00%

クリーンアップ

TPU VM の使用を終了したら、次の手順でリソースをクリーンアップします。

  1. Cloud TPU インスタンスとの接続を切断していない場合は切断します。

    (vm)$ exit

    プロンプトが username@projectname に変わります。これは、現在、Cloud Shell 内にいることを示しています。

  2. Cloud TPU リソースを削除します。

    $ gcloud compute tpus tpu-vm delete  \
        --zone=${ZONE}
  3. gcloud compute tpus tpu-vm list を実行して、リソースが削除されたことを確認します。削除には数分かかることがあります。次のコマンドの出力には、このチュートリアルで作成したリソースが含まれていないはずです。

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}