PyTorch コードを TPU Pod スライスで実行する
PyTorch / XLA では、すべての TPU VM がモデルのコードとデータにアクセスできる必要があります。起動スクリプトを使用して、モデルデータをすべての TPU VM に分散させるのに必要なソフトウェアをダウンロードできます。
TPU VM を Virtual Private Cloud(VPC)に接続する場合は、ポート 8470~8479 の上り(内向き)を許可するファイアウォール ルールをプロジェクトに追加する必要があります。ファイアウォール ルールの追加方法については、ファイアウォール ルールの使用をご覧ください。
環境の設定
Cloud Shell で次のコマンドを実行して、
gcloud
の最新バージョンを実行していることを確認します。$ gcloud components update
gcloud
をインストールする必要がある場合は、次のコマンドを使用します。$ sudo apt install -y google-cloud-sdk
いくつかの環境変数を作成します。
$ export PROJECT_ID=project-id $ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-32
TPU VM を作成します。
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version ${RUNTIME_VERSION}
トレーニング スクリプトを構成して実行する
プロジェクトに SSH 証明書を追加します。
ssh-add ~/.ssh/google_compute_engine
すべての TPU VM ワーカーに PyTorch/XLA をインストールする
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all --command=" pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
すべての TPU VM ワーカーで XLA のクローンを作成する
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all --command="git clone -b r2.5 https://github.com/pytorch/xla.git"
すべてのワーカーでトレーニング スクリプトを実行する
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
トレーニングには約 5 分間を要します。完了すると、次のようなメッセージが表示されます。
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
クリーンアップ
TPU VM の使用を終了したら、次の手順に沿ってリソースをクリーンアップします。
Compute Engine インスタンスとの接続を切断していない場合は切断します。
(vm)$ exit
プロンプトが
username@projectname
に変わります。これは、現在、Cloud Shell 内にいることを示しています。Cloud TPU と Compute Engine リソースを削除します。
$ gcloud compute tpus tpu-vm delete \ --zone=${ZONE}
gcloud compute tpus execution-groups list
を実行して、リソースが削除されたことを確認します。削除には数分かかることがあります。次のコマンドの出力には、このチュートリアルで作成したリソースを含めないでください。$ gcloud compute tpus tpu-vm list --zone=${ZONE}