TPU スライスで PyTorch コードを実行する
このドキュメントのコマンドを実行する前に、アカウントと Cloud TPU プロジェクトを設定するの説明に従ってください。
PyTorch コードを単一の TPU VM で実行したら、TPU スライスで実行してコードをスケールアップできます。TPU スライスは、専用の高速ネットワーク接続で互いに接続された複数の TPU ボードです。このドキュメントでは、TPU スライスで PyTorch コードを実行する方法について説明します。
Cloud TPU スライスを作成する
コマンドを使いやすくするため、いくつかの環境変数を定義します。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-32 export RUNTIME_VERSION=v2-alpha-tpuv5
環境変数の説明
変数 説明 PROJECT_ID
実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 TPU_NAME
TPU の名前。 ZONE
TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。 ACCELERATOR_TYPE
アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。 RUNTIME_VERSION
Cloud TPU ソフトウェアのバージョン。 次のコマンドを実行して TPU VM を作成します。
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
スライスに PyTorch/XLA をインストールする
TPU スライスを作成したら、TPU スライスのすべてのホストに PyTorch をインストールする必要があります。これを行うには、--worker=all
パラメータと --commamnd
パラメータを使用して gcloud compute tpus tpu-vm ssh
コマンドを実行します。
SSH 接続エラーが原因で次のコマンドが失敗した場合は、TPU VM に外部 IP アドレスがないことが原因である可能性があります。外部 IP アドレスのない TPU VM にアクセスするには、パブリック IP アドレスを持たない TPU VM に接続するの説明に従ってください。
すべての TPU VM ワーカーに PyTorch/XLA をインストールします。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
すべての TPU VM ワーカーで XLA のクローンを作成します。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="git clone https://github.com/pytorch/xla.git"
TPU スライスでトレーニング スクリプトを実行する
すべてのワーカーでトレーニング スクリプトを実行します。トレーニング スクリプトでは、単一プログラム複数データ(SPMD)のシャーディング戦略を使用します。SPMD の詳細については、PyTorch/XLA SPMD ユーザーガイドをご覧ください。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
トレーニングには 15 分ほどかかります。完了すると、次のようなメッセージが表示されます。
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
クリーンアップ
TPU VM の使用を終了したら、次の手順でリソースをクリーンアップします。
Cloud TPU インスタンスとの接続を切断していない場合は切断します。
(vm)$ exit
プロンプトが
username@projectname
に変わります。これは、現在、Cloud Shell 内にいることを示しています。Cloud TPU リソースを削除します。
$ gcloud compute tpus tpu-vm delete \ --zone=${ZONE}
gcloud compute tpus tpu-vm list
を実行して、リソースが削除されたことを確認します。削除には数分かかることがあります。次のコマンドの出力には、このチュートリアルで作成したリソースが含まれていないはずです。$ gcloud compute tpus tpu-vm list --zone=${ZONE}