在 TPU 配量上執行 PyTorch 程式碼
執行本文件中的指令前,請務必按照「設定帳戶和 Cloud TPU 專案」中的操作說明操作。
在單一 TPU VM 上執行 PyTorch 程式碼後,您可以將程式碼擴展到 TPU 配量。TPU 配量是指透過專用高速網路連線,彼此相連的多個 TPU 板。本文將介紹如何在 TPU 配量上執行 PyTorch 程式碼。
建立 Cloud TPU 分片
定義一些環境變數,方便使用指令。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-32 export RUNTIME_VERSION=v2-alpha-tpuv5
請執行下列指令,建立 TPU VM:
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
在切片中安裝 PyTorch/XLA
建立 TPU 配量後,您必須在 TPU 配量中的所有主機上安裝 PyTorch。您可以使用 gcloud compute tpus tpu-vm ssh
指令,並使用 --worker=all
和 --commamnd
參數執行這項操作。
如果下列指令因 SSH 連線錯誤而失敗,可能是因為 TPU VM 沒有外部 IP 位址。如要連線至沒有外部 IP 位址的 TPU VM,請按照「連線至沒有公開 IP 位址的 TPU VM」一文中的操作說明進行。
在所有 TPU VM worker 上安裝 PyTorch/XLA:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
在所有 TPU VM worker 上複製 XLA:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="git clone https://github.com/pytorch/xla.git"
在 TPU 分片上執行訓練指令碼
在所有工作站上執行訓練指令碼。訓練指令碼會使用單一程式多重資料 (SPMD) 分割策略。如要進一步瞭解 SPMD,請參閱 PyTorch/XLA SPMD 使用者指南。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
訓練作業大約需要 15 分鐘。完成後,您應該會看到類似下方的訊息:
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
清除所用資源
使用完 TPU VM 後,請按照下列步驟清理資源。
如果您尚未中斷與 Cloud TPU 執行個體的連線,請中斷連線:
(vm)$ exit
畫面上的提示現在應為
username@projectname
,表示您正在 Cloud Shell 中。刪除 Cloud TPU 資源。
$ gcloud compute tpus tpu-vm delete \ --zone=${ZONE}
執行
gcloud compute tpus tpu-vm list
來驗證資源是否已刪除。刪除作業可能需要幾分鐘才能完成。下列指令的輸出內容不應包含本教學課程中建立的任何資源:$ gcloud compute tpus tpu-vm list --zone=${ZONE}