Cloud TPU v5e 训练

TPU v5e 的每个 Pod 占用空间更小(256 个芯片),经过优化,成为适用于 Transformer、文本转图片和卷积神经网络 (CNN) 训练、微调和服务的高价值产品。如需详细了解如何使用 Cloud TPU v5e 进行服务,请参阅使用 v5e 进行推断

如需详细了解 Cloud TPU v5e TPU 硬件和配置,请参阅 TPU v5e

开始使用

以下部分介绍了如何开始使用 TPU v5e。

请求配额

您需要有配额才能使用 TPU v5e 进行训练。按需 TPU、预留的 TPU 和 TPU Spot 虚拟机有不同的配额类型。如果您将 TPU v5e 用于推理,则需要单独的配额。如需详细了解配额,请参阅配额。如需申请 TPU v5e 配额,请与 Cloud 销售团队联系。

创建 Google Cloud 账号和项目

您需要拥有 Google Cloud 账号和项目才能使用 Cloud TPU。如需了解详情,请参阅设置 Cloud TPU 环境

创建 Cloud TPU

最佳实践是使用 queued-resource create 命令将 Cloud TPU v5e 预配为已排队的资源。如需了解详情,请参阅管理已排队的资源

您还可以使用 Create Node API (gcloud compute tpus tpu-vm create) 来预配 Cloud TPU v5e。如需了解详情,请参阅管理 TPU 资源

如需详细了解可用于训练的 v5e 配置,请参阅用于训练的 Cloud TPU v5e 类型

框架设置

本部分介绍了结合使用 JAX 或 PyTorch 与 TPU v5e 进行自定义模型训练的一般设置过程。

如需查看推理设置说明,请参阅 v5e 推理简介

定义一些环境变量:

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id

JAX 设置

如果切片形状大于 8 个芯片,则一个切片中会有多个虚拟机。在这种情况下,您需要使用 --worker=all 标志在一个步骤中对所有 TPU 虚拟机运行安装,而无需使用 SSH 单独登录每个虚拟机:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

命令标志说明

变量 说明
TPU_NAME 用户分配的 TPU 文本 ID,该 ID 是在分配已排队的资源请求时创建的。
PROJECT_ID Google Cloud 项目名称。使用现有项目或在设置 Google Cloud 项目时创建新项目
ZONE 如需了解支持的可用区,请参阅 Cloud TPU 区域和可用区文档。
worker 有权访问底层 TPU 的 TPU 虚拟机。

您可以运行以下命令来检查设备数量(此处显示的输出是使用 v5litepod-16 切片生成的)。此代码通过检查 JAX 是否看到 Cloud TPU TensorCore 并可以运行基本操作来测试是否已正确安装所有组件:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

输出将如下所示:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() 显示给定切片中的芯片总数。jax.local_device_count() 表示此切片中的单个虚拟机可访问的芯片数量。

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

输出将如下所示:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

请尝试本文档中的 JAX 教程,以开始使用 JAX 进行 v5e 训练。

PyTorch 设置

请注意,v5e 仅支持 PJRT 运行时,并且 PyTorch 2.1+ 将 PJRT 用作所有 TPU 版本的默认运行时。

本部分介绍了如何开始结合使用 v5e 上的 PJRT 与 PyTorch/XLA,并为所有工作器提供命令。

安装依赖项

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip install mkl mkl-include
      pip install tf-nightly tb-nightly tbp-nightly
      pip install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

PYTORCH_VERSION 替换为您要使用的 PyTorch 版本。PYTORCH_VERSION 用于为 PyTorch/XLA 指定相同的版本。建议使用 2.6.0。

如需详细了解 PyTorch 和 PyTorch/XLA 的版本,请参阅 PyTorch - 使用入门PyTorch/XLA 版本

如需详细了解如何安装 PyTorch/XLA,请参阅 PyTorch/XLA 安装

如果您在安装 torchtorch_xlatorchvision(如 pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222)的 wheel 时遇到错误,请使用以下命令降级您的版本:

pip3 install setuptools==62.1.0

使用 PJRT 运行脚本

unset LD_PRELOAD

以下示例展示了如何使用 Python 脚本对 v5e 虚拟机执行计算:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

这将生成如下所示的输出:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

请尝试使用本文档中的 PyTorch 教程,开始使用 PyTorch 进行 v5e 训练。

在会话结束时删除 TPU 和已排队的资源。如需删除已排队的资源,请按以下 2 个步骤删除切片,然后删除已排队的资源:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

这两个步骤还可用于移除处于 FAILED 状态的已排队资源请求。

JAX/FLAX 示例

以下部分介绍了如何在 TPU v5e 上训练 JAX 和 FLAX 模型的示例。

在 v5e 上训练 ImageNet

本教程介绍了如何使用虚构的输入数据在 v5e 上训练 ImageNet。如果您想使用真实数据,请参阅 GitHub 上的自述文件

设置

  1. 创建环境变量:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    环境变量说明

    变量 说明
    PROJECT_ID 您的 Google Cloud 项目 ID。使用现有项目或创建新项目
    TPU_NAME TPU 的名称。
    ZONE 要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区
    ACCELERATOR_TYPE 加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    RUNTIME_VERSION Cloud TPU 软件版本
    SERVICE_ACCOUNT 您的服务账号的邮箱。您可以前往 Google Cloud 控制台中的“服务账号”页面找到该账号。

    例如:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 已排队的资源请求的用户分配文本 ID。

  2. 创建 TPU 资源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    已排队的资源处于 ACTIVE 状态后,您将能够通过 SSH 连接到 TPU 虚拟机:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    当已排队的资源处于 ACTIVE 状态时,输出将类似于以下内容:

     state: ACTIVE
    
  3. 安装最新版本的 JAX 和 jaxlib:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. 克隆 ImageNet 模型并安装相应要求:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
    
  5. 为了生成虚构数据,模型需要了解数据集的维度。您可以从 ImageNet 数据集的元数据中收集这些信息:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
    

训练模型

完成前面的所有步骤后,您可以训练模型。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

删除 TPU 和已排队的资源

在会话结束时删除 TPU 和已排队的资源。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hugging Face FLAX 模型

在 FLAX 中实现的 Hugging Face 模型在 Cloud TPU v5e 上开箱即用。本部分介绍了如何运行热门模型。

在 Imagenette 上训练 ViT

本教程介绍如何在 Cloud TPU v5e 上使用 Fast AI Imagenette 数据集训练来自 HuggingFace 的 Vision Transformer (ViT) 模型。

ViT 模型是第一个在 ImageNet 上成功训练 Transformer 编码器的模型,与卷积网络相比,可获得出色的效果。如需了解详情,请参阅以下资源:

设置

  1. 创建环境变量:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    环境变量说明

    变量 说明
    PROJECT_ID 您的 Google Cloud 项目 ID。使用现有项目或创建新项目
    TPU_NAME TPU 的名称。
    ZONE 要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区
    ACCELERATOR_TYPE 加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    RUNTIME_VERSION Cloud TPU 软件版本
    SERVICE_ACCOUNT 您的服务账号的邮箱。您可以前往 Google Cloud 控制台中的“服务账号”页面找到该账号。

    例如:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 已排队的资源请求的用户分配文本 ID。

  2. 创建 TPU 资源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    已排队的资源处于 ACTIVE 状态后,您将能够通过 SSH 连接到 TPU 虚拟机:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    当已排队的资源处于 ACTIVE 状态时,输出将类似于以下内容:

     state: ACTIVE
    
  3. 安装 JAX 及其库:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. 下载 Hugging Face 仓库并安装要求:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
    
  5. 下载 Imagenette 数据集:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

训练模型

使用 4GB 预映射缓冲区训练模型。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

删除 TPU 和已排队的资源

在会话结束时删除 TPU 和已排队的资源。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

ViT 基准测试结果

训练脚本在 v5litepod-4、v5litepod-16 和 v5litepod-64 上运行。下表显示了不同加速器类型的吞吐量。

加速器类型 v5litepod-4 v5litepod-16 v5litepod-64
周期 3 3 3
全局批量大小 32 128 512
吞吐量(样本/秒) 263.40 429.34 470.71

在 Pokémon 上训练 Diffusion

本教程介绍如何在 Cloud TPU v5e 上使用 Pokémon 数据集训练来自 HuggingFace 的 Stable Diffusion 模型。

Stable Diffusion 模型是一种潜在文本转图片模型,可根据任何文本输入生成逼真图片。如需了解详情,请参阅以下资源:

设置

  1. 为您的存储桶名称设置环境变量:

    export GCS_BUCKET_NAME=your_bucket_name
  2. 为模型输出设置存储桶:

    gcloud storage buckets create gs://GCS_BUCKET_NAME \
        --project=your_project \
        --location=us-west1
  3. 创建环境变量:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west1-c
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    环境变量说明

    变量 说明
    PROJECT_ID 您的 Google Cloud 项目 ID。使用现有项目或创建新项目
    TPU_NAME TPU 的名称。
    ZONE 要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区
    ACCELERATOR_TYPE 加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    RUNTIME_VERSION Cloud TPU 软件版本
    SERVICE_ACCOUNT 您的服务账号的邮箱。您可以前往 Google Cloud 控制台中的“服务账号”页面找到该账号。

    例如:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 已排队的资源请求的用户分配文本 ID。

  4. 创建 TPU 资源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    已排队的资源处于 ACTIVE 状态后,您将能够通过 SSH 连接到 TPU 虚拟机:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    当已排队的资源处于 ACTIVE 状态时,输出将类似于以下内容:

     state: ACTIVE
    
  5. 安装 JAX 及其库。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  6. 下载 HuggingFace 仓库并安装要求。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
         --project=${PROJECT_ID} \
         --zone=${ZONE} \
         --worker=all \
         --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
    

训练模型

使用 4GB 预映射缓冲区训练模型。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip3 install -r requirements.txt
    pip3 install .
    pip3 install gcsfs
    export LIBTPU_INIT_ARGS=''
    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

清理

在会话结束时删除 TPU、已排队的资源和 Cloud Storage 存储桶。

  1. 删除 TPU:

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  2. 删除已排队的资源:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  3. 删除 Cloud Storage 存储桶:

    gcloud storage rm -r gs://${GCS_BUCKET_NAME}
    

Diffusion 的基准测试结果

训练脚本在 v5litepod-4、v5litepod-16 和 v5litepod-64 上运行。下表显示了吞吐量。

加速器类型 v5litepod-4 v5litepod-16 v5litepod-64
训练步数 1500 1500 1500
全局批量大小 32 64 128
吞吐量(样本/秒) 36.53 43.71 49.36

PyTorch/XLA

以下部分介绍了如何在 TPU v5e 上训练 PyTorch/XLA 模型的示例。

使用 PJRT 运行时训练 ResNet

从 PyTorch 2.0+ 开始,PyTorch/XLA 将从 XRT 迁移到 PjRt。以下是更新后的说明,介绍如何为 PyTorch/XLA 训练工作负载设置 v5e。

设置
  1. 创建环境变量:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    环境变量说明

    变量 说明
    PROJECT_ID 您的 Google Cloud 项目 ID。使用现有项目或创建新项目
    TPU_NAME TPU 的名称。
    ZONE 要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区
    ACCELERATOR_TYPE 加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    RUNTIME_VERSION Cloud TPU 软件版本
    SERVICE_ACCOUNT 您的服务账号的邮箱。您可以前往 Google Cloud 控制台中的“服务账号”页面找到该账号。

    例如:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 已排队的资源请求的用户分配文本 ID。

  2. 创建 TPU 资源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    已排队的资源处于 ACTIVE 状态后,您将能够通过 SSH 连接到 TPU 虚拟机:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    当已排队的资源处于 ACTIVE 状态时,输出将类似于以下内容:

     state: ACTIVE
    
  3. 安装 Torch/XLA 特有依赖项

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --worker=all \
      --command='
         sudo apt-get update -y
         sudo apt-get install libomp5 -y
         pip3 install mkl mkl-include
         pip3 install tf-nightly tb-nightly tbp-nightly
         pip3 install numpy
         sudo apt-get install libopenblas-dev -y
         pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    PYTORCH_VERSION 替换为您要使用的 PyTorch 版本。PYTORCH_VERSION 用于为 PyTorch/XLA 指定相同的版本。建议使用 2.6.0。

    如需详细了解 PyTorch 和 PyTorch/XLA 的版本,请参阅 PyTorch - 使用入门PyTorch/XLA 版本

    如需详细了解如何安装 PyTorch/XLA,请参阅 PyTorch/XLA 安装

训练 ResNet 模型
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git checkout release-r2.6
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 —num_workers=16  --log_steps=300 --batch_size=64 --profile'

删除 TPU 和已排队的资源

在会话结束时删除 TPU 和已排队的资源。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
基准结果

下表显示了基准吞吐量。

加速器类型 吞吐量(样本/秒)
v5litepod-4 4240 ex/s
v5litepod-16 10,810 ex/s
v5litepod-64 46,154 ex/s

在 v5e 上训练 ViT

本教程介绍如何在 cifar10 数据集上使用 PyTorch/XLA 上的 HuggingFace 仓库在 v5e 上运行 VIT。

设置

  1. 创建环境变量:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    环境变量说明

    变量 说明
    PROJECT_ID 您的 Google Cloud 项目 ID。使用现有项目或创建新项目
    TPU_NAME TPU 的名称。
    ZONE 要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区
    ACCELERATOR_TYPE 加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    RUNTIME_VERSION Cloud TPU 软件版本
    SERVICE_ACCOUNT 您的服务账号的邮箱。您可以前往 Google Cloud 控制台中的“服务账号”页面找到该账号。

    例如:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 已排队的资源请求的用户分配文本 ID。

  2. 创建 TPU 资源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    已排队的资源处于 ACTIVE 状态后,您将能够通过 SSH 连接到 TPU 虚拟机:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    当已排队的资源处于 ACTIVE 状态时,输出将类似于以下内容:

     state: ACTIVE
    
  3. 安装 PyTorch/XLA 依赖项

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
      pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    PYTORCH_VERSION 替换为您要使用的 PyTorch 版本。PYTORCH_VERSION 用于为 PyTorch/XLA 指定相同的版本。建议使用 2.6.0。

    如需详细了解 PyTorch 和 PyTorch/XLA 的版本,请参阅 PyTorch - 使用入门PyTorch/XLA 版本

    如需详细了解如何安装 PyTorch/XLA,请参阅 PyTorch/XLA 安装

  4. 下载 HuggingFace 仓库并安装要求。

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

训练模型

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

删除 TPU 和已排队的资源

在会话结束时删除 TPU 和已排队的资源。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

基准结果

下表显示了不同加速器类型的基准吞吐量。

v5litepod-4 v5litepod-16 v5litepod-64
周期 3 3 3
全局批量大小 32 128 512
吞吐量(样本/秒) 201 657 2,844