在 TPU 切片上运行 PyTorch 代码
在运行本文档中的命令之前,请确保已按照设置账号和 Cloud TPU 项目中的说明操作。
在单个 TPU 虚拟机上运行 PyTorch 代码后,您可以通过在 TPU 切片上运行代码来扩容代码。TPU 切片是通过专用高速网络连接相互连接的多个 TPU 板。本文档介绍了如何在 TPU 切片上运行 PyTorch 代码。
创建 Cloud TPU 切片
定义一些环境变量,以便更轻松地使用这些命令。
export PROJECT_ID=your-project export ACCELERATOR_TYPE=v5p-32 export ZONE=europe-west4-b export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=your-tpu-name
运行以下命令,创建 TPU 虚拟机:
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
在您的 slice 上安装 PyTorch/XLA
创建 TPU 切片后,您必须在 TPU 切片中的所有主机上安装 PyTorch。您可以使用 gcloud compute tpus tpu-vm ssh
命令并使用 --worker=all
和 --commamnd
参数来执行此操作。
如果以下命令因 SSH 连接错误而失败,可能是因为 TPU 虚拟机没有外部 IP 地址。如需访问没有外部 IP 地址的 TPU 虚拟机,请按照连接到没有公共 IP 地址的 TPU 虚拟机中的说明操作。
在所有 TPU 虚拟机工作器上安装 PyTorch/XLA:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
在所有 TPU VM 工作器上克隆 XLA:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="git clone https://github.com/pytorch/xla.git"
在 TPU 切片上运行训练脚本
在所有工作器上运行训练脚本。训练脚本使用单程序多数据 (SPMD) 分片策略。如需详细了解 SPMD,请参阅 PyTorch/XLA SPMD 用户指南。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
训练大约需要 15 分钟。完成后,您应该会看到类似于下面这样的消息:
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
清理
完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。
断开与 Cloud TPU 实例的连接(如果您尚未这样做):
(vm)$ exit
您的提示符现在应为
username@projectname
,表明您位于 Cloud Shell 中。删除您的 Cloud TPU 资源。
$ gcloud compute tpus tpu-vm delete \ --zone=${ZONE}
通过运行
gcloud compute tpus tpu-vm list
验证资源是否已删除。删除操作可能需要几分钟时间才能完成。以下命令的输出不应包含本教程中创建的任何资源:$ gcloud compute tpus tpu-vm list --zone=${ZONE}