Trillium (v6e) 簡介

在本說明文件、TPU API 和記錄中,v6e 是指 Trillium。v6e 代表 Google 第 6 代 TPU。

每個 Pod 有 256 個晶片,因此 v6e 架構與 v5e 有許多相似之處。這個系統經過最佳化,可用於轉換器、文字轉圖像和卷積類神經網路 (CNN) 的訓練、微調和服務。

如要進一步瞭解 v6e 系統架構和設定,請參閱 TPU v6e

這份簡介文件著重於使用 JAXPyTorch 架構訓練及提供模型的程序。您可以使用佇列資源或 GKE,透過每個架構來佈建 TPU。您可以使用 XPK 或 GKE 指令設定 GKE。

使用 v6e 訓練或提供模型的一般程序

  1. 準備 Google Cloud 專案
  2. 安全容量
  3. 佈建 Cloud TPU 環境
  4. 執行模型訓練推論工作負載

準備 Google Cloud 專案

您必須先完成下列步驟,才能使用 Cloud TPU:

  • 建立 Google Cloud 帳戶和專案並啟用計費功能
  • 安裝 Google Cloud CLI alpha 元件
  • 啟用 Cloud TPU API
  • 建立 Cloud TPU 服務代理
  • 建立 Cloud TPU 服務帳戶並授予權限

詳情請參閱「設定 Cloud TPU 環境」。

安全容量

如要申請 Cloud TPU v6e 配額,或解答任何有關容量的問題,請與Google Cloud 支援團隊聯絡。

佈建 Cloud TPU 環境

您可以使用 GKE、GKE 和 XPK (GKE 上的包裝 CLI 工具),或以佇列資源的方式來配置及管理 v6e Cloud TPU。

必要條件

  • 請確認專案具備足夠的 TPUS_PER_TPU_FAMILY 配額,這項配額會指定您可在 Google Cloud專案中存取的晶片數量上限。
  • 我們已使用下列設定測試 v6e:
    • Python 3.10 以上版本
    • 夜間軟體版本:
      • Nightly JAX 0.4.32.dev20240912
      • 夜間版本 LibTPU 0.1.dev20240912+nightly
    • 穩定版軟體版本:
      • JAX + JAX Lib 0.4.37 版
  • 請確認您的專案具備充足的配額,可供下列項目使用:

    • Cloud TPU VM 配額
    • IP 位址配額
    • Hyperdisk Balanced 和其他要使用的磁碟類型的配額

  • 如果您要搭配 XPK 使用 GKE,請參閱「使用者或服務帳戶的 Cloud Console 權限」,瞭解執行 XPK 所需的權限。

建立環境變數

在 Cloud Shell 中建立下列環境變數:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

指令旗標說明

變數 說明
NODE_ID 在分配佇列資源要求時,使用者指派的 Cloud TPU ID。
PROJECT_ID Google Cloud 專案名稱。使用現有專案或建立新專案。詳情請參閱「設定 Google Cloud 專案」。
可用區 如要瞭解支援的區域,請參閱 Cloud TPU 地區和區域說明文件。
ACCELERATOR_TYPE 請參閱「加速器類型」。
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT 這是服務帳戶的電子郵件地址,您可以在 Google Cloud 控制台 ->「IAM」->「服務帳戶」中找到這項資訊。

例如:tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES 要建立的切片數量 (僅限多切片)。
QUEUED_RESOURCE_ID 排入佇列的資源要求,由使用者指定的文字 ID。
VALID_DURATION 佇列資源要求的有效時間。
NETWORK_NAME 要使用的次要網路名稱。
NETWORK_FW_NAME 要使用的次要網路防火牆名稱。

提升網路效能

為獲得最佳效能,請使用 8,896 MTU (最大傳輸單位) 的網路。

根據預設,虛擬私有雲 (VPC) 只提供 1,460 個位元組的 MTU,這會導致網路效能不佳。您可以將虛擬私人雲端網路的 MTU 設為 1,300 位元組至 8,896 位元組之間的任何值 (含括)。常見的自訂 MTU 大小為 1,500 位元組 (標準乙太網路) 或 8,896 位元組 (可能的最大值)。詳情請參閱「有效的虛擬私有雲網路 MTU 大小」。

如要進一步瞭解如何變更現有或預設網路的 MTU 設定,請參閱「變更 VPC 網路的 MTU 設定」。

以下範例會建立 MTU 為 8,896 的網路。

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
 --allow tcp,icmp,udp --project=${PROJECT_ID}

使用多個 NIC (多配量選項)

使用多切片環境時,次要子網路需要下列環境變數。

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

使用下列指令,為網路和子網路建立自訂 IP 路由。

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
  --project=${PROJECT_ID} \
  --network=${NETWORK_NAME_2} \
  --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
  --router=${ROUTER_NAME} \
  --region=${REGION} \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project=${PROJECT_ID} \
  --enable-logging

建立多網路區塊後,您可以設定 XPK 叢集,並在 XPK 工作負載建立指令中加入 --command ifconfig 標記,驗證是否有使用兩張網路介面卡 (NIC)。

使用下列 xpk workload 指令,在 Google Cloud 主控台記錄中顯示 ifconfig 指令的輸出內容,並確認 eth0 和 eth1 都具有 mtu=8896。

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

如要啟用偵錯記錄或使用 Vertex AI TensorBoard,請將下列選用引數新增至指令:

    --enable-debug-logs \
    --use-vertex-tensorboard

確認 eth0 和 eth1 都設有 mtu=8,896。如要確認多個 NIC 是否正在執行,請在 XPK 工作負載建立指令中加入 --command ifconfig 標記。在 Google Cloud 主控台記錄中查看該 xpk 工作負載的輸出內容,並確認 eth0 和 eth1 都具有 mtu=8896。

改善 TCP 設定

如果您是使用排序資源介面建立 Cloud TPU,可以執行下列指令,藉由提高 TCP 接收緩衝區限制來改善網路效能。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "${PROJECT_ID}" \
  --zone "${ZONE}" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

使用排入佇列的資源進行佈建

您可以使用排入佇列的資源建立 Cloud TPU v6e。排隊資源可讓您在可用容量釋出時收到容量。您可以指定要填寫要求的選用開始和結束時間。詳情請參閱「管理排隊中的資源」。

使用 GKE 或 XPK 配置 v6e Cloud TPU

如果您使用 v6e 搭配 GKE 指令,可以使用 Kubernetes 指令或 XPK 來佈建 Cloud TPU,並訓練或提供模型。請參閱「在 GKE 中規劃 Cloud TPU」,瞭解如何在 GKE 叢集中規劃 Cloud TPU 設定。下列各節提供指令,可用於建立支援單一 NIC 和多個 NIC 的 XPK 叢集。

建立支援單一 NIC 的 XPK 叢集

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

指令旗標說明

變數 說明
CLUSTER_NAME 使用者指派的 XPK 叢集名稱。
PROJECT_ID Google Cloud 專案名稱。使用現有專案或建立新專案。詳情請參閱「設定 Google Cloud 專案」。
可用區 如要瞭解支援的區域,請參閱 Cloud TPU 地區和區域說明文件。
TPU_TYPE 請參閱「加速器類型」。
NUM_SLICES 要建立的切片數量
CLUSTER_ARGUMENTS 要使用的網路和子網路。

例如:--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES 要建立的片段數量。
NETWORK_NAME 要使用的次要網路名稱。
NETWORK_FW_NAME 要使用的次要網路防火牆名稱。

建立支援多個 NIC 的 XPK 叢集

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_1} \
    --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
     --project=${PROJECT_ID} \
     --network=${NETWORK_NAME_2} \
     --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
    --cluster=${CLUSTER_NAME} \
    --cluster-cpu-machine-type=e2-standard-8 \
    --num-slices=${NUM_SLICES} \
    --tpu-type=${TPU_TYPE} \
    --zone=${ZONE}  \
    --project=${PROJECT_ID} \
    --on-demand \
    --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
    --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
    --create-vertex-tensorboard

指令旗標說明

變數 說明
CLUSTER_NAME 使用者指派的 XPK 叢集名稱。
PROJECT_ID Google Cloud 專案名稱。使用現有專案或建立新專案。詳情請參閱「設定 Google Cloud 專案」。
可用區 如要瞭解支援的區域,請參閱 Cloud TPU 地區和區域說明文件。
TPU_TYPE 請參閱「加速器類型」。
NUM_SLICES 要建立的切片數量
CLUSTER_ARGUMENTS 要使用的網路和子網路。

例如:--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS 要使用的額外節點網路。

例如:--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES 要建立的切片數量 (僅限多切片)。
NETWORK_NAME 要使用的次要網路名稱。
NETWORK_FW_NAME 要使用的次要網路防火牆名稱。

設定架構

本節將說明使用 JAXPyTorch 架構訓練機器學習模型的一般設定程序。如果您使用的是 GKE,可以使用 XPK 或 Kubernetes 指令設定架構。

設定 JAX

本節提供在 GKE 上執行 JAX 工作負載的設定說明 (不論是否使用 XPK),以及使用排隊資源的說明。

使用 GKE 設定 JAX

單一主機上的單一切片

以下範例會使用 Kubernetes YAML 檔案設定 2x2 單主機節點集區。

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

成功完成後,您應該會在 GKE 記錄檔中看到以下訊息:

Total TPU chips: 4

多主機上的單一切片

以下範例會使用 Kubernetes YAML 檔案設定 4x4 多主機節點集區。

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

成功完成後,您應該會在 GKE 記錄檔中看到以下訊息:

Total TPU chips: 16

多主機上的多配量

以下範例會使用 Kubernetes YAML 檔案設定兩個 4x4 多主機節點集區。

前置條件:您必須安裝 JobSet 0.2.3 以上版本。

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

成功完成後,您應該會在 GKE 記錄檔中看到以下訊息:

Total TPU chips: 32

如需更多資訊,請參閱 GKE 說明文件中的「執行多切片工作負載」。

為提升效能,請啟用 hostNetwork

多 NIC

如要使用下列多 NIC 資訊清單,您必須設定網路。詳情請參閱「設定 Kubernetes Pod 的多網路支援功能」。如要充分利用 GKE 中的多 NIC,您必須在 Kubernetes Pod 資訊清單中加入一些額外的註解。以下是非 TPU 多 NIC 工作負載範例資訊清單。

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

如果您使用 exec 指令連線至 Kubernetes Pod,應會看到使用下列程式碼的額外 NIC。

$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

使用 GKE 搭配 XPK 設定 JAX

如要使用 GKE 和 XPK 設定 JAX,請參閱 xpk README

如要使用 MaxText 設定及執行 XPK,請參閱「如何執行 MaxText」。

使用排入佇列的資源設定 JAX

使用 gcloud alpha compute tpus tpu-vm ssh 指令,同時在單一或多個配方的所有 Cloud TPU VM 上安裝 JAX。如為多切片,請新增 --node=all 標記。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

您可以執行下列指令,查看切片中可用的 Cloud TPU 核心數量,並測試所有資源是否正確安裝:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

在 v6e-16 切片上執行時,輸出結果會類似以下內容:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() 會顯示指定切片中的方塊總數。jax.local_device_count() 會指出單一 VM 在這個切片中可存取的晶片數量。


 gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   pip install setuptools==59.6.0 &&
   pip install -r requirements.txt  && pip install . '

排解 JAX 設定問題

一般建議是在 GKE 工作負載資訊清單中啟用詳細記錄功能。然後將記錄檔提供給 GKE 支援團隊。

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

錯誤訊息

no endpoints available for service 'jobset-webhook-service'

這項錯誤表示工作集未正確安裝。檢查 jobset-controller-manager 部署的 Kubernetes Pod 是否正在執行。詳情請參閱JobSet 疑難排解說明文件

TPU initialization failed: Failed to connect

請確認您的 GKE 節點版本為 1.30.4-gke.1348000 以上版本 (不支援 GKE 1.31)。

設定 PyTorch

本節說明如何在 v6e 上使用 PyTorch/XLA 開始使用 PJRT。建議使用 Python 3.10 版本。

使用 GKE 和 XPK 設定 PyTorch

您可以使用下列 Docker 容器搭配 XPK,其中已安裝 PyTorch 依附元件:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

如要建立 XPK 工作負載,請使用下列指令:

python3 xpk.py workload create \
    --cluster ${CLUSTER_NAME} \
    {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name \
    --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --num-slices=${NUM_SLICES}  \
    --on-demand \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --enable-debug-logs \
    --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

使用 --base-docker-image 建立新的 Docker 映像檔,並將目前的工作目錄建構到新的 Docker 中。

使用排入佇列的資源設定 PyTorch

請按照下列步驟,使用排入佇列的資源安裝 PyTorch,並在 v6e 上執行小型指令碼。

使用 SSH 安裝依附元件,以便存取 VM

使用下列指令,在所有 Cloud TPU VM 上安裝依附元件。如要使用多切片,請新增 --worker=all 標記:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
               pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

改善經常分配大量資源的模型效能

對於有大量頻繁配置的模型,使用 tcmalloc 函式可大幅提升效能,相較於預設的 malloc 函式實作,因此在 Cloud TPU VM 上使用的預設 malloc 函式為 tcmalloc。不過,視工作負載而定 (例如,DLRM 為其嵌入式資料表分配的空間非常大),tcmalloc 函式可能會導致速度變慢,在這種情況下,您可以改用預設 malloc 函式來取消設定下列變數:

unset LD_PRELOAD

使用 Python 指令碼在 v6e VM 上執行計算

使用下列指令執行指令碼,建立兩個張量、將兩者加總,並列印結果。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

這麼做會產生類似以下內容的輸出結果:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

v6e 搭配 SkyPilot

您可以將 Cloud TPU v6e 與 SkyPilot 搭配使用。請按照下列步驟,在 SkyPilot 中新增 v6e 相關位置和定價資訊。詳情請參閱 GitHub 上的 SkyPilot TPU v6e 範例

推論教學課程

以下教學課程說明如何在 Cloud TPU v6e 上執行推論:

訓練範例

以下各節提供在 Cloud TPU v6e 上訓練 MaxText、MaxDiffusion 和 PyTorch 模型的範例。

在 v6e Cloud TPU VM 上進行 MaxText 和 MaxDiffusion 訓練

以下各節將說明 MaxTextMaxDiffusion 模型的訓練生命週期。

一般來說,高階步驟如下:

  1. 建構工作負載基本映像檔。
  2. 使用 XPK 執行工作負載。
    1. 為工作負載建構訓練指令。
    2. 部署工作負載。
  3. 追蹤工作負載並查看指標。
  4. 如果不需要 XPK 工作負載,請將其刪除。
  5. 不再需要 XPK 叢集時,請將其刪除。

建構基本映像檔

安裝 MaxText 或 MaxDiffusion,並建構 Docker 映像檔:

  1. 複製要使用的存放區,然後變更為存放區的目錄:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. 將 Docker 設定為使用 Google Cloud CLI:

    gcloud auth configure-docker
    
  3. 請使用下列指令或 JAX Stable Stack 建構 Docker 映像檔。如要進一步瞭解 JAX Stable Stack,請參閱「使用 JAX Stable Stack 建構 Docker 映像檔」。

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. 在有效的 gcloud CLI 設定中設定專案 ID:

    gcloud config set project ${PROJECT_ID}
    
  5. 如果您要從未在本機建構映像檔的電腦啟動工作負載,請上傳映像檔。

    1. 設定 CLOUD_IMAGE_NAME 環境變數:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. 上傳圖片:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

使用 XPK 執行工作負載

  1. 如果您不使用 MaxText 設定的預設值MaxDiffusion,請設定下列環境變數:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. 建立模型指令碼。這個指令碼會在後續步驟中複製為訓練指令。

    請先不要執行模型指令碼。

    MaxText

    MaxText 是一種高效能、高擴充性的開放原始碼 LLM,以純 Python 和 JAX 編寫,並鎖定 Google Cloud TPU 和 GPU 進行訓練和推論。

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma 是一系列開放權重 LLM,由 Google DeepMind 根據 Gemini 研究和技術開發而成。

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral 是 Mistral AI 開發的尖端 AI 模型,採用稀疏的專家混合 (MoE) 架構。

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama 是 Meta 開發的開放權重 LLM 系列。

    如需在 PyTorch 上執行 Llama3 的範例,請參閱 torchprime GitHub 存放區中的 torch_xla 模型

    MaxDiffusion

    MaxDiffusion 是各種隱含擴散模型的參考實作項目集合,這些模型是以純 Python 和 JAX 編寫,可在 XLA 裝置 (包括 Cloud TPU 和 GPU) 上執行。Stable Diffusion 是一種隱含的文字轉圖像模型,可根據任何文字輸入內容生成逼真的圖像。

    您需要安裝特定 Git 分支,才能執行 MaxDiffusion,如以下 git checkout 指令所示。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    訓練指令碼:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
        
  3. 匯出下列變數:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    環境變數說明

    變數 說明
    CLUSTER_NAME XPK 叢集的名稱。
    ACCELERATOR_TYPE 請參閱「加速器類型」。
    NUM_SLICES TPU 配量數量。
    YOUR_MODEL_SCRIPT 要做為訓練指令執行的模型指令碼。
  4. 使用先前步驟中建立的指令碼執行模型。您必須指定 --base-docker-image 標記,才能使用 MaxText 基礎圖片,或是指定 --docker-image 標記和要使用的圖片。

    選用:您可以加入 --enable-debug-logs 標記來啟用偵錯記錄功能。詳情請參閱「在 MaxText 上偵錯 JAX」。

    選用:您可以建立 Vertex AI 實驗,並加入 --use-vertex-tensorboard 標記,將資料上傳至 Vertex AI TensorBoard。詳情請參閱「使用 Vertex AI 監控 MaxText 上的 JAX」。

    python3 xpk.py workload create \
        --cluster ${CLUSTER_NAME} \
        {--base-docker-image maxtext_base_image|--docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command=${YOUR_MODEL_SCRIPT}

    輸出內容包含追蹤工作負載的連結。開啟連結,然後按一下「Logs」分頁標籤,即可即時追蹤工作負載。

在 MaxText 上偵錯 JAX

使用輔助 XPK 指令,診斷叢集或工作負載無法執行的原因。

使用 Vertex AI 監控 MaxText 上的 JAX

如要使用 TensorBoard,您的 Google Cloud 使用者帳戶必須具備 aiplatform.user 角色。執行下列指令來授予這些角色:

gcloud projects add-iam-policy-binding your-project-id \
--member='user:your-email' \
--role='roles/aiplatform.user'

透過 Vertex AI 的受管理 TensorBoard 查看向量和設定檔資料。

  1. 將您使用的區域的資源管理 (CRUD) 要求從 600 提高至 5000。對於使用少於 16 個 VM 的小型工作負載而言,這可能不是問題。
  2. 安裝 Vertex AI 的 cloud-accelerator-diagnostics 等依附元件:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. 如「建立 Vertex AI TensorBoard」一文所述,請使用 --create-vertex-tensorboard 標記建立 XPK 叢集。您也可以在現有叢集中執行這項指令。

  4. 使用 --use-vertex-tensorboard 標記和選用的 --experiment-name 標記執行 XPK 工作負載時,建立 Vertex AI 實驗。如需完整步驟清單,請參閱「建立 Vertex AI 實驗,將資料上傳至 Vertex AI TensorBoard」一文。

記錄包含 Vertex AI TensorBoard 的連結,類似以下內容:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

您也可以在 Google Cloud 控制台中找到 Vertex AI TensorBoard 連結。前往 Google Cloud 控制台中的 Vertex AI Experiments。從下拉式選單中選取適當的地區。

TensorBoard 目錄也會寫入您使用 ${BASE_OUTPUT_DIR} 指定的 Cloud Storage 值區。

刪除 XPK 工作負載

使用 xpk workload delete 指令,根據工作前置字串或工作狀態刪除一或多個工作負載。如果您傳送的 XPK 工作負載已無需執行,或是工作卡在佇列中,這項指令可能會很實用。

刪除 XPK 叢集

使用 xpk cluster delete 指令刪除叢集:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone=${ZONE} --project=${PROJECT_ID}

在 v6e Cloud TPU VM 上執行 Llama 和 PyTorch/XLA 訓練

本教學課程將說明如何在 Cloud TPU v6e 上使用 WikiText 資料集,透過 PyTorch/XLA 訓練 Llama 模型。

取得 Hugging Face 和 Llama 3 模型的存取權

您需要 Hugging Face 使用者存取權杖,才能執行本教學課程。如要瞭解如何建立使用者存取權杖,請參閱 Hugging Face 使用者存取權杖說明文件

您也需要取得 Hugging Face 的 Llama 3 8B 模型存取權。如要取得存取權,請前往 HuggingFace 上的 Meta-Llama-3-8B 模型並要求存取權。

建立 Cloud TPU VM

建立具有 8 個晶片的 Cloud TPU v6e,以便執行教學課程。

  1. 設定環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱「TPU 地區和區域」一文。
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 建立 Cloud TPU VM:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID}

安裝

安裝 Hugging Face 轉換器和依附元件的 pytorch-tpu/transformers 分支。本教學課程已使用以下範例中使用的依附元件版本進行測試:

  • torch:與 2.5.0 相容
  • torch_xla[tpu]:與 2.5.0 相容
  • jax:0.4.33
  • jaxlib:0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} --zone ${ZONE} \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

設定模型設定

下一節的訓練指令「Run the model」會使用兩個 JSON 設定檔定義模型參數和 FSDP (Fully Sharded Data Parallel) 設定。在訓練時,FSDP 區塊劃分可用於模型權重,以便配合較大的批次大小。使用較小模型進行訓練時,只要使用資料平行處理並在每部裝置上複製權重即可。如要進一步瞭解如何在 PyTorch/XLA 中跨裝置分割張量,請參閱 PyTorch/XLA SPMD 使用手冊

  1. 建立模型參數設定檔。以下是 Llama3-8B 的模型參數設定。如要查看其他模型的設定,請前往 Hugging Face。例如,請參閱 Llama2-7B 設定

    cat > llama-config.json << EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. 建立 FSDP 設定檔:

    cat > fsdp-config.json << EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    如要進一步瞭解 FSDP,請參閱 FSDPv2

  3. 使用下列指令,將設定檔上傳至 Cloud TPU VM:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

執行模型

使用您在上一節建立的設定檔,執行 run_clm.py 指令碼,在 WikiText 資料集上訓練 Llama 3 8B 模型。在 Cloud TPU v6e-8 上執行訓練指令碼大約需要 10 分鐘。

  1. 使用下列指令,在 Cloud TPU 上登入 Hugging Face:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. 執行模型訓練:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

PyTorch/XLA 疑難排解

如果您在前一個部分設定用於偵錯的選用變數,模型的設定檔會儲存在變數 PROFILE_LOGDIR 指定的位置。您可以擷取儲存在這個位置的 xplane.pb 檔案,並使用 tensorboard 搭配 TensorBoard 操作說明在瀏覽器中查看設定檔。如果 PyTorch/XLA 的效能不如預期,請參閱疑難排解指南,其中提供有關偵錯、設定檔和模型最佳化的建議。

基準化結果

以下部分包含 v6e 中 MaxDiffusion 的基準測試結果。

MaxDiffusion

我們在 v6e-4、v6e-16 和兩個 v6e-16 上執行 MaxDiffusion 訓練指令碼。請參閱下表中的吞吐量。

v6e-4 v6e-16 兩個 v6e-16
訓練步驟 0.069 0.073 0.13
全域批次大小 8 32 64
處理量 (例項/秒) 115.9 438.4 492.3