在 TPU 切片上运行 PyTorch 代码

在运行本文档中的命令之前,请确保已按照设置账号和 Cloud TPU 项目中的说明操作。

在单个 TPU 虚拟机上运行 PyTorch 代码后,您可以通过在 TPU 切片上运行代码来扩容代码。TPU 切片是通过专用高速网络连接相互连接的多个 TPU 板。本文档介绍了如何在 TPU 切片上运行 PyTorch 代码。

创建 Cloud TPU 切片

  1. 定义一些环境变量,以便更轻松地使用这些命令。

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    环境变量说明

    PROJECT_ID
    您的 Google Cloud 项目 ID。
    ACCELERATOR_TYPE
    加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    ZONE
    拟在其中创建 Cloud TPU 的可用区
    RUNTIME_VERSION
    Cloud TPU 软件版本
    TPU_NAME
    用户为 Cloud TPU 分配的名称。
  2. 运行以下命令,创建 TPU 虚拟机:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --version=${RUNTIME_VERSION}

在您的 slice 上安装 PyTorch/XLA

创建 TPU 切片后,您必须在 TPU 切片中的所有主机上安装 PyTorch。您可以使用 gcloud compute tpus tpu-vm ssh 命令并使用 --worker=all--commamnd 参数来执行此操作。

如果以下命令因 SSH 连接错误而失败,可能是因为 TPU 虚拟机没有外部 IP 地址。如需访问没有外部 IP 地址的 TPU 虚拟机,请按照连接到没有公共 IP 地址的 TPU 虚拟机中的说明操作。

  1. 在所有 TPU 虚拟机工作器上安装 PyTorch/XLA:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. 在所有 TPU VM 工作器上克隆 XLA:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="git clone https://github.com/pytorch/xla.git"

在 TPU 切片上运行训练脚本

在所有工作器上运行训练脚本。训练脚本使用单程序多数据 (SPMD) 分片策略。如需详细了解 SPMD,请参阅 PyTorch/XLA SPMD 用户指南

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --worker=all \
   --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py  \
   --fake_data \
   --model=resnet50  \
   --num_epochs=1 2>&1 | tee ~/logs.txt"

训练大约需要 15 分钟。完成后,您应该会看到类似于下面这样的消息:

Epoch 1 test end 23:49:15, Accuracy=100.00
     10.164.0.11 [0] Max Accuracy: 100.00%

清理

完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。

  1. 断开与 Cloud TPU 实例的连接(如果您尚未这样做):

    (vm)$ exit

    您的提示符现在应为 username@projectname,表明您位于 Cloud Shell 中。

  2. 删除您的 Cloud TPU 资源。

    $ gcloud compute tpus tpu-vm delete  \
        --zone=${ZONE}
  3. 通过运行 gcloud compute tpus tpu-vm list 验证资源是否已删除。删除操作可能需要几分钟时间才能完成。以下命令的输出不应包含本教程中创建的任何资源:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}