Trillium (v6e) 简介

在本文档、TPU API 和日志中,v6e 用于指代 Trillium。v6e 代表 Google 的第 6 代 TPU。

v6e 架构每个 Pod 包含 256 个芯片,与 v5e 有许多相似之处。此系统针对转换器、文本到图像和卷积神经网络 (CNN) 训练、微调和服务进行了优化。

如需详细了解 v6e 系统架构和配置,请参阅 TPU v6e

本文档简介将重点介绍使用 JAXPyTorch 框架进行模型训练和服务的过程。在每个框架中,您都可以使用队列化资源或 GKE 预配 TPU。您可以使用 XPK 或 GKE 命令进行 GKE 设置。

使用 v6e 训练或部署模型的一般流程

  1. 准备 Google Cloud 项目
  2. 安全容量
  3. 预配 Cloud TPU 环境
  4. 运行模型训练推理工作负载

准备 Google Cloud 项目

在使用 Cloud TPU 之前,您需要:

  • 创建 Google Cloud 已启用结算功能的账号和项目
  • 安装 Google Cloud CLI Alpha 版组件
  • 启用 Cloud TPU API
  • 创建 Cloud TPU 服务代理
  • 创建 Cloud TPU 服务账号并授予权限

如需了解详情,请参阅设置 Cloud TPU 环境

保障容量

如需申请 Cloud TPU v6e 配额,并解答与容量有关的任何问题,请与 Google Cloud 支持团队联系。

预配 Cloud TPU 环境

v6e Cloud TPU 可以使用 GKE、GKE 和 XPK(一种基于 GKE 的封装容器 CLI 工具)进行预配和管理,也可以作为队列化资源进行管理。

前提条件

  • 验证您的项目是否有足够的 TPUS_PER_TPU_FAMILY 配额,该配额指定您可以在 Google Cloud项目中访问的芯片数量上限。
  • v6e 已通过以下配置进行测试:
    • Python 3.10 或更高版本
    • 每夜软件版本:
      • 每夜 JAX 0.4.32.dev20240912
      • 每夜 LibTPU 0.1.dev20240912+nightly
    • 稳定版软件版本:
      • JAX + v0.4.37 的 JAX 库
  • 请验证您的项目是否有足够的配额来执行以下操作:

    • Cloud TPU 虚拟机配额
    • IP 地址配额
    • Hyperdisk Balanced 以及您要使用的任何其他磁盘类型的配额

  • 如果您将 GKE 与 XPK 搭配使用,请参阅用户或服务账号的 Cloud 控制台权限,了解运行 XPK 所需的权限。

创建环境变量

在 Cloud Shell 中,创建以下环境变量:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

命令标志说明

变量 说明
NODE_ID 用户分配给 Cloud TPU 的 ID,该 ID 在分配已排队的资源请求时创建。
PROJECT_ID Google Cloud 项目名称。使用现有项目或创建新项目。 如需了解详情,请参阅设置 Google Cloud 项目
ZONE 如需了解支持的区域,请参阅 Cloud TPU 区域和可用区文档。
ACCELERATOR_TYPE 请参阅加速器类型
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT 这是您的服务账号的电子邮件地址,您可以在 Google Cloud Console -> IAM -> 服务账号

例如:tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES 要创建的 Slice 的数量(仅适用于多 Slice)。
QUEUED_RESOURCE_ID 已加入队列的资源请求的用户分配的文本 ID。
VALID_DURATION 队列中资源请求的有效时长。
NETWORK_NAME 要使用的辅助网络的名称。
NETWORK_FW_NAME 要使用的次要网络防火墙的名称。

优化网络性能

为了获得最佳性能,请使用 8,896 MTU(最大传输单元)的网络。

默认情况下,虚拟私有云 (VPC) 仅提供 1,460 字节的 MTU,这会导致网络性能不佳。您可以将 VPC 网络的 MTU 设置为 1300 字节到 8896 字节之间(含边界值)的任何值。常见的自定义 MTU 大小为 1500 字节(标准以太网)或 8896 字节(可能的最大值)。如需了解详情,请参阅有效的 VPC 网络 MTU 大小

如需详细了解如何更改现有网络或默认网络的 MTU 设置,请参阅更改 VPC 网络的 MTU 设置

以下示例会创建一个 MTU 为 8,896 的网络。

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
 --allow tcp,icmp,udp --project=${PROJECT_ID}

使用多 NIC(适用于多切片)

使用多 slice 环境时,辅助子网需要以下环境变量。

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

使用以下命令为网络和子网创建自定义 IP 路由。

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
  --project=${PROJECT_ID} \
  --network=${NETWORK_NAME_2} \
  --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
  --router=${ROUTER_NAME} \
  --region=${REGION} \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project=${PROJECT_ID} \
  --enable-logging

创建多网络 slice 后,您可以通过设置 XPK 集群并将 --command ifconfig 标志添加到 XPK 工作负载创建命令,验证是否使用了两个网络接口卡 (NIC)。

使用以下 xpk workload 命令在控制台日志中显示 ifconfig 命令的输出,并检查 eth0 和 eth1 是否均为 mtu=8896。 Google Cloud

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

如果您想启用调试日志或使用 Vertex AI TensorBoard,请将以下可选参数添加到该命令中:

    --enable-debug-logs \
    --use-vertex-tensorboard

验证 eth0 和 eth1 是否均为 mtu=8,896。您可以通过向 XPK 工作负载创建命令添加 --command ifconfig 标志来验证多 NIC 是否正在运行。在控制台日志中检查该 xpk 工作负载的输出,并验证 eth0 和 eth1 的 mtu 均为 8896。 Google Cloud

改进 TCP 设置

如果您使用已排队的资源界面创建了 Cloud TPU,则可以运行以下命令,通过增加 TCP 接收缓冲区限制来提升网络性能。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "${PROJECT_ID}" \
  --zone "${ZONE}" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

使用已排队的资源进行预配

您可以使用排队资源创建 Cloud TPU v6e。通过加入队列的资源,您可以在容量可用时接收容量。您可以指定请求填充的开始时间和结束时间(可选)。如需了解详情,请参阅管理队列中的资源

使用 GKE 或 XPK 预配 v6e Cloud TPU

如果您将 GKE 命令与 v6e 搭配使用,则可以使用 Kubernetes 命令或 XPK 预配 Cloud TPU,以及训练或部署模型。如需了解如何在 GKE 集群中规划 Cloud TPU 配置,请参阅在 GKE 中规划 Cloud TPU。以下部分提供了用于创建支持单个 NIC 和多 NIC 的 XPK 集群的命令。

创建支持单个 NIC 的 XPK 集群

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

命令标志说明

变量 说明
CLUSTER_NAME XPK 集群的用户分配的名称。
PROJECT_ID Google Cloud 项目名称。使用现有项目或创建新项目。 如需了解详情,请参阅设置 Google Cloud 项目
ZONE 如需了解支持的区域,请参阅 Cloud TPU 区域和可用区文档。
TPU_TYPE 请参阅加速器类型
NUM_SLICES 您要创建的 Slice 的数量
CLUSTER_ARGUMENTS 要使用的网络和子网。

例如:--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES 要创建的切片数量。
NETWORK_NAME 要使用的辅助网络的名称。
NETWORK_FW_NAME 要使用的次要网络防火墙的名称。

创建支持多 NIC 的 XPK 集群

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_1} \
    --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
     --project=${PROJECT_ID} \
     --network=${NETWORK_NAME_2} \
     --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
    --cluster=${CLUSTER_NAME} \
    --cluster-cpu-machine-type=e2-standard-8 \
    --num-slices=${NUM_SLICES} \
    --tpu-type=${TPU_TYPE} \
    --zone=${ZONE}  \
    --project=${PROJECT_ID} \
    --on-demand \
    --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
    --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
    --create-vertex-tensorboard

命令标志说明

变量 说明
CLUSTER_NAME XPK 集群的用户分配的名称。
PROJECT_ID Google Cloud 项目名称。使用现有项目或创建新项目。 如需了解详情,请参阅设置 Google Cloud 项目
ZONE 如需了解支持的区域,请参阅 Cloud TPU 区域和可用区文档。
TPU_TYPE 请参阅加速器类型
NUM_SLICES 您要创建的 Slice 的数量
CLUSTER_ARGUMENTS 要使用的网络和子网。

例如:--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS 要使用的额外节点网络。

例如:--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES 要创建的 Slice 的数量(仅适用于多 Slice)。
NETWORK_NAME 要使用的辅助网络的名称。
NETWORK_FW_NAME 要使用的次要网络防火墙的名称。

框架设置

本部分介绍了使用 JAXPyTorch 框架进行机器学习模型训练的一般设置流程。如果您使用的是 GKE,则可以使用 XPK 或 Kubernetes 命令进行框架设置。

JAX 设置

本部分介绍了在 GKE 上运行 JAX 工作负载(无论是否使用 XPK)以及使用队列化资源的设置说明。

使用 GKE 设置 JAX

单个主机上的单个切片

以下示例使用 Kubernetes YAML 文件设置了 2x2 单主机节点池。

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

成功完成后,您应该会在 GKE 日志中看到以下消息:

Total TPU chips: 4

多主机上的单个切片

以下示例使用 Kubernetes YAML 文件设置了 4x4 多主机节点池。

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

成功完成后,您应该会在 GKE 日志中看到以下消息:

Total TPU chips: 16

多主机上的多切片

以下示例使用 Kubernetes YAML 文件设置了两个 4x4 多主机节点池。

前提是,您需要安装 JobSet v0.2.3 或更高版本。

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

成功完成后,您应该会在 GKE 日志中看到以下消息:

Total TPU chips: 32

如需了解详情,请参阅 GKE 文档中的运行多切片工作负载

为了获得更好的性能,请启用 hostNetwork

多 NIC

如需使用以下多 NIC 清单,您需要设置网络。如需了解详情,请参阅为 Kubernetes Pod 设置多网络支持。 如需在 GKE 中使用多 NIC,您必须在 Kubernetes Pod 清单中添加一些额外的注解。以下是非 TPU 多 NIC 工作负载示例清单。

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

如果您使用 exec 命令连接到 Kubernetes Pod,则应使用以下代码看到额外的 NIC。

$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

使用 GKE 搭配 XPK 设置 JAX

如需使用 GKE 和 XPK 设置 JAX,请参阅 xpk README

如需使用 MaxText 设置和运行 XPK,请参阅如何运行 MaxText

使用已排队的资源设置 JAX

使用 gcloud alpha compute tpus tpu-vm ssh 命令同时在切片中的所有 Cloud TPU 虚拟机上安装 JAX。对于多 Slice,请添加 --node=all 标志。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

您可以运行以下命令,检查您的 slice 中可用的 Cloud TPU 核心数量,并测试是否已正确安装所有组件:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

在 v6e-16 slice 上运行时,输出类似于以下内容:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() 显示给定 slice 中的芯片总数。jax.local_device_count() 表示此 slice 中单个虚拟机可访问的芯片数量。


 gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   pip install setuptools==59.6.0 &&
   pip install -r requirements.txt  && pip install . '

排查 JAX 设置问题

一般提示是在 GKE 工作负载清单中启用详细日志记录。然后,将日志提供给 GKE 支持团队。

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

错误消息

no endpoints available for service 'jobset-webhook-service'

此错误表示作业集未正确安装。检查 jobset-controller-manager 部署 Kubernetes Pod 是否正在运行。如需了解详情,请参阅 JobSet 问题排查文档

TPU initialization failed: Failed to connect

确保您的 GKE 节点版本为 1.30.4-gke.1348000 或更高版本(不支持 GKE 1.31)。

PyTorch 设置

本部分介绍了如何开始在 v6e 上使用 PyTorch/XLA 的 PJRT。建议使用 Python 3.10。

使用 GKE 搭配 XPK 设置 PyTorch

您可以将以下 Docker 容器与已安装 PyTorch 依赖项的 XPK 搭配使用:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

如需创建 XPK 工作负载,请使用以下命令:

python3 xpk.py workload create \
    --cluster ${CLUSTER_NAME} \
    {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name \
    --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --num-slices=${NUM_SLICES}  \
    --on-demand \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --enable-debug-logs \
    --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

使用 --base-docker-image 会创建一个新的 Docker 映像,并将当前工作目录内置到新的 Docker 中。

使用已排队的资源设置 PyTorch

请按照以下步骤使用队列化资源安装 PyTorch,并在 v6e 上运行一个小脚本。

使用 SSH 安装依赖项以访问虚拟机

使用以下命令在所有 Cloud TPU 虚拟机上安装依赖项。对于多 Slice,请添加 --worker=all 标志:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
               pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

提高具有可扩缩、频繁分配的模型的性能

对于具有可扩缩、频繁分配的模型,与使用默认的 malloc 函数实现相比,使用 tcmalloc 函数可以显著提升性能,因此 Cloud TPU VM 上默认使用的 malloc 函数是 tcmalloc。但是,根据您的工作负载(例如为其嵌入表进行了超大规模分配的 DLRM),tcmalloc 函数可能会造成运行缓慢,在这种情况下,您可以尝试设置以下变量以改为使用默认 malloc 函数:

unset LD_PRELOAD

使用 Python 脚本在 v6e 虚拟机上执行计算

使用以下命令运行一个脚本,该脚本会创建两个张量,将它们相加,然后输出结果。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

这将生成如下所示的输出:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

配备 SkyPilot 的 v6e

您可以将 Cloud TPU v6e 与 SkyPilot 搭配使用。请按照以下步骤向 SkyPilot 添加与 v6e 相关的位置和价格信息。如需了解详情,请参阅 SkyPilot TPU v6e 示例

推理教程

以下教程介绍了如何在 Cloud TPU v6e 上运行推理:

训练示例

以下部分提供了在 Cloud TPU v6e 上训练 MaxText、MaxDiffusion 和 PyTorch 模型的示例。

在 v6e Cloud TPU 虚拟机上进行 MaxText 和 MaxDiffusion 训练

以下部分介绍了 MaxTextMaxDiffusion 模型的训练生命周期。

一般而言,大致步骤如下:

  1. 构建工作负载基础映像。
  2. 使用 XPK 运行工作负载。
    1. 为工作负载构建训练命令。
    2. 部署工作负载。
  3. 跟踪工作负载并查看指标。
  4. 如果不需要 XPK 工作负载,请将其删除。
  5. 不再需要 XPK 集群时,请将其删除。

构建基础映像

安装 MaxText 或 MaxDiffusion 并构建 Docker 映像:

  1. 克隆要使用的代码库,然后切换到该代码库的目录:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. 将 Docker 配置为使用 Google Cloud CLI:

    gcloud auth configure-docker
    
  3. 使用以下命令或 JAX 稳定版堆栈构建 Docker 映像。如需详细了解 JAX Stable Stack,请参阅使用 JAX Stable Stack 构建 Docker 映像

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. 在有效的 gcloud CLI 配置中设置您的项目 ID:

    gcloud config set project ${PROJECT_ID}
    
  5. 如果您要从未在本地构建映像的机器启动工作负载,请上传映像。

    1. 设置 CLOUD_IMAGE_NAME 环境变量:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. 上传图片:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

使用 XPK 运行工作负载

  1. 如果您不使用 MaxText 设置的默认值MaxDiffusion,请设置以下环境变量:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. 构建模型脚本。在后续步骤中,此脚本将作为训练命令复制。

    暂时不要执行模型脚本。

    MaxText

    MaxText 是一个高性能、高度可伸缩的开源 LLM,采用纯 Python 和 JAX 编写,可在 TPU 和 GPU 上进行训练和推理。 Google Cloud

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma 是 Google DeepMind 基于 Gemini 研究和技术开发的一系列开放权重 LLM。

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral 是 Mistral AI 开发的利用稀疏混合专家 (MoE) 架构的先进 AI 模型。

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama 是由 Meta 开发的一系列开放权重 LLM。

    如需查看如何在 PyTorch 上运行 Llama3 的示例,请参阅 torchprime 代码库中的 torch_xla 模型

    MaxDiffusion

    MaxDiffusion 是一系列用纯 Python 和 JAX 编写的参考实现,其中包含在 XLA 设备(包括 Cloud TPU 和 GPU)上运行的各种潜在 diffusion 模型。Stable Diffusion 是一种潜在的文本到图像模型,可根据任何文本输入生成逼真的图片。

    您需要安装特定的 Git 分支才能运行 MaxDiffusion,如以下 git checkout 命令所示。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    训练脚本:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
        
  3. 导出以下变量:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    环境变量说明

    变量 说明
    CLUSTER_NAME XPK 集群的名称。
    ACCELERATOR_TYPE 请参阅加速器类型
    NUM_SLICES TPU 切片数量。
    YOUR_MODEL_SCRIPT 要作为训练命令执行的模型脚本。
  4. 使用您在上一步中创建的脚本运行模型。您必须指定 --base-docker-image 标志才能使用 MaxText 基础图片,或者指定 --docker-image 标志和要使用的图片。

    可选:您可以通过添加 --enable-debug-logs 标志来启用调试日志记录。如需了解详情,请参阅在 MaxText 上调试 JAX

    可选:您可以创建 Vertex AI 实验,通过添加 --use-vertex-tensorboard 标志将数据上传到 Vertex AI TensorBoard。如需了解详情,请参阅使用 Vertex AI 监控 MaxText 上的 JAX

    python3 xpk.py workload create \
        --cluster ${CLUSTER_NAME} \
        {--base-docker-image maxtext_base_image|--docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command=${YOUR_MODEL_SCRIPT}

    输出包含用于跟踪工作负载的链接。打开链接,然后点击日志标签页以实时跟踪工作负载。

在 MaxText 上调试 JAX

使用补充 XPK 命令诊断集群或工作负载未运行的原因。

使用 Vertex AI 监控 MaxText 上的 JAX

如需使用 TensorBoard,您的用户账号必须具有 aiplatform.user 角色。 Google Cloud 运行以下命令以授予这些角色:

gcloud projects add-iam-policy-binding your-project-id \
--member='user:your-email' \
--role='roles/aiplatform.user'

通过 Vertex AI 的托管式 TensorBoard 查看标量和性能数据。

  1. 将您所用可用区的资源管理 (CRUD) 请求次数从 600 提高到 5,000。对于使用少于 16 个虚拟机的小型工作负载,这可能不是问题。
  2. 为 Vertex AI 安装 cloud-accelerator-diagnostics 等依赖项:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. 使用 --create-vertex-tensorboard 标志创建 XPK 集群,如创建 Vertex AI TensorBoard 中所述。您也可以在现有集群上运行此命令。

  4. 在运行 XPK 工作负载时,使用 --use-vertex-tensorboard 标志和可选的 --experiment-name 标志创建 Vertex AI 实验。如需查看完整步骤列表,请参阅创建 Vertex AI 实验以将数据上传到 Vertex AI TensorBoard

日志包含指向 Vertex AI TensorBoard 的链接,如下所示:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

您还可以在 Google Cloud 控制台中找到 Vertex AI TensorBoard 链接。 在 Google Cloud 控制台中前往 Vertex AI Experiments。从下拉菜单中选择适当的区域。

TensorBoard 目录也会写入您使用 ${BASE_OUTPUT_DIR} 指定的 Cloud Storage 存储桶。

删除 XPK 工作负载

您可以使用 xpk workload delete 命令根据作业前缀或作业状态删除一个或多个工作负载。如果您发送的 XPK 工作负载不再需要运行,或者有作业卡在队列中,此命令可能会很有用。

删除 XPK 集群

使用 xpk cluster delete 命令删除集群:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone=${ZONE} --project=${PROJECT_ID}

在 v6e Cloud TPU 虚拟机上进行 Llama 和 PyTorch/XLA 训练

本教程介绍了如何使用 WikiText 数据集在 Cloud TPU v6e 上使用 PyTorch/XLA 训练 Llama 模型。

获取对 Hugging Face 和 Llama 3 模型的访问权限

您需要 Hugging Face 用户访问令牌才能运行本教程。如需了解如何创建和使用访问令牌,请参阅 Hugging Face 文档中的用户访问令牌部分

您还需要有权访问 Hugging Face 上的 Llama 3 8B 模型。如需获取访问权限,请前往 HuggingFace 上的 Meta-Llama-3-8B 模型并请求访问权限。

创建 Cloud TPU 虚拟机

创建一个包含 8 个芯片的 Cloud TPU v6e 来运行本教程。

  1. 设置环境变量:

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT_ID=your-project-id
    export ZONE=us-east1-d
  2. 创建 Cloud TPU 虚拟机:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID}

安装

安装 Hugging Face Transformer 的 pytorch-tpu/transformers 分支及其依赖项。本教程是使用以下示例中使用的依赖项版本进行测试的:

  • torch:与 2.5.0 兼容
  • torch_xla[tpu]:与 2.5.0 兼容
  • jax:0.4.33
  • jaxlib:0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} --zone ${ZONE} \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

设置模型配置

下一部分(运行模型)中的训练命令使用两个 JSON 配置文件来定义模型参数和 FSDP(完全分片数据并行)配置。FSDP 分片用于模型权重,以便在训练过程中适应更大的批次大小。使用较小模型进行训练时,使用数据并行处理并在每台设备上复制权重可能就足够了。如需详细了解如何在 PyTorch/XLA 中跨设备分片张量,请参阅 PyTorch/XLA SPMD 用户指南

  1. 创建模型参数配置文件。以下是 Llama3-8B 的模型参数配置。对于其他模型,请在 Hugging Face 上查找配置。例如,请参阅 Llama2-7B 配置

    cat > llama-config.json << EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. 创建 FSDP 配置文件:

    cat > fsdp-config.json << EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    如需详细了解 FSDP,请参阅 FSDPv2

  3. 使用以下命令将配置文件上传到 Cloud TPU 虚拟机:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

运行模型

使用您在上一部分中创建的配置文件,运行 run_clm.py 脚本,以便在 WikiText 数据集上训练 Llama 3 8B 模型。训练脚本在 Cloud TPU v6e-8 上大约需要 10 分钟才能运行完毕。

  1. 使用以下命令在 Cloud TPU 上登录 Hugging Face:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. 运行模型训练:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

PyTorch/XLA 问题排查

如果您在上一部分中设置了用于调试的可选变量,则模型的配置文件将存储在变量 PROFILE_LOGDIR 指定的位置。您可以提取存储在此位置的 xplane.pb 文件,并使用 tensorboard 按照 TensorBoard 说明在浏览器中查看配置文件。如果 PyTorch/XLA 的性能不符合预期,请参阅问题排查指南,其中提供了有关调试、性能分析和优化模型的建议。

基准测试结果

以下部分包含 v6e 上 MaxDiffusion 的基准测试结果。

MaxDiffusion

我们在 v6e-4、v6e-16 和两个 v6e-16 上运行了 MaxDiffusion 的训练脚本。请参阅下表中的吞吐量。

v6e-4 v6e-16 两个 v6e-16
训练步骤 0.069 0.073 0.13
全局批次大小 8 32 64
吞吐量(示例/秒) 115.9 438.4 492.3