Trillium(v6e)の概要

v6e は、このドキュメント、TPU API、ログで Trillium を指すために使用されます。v6e は Google の第 6 世代の TPU を表します。

Pod あたり 256 個のチップを備えた v6e アーキテクチャは、v5e と多くの類似点があります。このシステムは、トランスフォーマー、text-to-image、畳み込みニューラル ネットワーク(CNN)のトレーニング、微調整、サービス提供に最適化されています。

v6e システム アーキテクチャと構成の詳細については、TPU v6e をご覧ください。

この概要ドキュメントでは、JAXPyTorch TensorFlow フレームワークを使用したモデルのトレーニングとサービス提供のプロセスについて説明します。各フレームワークで、キューに入れられたリソースまたは GKE を使用して TPU をプロビジョニングできます。GKE の設定は、XPK または GKE コマンドを使用して行うことができます。

v6e を使用してモデルをトレーニングまたはサービングする一般的な手順

  1. Google Cloud プロジェクトを準備する
  2. 容量を確保する
  3. Cloud TPU 環境をプロビジョニングする
  4. モデルのトレーニングまたは推論ワークロードを実行する

Google Cloud プロジェクトを準備する

Cloud TPU を使用するには、次の操作が必要です。

  • 課金を有効にした Google Cloud アカウントとプロジェクトを作成する
  • Google Cloud CLI アルファ版コンポーネントをインストールする
  • Cloud TPU API を有効にする
  • Cloud TPU サービス エージェントを作成する
  • Cloud TPU サービス アカウントを作成して権限を付与する

詳細については、Cloud TPU 環境を設定するをご覧ください。

容量を確保する

Cloud TPU v6e の割り当てをリクエストし、容量に関する質問に回答するには、Google Cloud サポートにお問い合わせください。

Cloud TPU 環境をプロビジョニングする

v6e Cloud TPU は、GKE、GKE と XPK(GKE をラップする CLI ツール)、またはキューに入れられたリソースでプロビジョニングして管理できます。

前提条件

  • プロジェクトに十分な TPUS_PER_TPU_FAMILY 割り当てがあることを確認します。これは、プロジェクト内でアクセスできるチップの最大数を指定します。 Google Cloud
  • v6e は、次の構成でテストされています。
    • Python 3.10 以降
    • ナイトリー ソフトウェアのバージョン:
      • ナイトリー JAX 0.4.32.dev20240912
      • ナイトリー LibTPU 0.1.dev20240912+nightly
    • 安定版ソフトウェア バージョン:
      • JAX + v0.4.37 の JAX Lib
  • プロジェクトに次の割り当てが十分にあることを確認します。

    • Cloud TPU VM の割り当て
    • IP アドレスの割り当て
    • Hyperdisk Balanced の割り当て

  • XPK で GKE を使用している場合は、XPK の実行に必要な権限について、ユーザーまたはサービス アカウントに対する Cloud コンソールの権限をご覧ください。

環境変数を作成する

Cloud Shell で、次の環境変数を作成します。

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

コマンドフラグの説明

変数 説明
NODE_ID キューに入れられたリソース リクエストの割り当て時に作成される Cloud TPU のユーザー割り当て ID。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。詳細については、プロジェクトを設定する Google Cloud をご覧ください。
ゾーン サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。
ACCELERATOR_TYPE アクセラレータ タイプをご覧ください。
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT これは、サービス アカウントのメールアドレスです。 Google Cloud コンソール -> IAM -> サービス アカウントで確認できます。

例: tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES 作成するスライスの数(マルチスライスの場合のみ必要)。
QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当てテキスト ID。
VALID_DURATION キューに入れられたリソース リクエストが有効である期間。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

ネットワーク パフォーマンスを最適化する

最適なパフォーマンスを得るには、8,896 MTU(最大伝送単位)のネットワークを使用します。

デフォルトでは、Virtual Private Cloud(VPC)は 1,460 バイトの MTU のみを提供します。これにより、ネットワーク パフォーマンスが最適化されません。VPC ネットワークの MTU は、1,300 ~ 8,896 バイトの任意の値に設定できます。一般的なカスタム MTU サイズは 1,500 バイト(標準イーサネット)または 8,896 バイト(可能な最大値)です。詳細については、有効な VPC ネットワークの MTU サイズをご覧ください。

既存またはデフォルトのネットワークの MTU 設定の変更の詳細については、VPC ネットワークの MTU 設定を変更するをご覧ください。

次の例では、8,896 MTU のネットワークを作成します。

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
 --allow tcp,icmp,udp --project=${PROJECT_ID}

マルチ NIC の使用(マルチスライス オプション)

マルチスライス環境を使用している場合、セカンダリ サブネットには次の環境変数が必要です。

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

次のコマンドを使用して、ネットワークとサブネットのカスタム IP ルーティングを作成します。

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
  --project=${PROJECT_ID} \
  --network=${NETWORK_NAME_2} \
  --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
  --router=${ROUTER_NAME} \
  --region=${REGION} \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project=${PROJECT_ID} \
  --enable-logging

マルチネットワーク スライスを作成したら、XPK クラスタを設定し、XPK ワークロード作成コマンド--command ifconfig フラグを追加して、両方のネットワーク インターフェース カード(NIC)が使用されていることを確認できます。

次の xpk workload コマンドを使用して、Google Cloud コンソール ログに ifconfig コマンドの出力を表示し、eth0 と eth1 の両方に mtu=8896 があることを確認します。

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

デバッグログを有効にするVertex AI TensorBoard を使用する場合は、次のオプション引数をコマンドに追加します。

    --enable-debug-logs \
    --use-vertex-tensorboard

eth0 と eth1 の両方に mtu=8,896 があることを確認します。マルチ NIC が実行されていることを確認するには、XPK ワークロード作成コマンドに --command ifconfig フラグを追加します。Google Cloud コンソールのログでその xpk ワークロードの出力を確認し、eth0 と eth1 の両方に mtu=8896 があることを確認します。

TCP 設定を改善

キューに格納されたリソース インターフェースを使用して Cloud TPU を作成した場合は、次のコマンドを実行して TCP 受信バッファの上限を増やし、ネットワーク パフォーマンスを改善できます。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "${PROJECT}" \
  --zone "${ZONE}" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

キューに入れられたリソースでプロビジョニングする

キューに入れられたリソースを使用して Cloud TPU v6e を作成できます。キューに入れられたリソースを使用すると、容量が利用可能になり次第、容量を受け取ることができます。リクエストが入力される開始時間と終了時間を指定できます(省略可)。詳細については、キューに入れられたリソースを管理するをご覧ください。

GKE または XPK で v6e Cloud TPU をプロビジョニングする

v6e で GKE コマンドを使用している場合は、Kubernetes コマンドまたは XPK を使用して Cloud TPU をプロビジョニングし、モデルをトレーニングまたは提供できます。GKE クラスタで Cloud TPU 構成を計画する方法については、GKE で Cloud TPU を計画するをご覧ください。以降のセクションでは、単一 NIC とマルチ NIC をサポートする XPK クラスタを作成するコマンドについて説明します。

単一の NIC をサポートする XPK クラスタを作成する

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=n1-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

コマンドフラグの説明

変数 説明
CLUSTER_NAME XPK クラスタにユーザーが割り当てた名前。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。詳細については、プロジェクトを設定する Google Cloud をご覧ください。
ゾーン サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。
TPU_TYPE アクセラレータ タイプをご覧ください。
NUM_SLICES 作成するスライスの数
CLUSTER_ARGUMENTS 使用するネットワークとサブネットワーク。

例: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES 作成するスライス数。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

マルチ NIC をサポートする XPK クラスタを作成する

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_1} \
    --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
     --project=${PROJECT_ID} \
     --network=${NETWORK_NAME_2} \
     --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
    --cluster=${CLUSTER_NAME} \
    --num-slices=${NUM_SLICES} \
    --tpu-type=${TPU_TYPE} \
    --zone=${ZONE}  \
    --project=${PROJECT_ID} \
    --on-demand \
    --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
    --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
    --create-vertex-tensorboard

コマンドフラグの説明

変数 説明
CLUSTER_NAME XPK クラスタにユーザーが割り当てた名前。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。詳細については、プロジェクトを設定する Google Cloud をご覧ください。
ゾーン サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。
TPU_TYPE アクセラレータ タイプをご覧ください。
NUM_SLICES 作成するスライスの数
CLUSTER_ARGUMENTS 使用するネットワークとサブネットワーク。

例: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS 使用する追加のノード ネットワーク。

例: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES 作成するスライスの数(マルチスライスの場合のみ必要)。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

フレームワークの設定

このセクションでは、JAXPyTorchTensorFlow フレームワークを使用した ML モデル トレーニングの一般的な設定プロセスについて説明します。GKE を使用している場合は、フレームワークの設定に XPK または Kubernetes コマンドを使用できます。

JAX を設定する

このセクションでは、XPK の有無にかかわらず GKE で JAX ワークロードを実行する方法と、キューに入れられたリソースを使用する方法について説明します。

GKE を使用して JAX を設定する

単一ホスト上の単一スライス

次の例では、Kubernetes YAML ファイルを使用して 2x2 の単一ホストノードプールを設定します。

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

正常に完了すると、GKE ログに次のメッセージが表示されます。

Total TPU chips: 4

マルチホスト上の単一スライス

次の例では、Kubernetes YAML ファイルを使用して 4x4 マルチホスト ノードプールを設定します。

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

正常に完了すると、GKE ログに次のメッセージが表示されます。

Total TPU chips: 16

マルチホストでのマルチスライス

次の例では、Kubernetes YAML ファイルを使用して 2 つの 4x4 マルチホスト ノードプールを設定します。

前提条件として、v0.2.3 以降の JobSet をインストールする必要があります。

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

正常に完了すると、GKE ログに次のメッセージが表示されます。

Total TPU chips: 32

詳細については、GKE ドキュメントのマルチスライス ワークロードを実行するをご覧ください。

パフォーマンスを向上させるには、hostNetwork を有効化します。

マルチ NIC

GKE でマルチ NIC を利用する場合は、Kubernetes Pod マニフェストに追加のアノテーションが必要です。次の例は、TPU 以外のマルチ NIC ワークロードのマニフェストです。

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

exec コマンドを使用して Kubernetes Pod に接続すると、次のコードを使用して追加の NIC が表示されます。

$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

XPK で GKE を使用して JAX を設定する

GKE と XPK を使用して JAX を設定するには、xpk README をご覧ください。

MaxText で XPK を設定して実行する方法については、MaxText の実行方法をご覧ください。

キューに入れられたリソースを使用して JAX を設定する

gcloud alpha compute tpus tpu-vm ssh コマンドを使用して、スライス内のすべての Cloud TPU VM に JAX を同時にインストールします。マルチスライスの場合は、--node=all フラグを追加します。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

次のコマンドを実行すると、スライスで使用可能な Cloud TPU コアの数を確認して、すべてが正しくインストールされていることをテストできます。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

v6e-16 スライスで実行した場合の出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() は、指定されたスライス内のチップの合計数を示します。jax.local_device_count() は、このスライス内の単一の VM からアクセス可能なチップの数を示します。

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   && pip install -r requirements.txt  && pip install . '

JAX の設定に関するトラブルシューティング

一般的なヒントとして、GKE ワークロード マニフェストで詳細なロギングを有効にします。次に、ログを GKE サポートに提供します。

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

エラー メッセージ

no endpoints available for service 'jobset-webhook-service'

このエラーは、ジョブセットが正しくインストールされていないことを意味します。jobset-controller-manager Deployment Kubernetes Pod が実行されていることを確認します。詳細については、JobSet のトラブルシューティングに関するドキュメントをご覧ください。

TPU initialization failed: Failed to connect

GKE ノード バージョンが 1.30.4-gke.1348000 以降であることを確認します(GKE 1.31 はサポートされていません)。

PyTorch を設定する

このセクションでは、PyTorch/XLA を使用して v6e で PJRT の使用を開始する方法について説明します。Python 3.10 が推奨される Python バージョンです。

XPK で GKE を使用して PyTorch を設定する

PyTorch の依存関係がすでにインストールされている XPK で、次の Docker コンテナを使用できます。

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

XPK ワークロードを作成するには、次のコマンドを使用します。

python3 xpk.py workload create \
    --cluster ${CLUSTER_NAME} \
    {--docker-image | --base-docker-image} us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
    --workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --num-slices=${NUM_SLICES}  \
    --on-demand \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --enable-debug-logs \
    --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

--base-docker-image を使用すると、現在の作業ディレクトリが新しい Docker にビルドされた新しい Docker イメージが作成されます。

キューに格納されたリソースを使用して PyTorch を設定する

キューに登録されたリソースを使用して PyTorch をインストールし、v6e で小さなスクリプトを実行する手順は次のとおりです。

SSH を使用して依存関係をインストールし、VM にアクセスする

次のコマンドを使用して、すべての Cloud TPU VM に依存関係をインストールします。マルチスライスの場合は、--worker=all フラグを追加します。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 && \
               pip install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

サイズが大きく、頻繁な割り当てがあるモデルのパフォーマンスを改善する

サイズが大きく、頻繁な割り当てがあるモデルの場合、tcmalloc 関数を使用すると、デフォルトの malloc 関数の実装と比較してパフォーマンスが大幅に向上するため、Cloud TPU VM で使用されるデフォルトの malloc 関数は tcmalloc です。ただし、ワークロードによっては(たとえば、埋め込みテーブルへの割り当てが非常に大きい DLRM の場合)、tcmalloc 関数で速度が低下する可能性があります。その場合は、次の変数の設定を解除して、代わりにデフォルトの malloc 関数を使用してください。

unset LD_PRELOAD

Python スクリプトを使用して v6e VM で計算を行う

次のコマンドを使用して、2 つのテンソルを作成し、それらを加算して結果を出力するスクリプトを実行します。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

これにより、次のような出力が生成されます。

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

TensorFlow の設定

次のコマンドを実行すると、v6e 互換の TensorFlow バージョンで Cloud TPU ランタイムをリセットできます。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

SSH を使用して worker-0 にアクセスします。

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

worker-0 に TensorFlow をインストールします。

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

TPU_NAME 環境変数をエクスポートします。

export TPU_NAME=v6e-16

次の Python スクリプトを実行して、スライスで使用可能な Cloud TPU コアの数を確認し、すべてが正しくインストールされていることをテストできます。

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

v6e-16 スライスで実行した場合の出力は次のようになります。

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e with SkyPilot

Cloud TPU v6e は SkyPilot で使用できます。v6e に関連する場所と料金情報を SkyPilot に追加する手順は次のとおりです。

  1. ~/.sky/catalogs/v5/gcp/vms.csv ファイルの末尾に以下を追加します。

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. YAML ファイルで次のリソースを指定します。

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Cloud TPU v6e を使用してクラスタを起動します。

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. SSH を使用して Cloud TPU v6e に接続します。ssh tpu_v6

推論チュートリアル

次のチュートリアルでは、Cloud TPU v6e で推論を実行する方法について説明します。

トレーニング サンプル

以降のセクションでは、Cloud TPU v6e で MaxText、MaxDiffusion、PyTorch モデルをトレーニングする例について説明します。

v6e Cloud TPU VM での MaxText と MaxDiffusion のトレーニング

以降のセクションでは、MaxText モデルと MaxDiffusion モデルのトレーニング ライフサイクルについて説明します。

一般的な手順は次のとおりです。

  1. ワークロードのベースイメージをビルドします。
  2. XPK を使用してワークロードを実行します。
    1. ワークロードのトレーニング コマンドをビルドします。
    2. ワークロードをデプロイします。
  3. ワークロードを追跡して指標を表示する。
  4. XPK ワークロードが不要な場合は削除します。
  5. XPK クラスタが不要になったら削除します。

ベースイメージをビルドする

MaxText または MaxDiffusion をインストールして Docker イメージをビルドします。

  1. 使用するリポジトリのクローンを作成し、リポジトリのディレクトリに移動します。

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Google Cloud CLI を使用するように Docker を構成します。

    gcloud auth configure-docker
    
  3. 次のコマンドまたは JAX Stable Stack を使用して Docker イメージをビルドします。JAX Stable Stack の詳細については、JAX Stable Stack を使用して Docker イメージを作成するをご覧ください。

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. ローカルにビルドされたイメージがないマシンからワークロードを起動する場合は、イメージをアップロードします。

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    

XPK を使用してワークロードを実行する

  1. MaxText によって設定されたデフォルト値または MaxDiffusion を使用していない場合は、次の環境変数を設定します。

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. モデル スクリプトを作成します。このスクリプトは、後でトレーニング コマンドとしてコピーされます。

    モデル スクリプトはまだ実行しないでください。

    MaxText

    MaxText は、ピュア Python と JAX で記述された、高性能でスケーラビリティに優れたオープンソースの LLM です。トレーニングと推論に TPU と GPU をターゲットとしています。 Google Cloud

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma は、Gemini の研究とテクノロジーに基づいて Google DeepMind が開発したオープン重み LLM のファミリーです。

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral は、Mistral AI によって開発された最先端の AI モデルであり、スパースな Mixture of Experts(MoE)アーキテクチャを利用しています。

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama は、Meta が開発したオープン ウェイト LLM のファミリーです。

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash
    

    MaxDiffusion

    MaxDiffusion は、Cloud TPU や GPU などの XLA デバイスで実行される、純粋な Python と JAX で記述されたさまざまな潜在拡散モデルのリファレンス実装のコレクションです。Stable Diffusion は、テキスト入力からフォトリアリスティックな画像を生成する、潜在的 text-to-image モデルです。

    MaxDiffusion を実行するには、次の git checkout コマンドに示すように、特定の Git ブランチをインストールする必要があります。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    トレーニング スクリプト:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
        
  3. 次の変数をエクスポートします。

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    環境変数の説明

    変数 説明
    CLUSTER_NAME XPK クラスタの名前。
    ACCELERATOR_TYPE アクセラレータのタイプをご覧ください。
    NUM_SLICES TPU スライスの数。
    YOUR_MODEL_SCRIPT トレーニング コマンドとして実行するモデル スクリプト。
  4. 前の手順で作成したスクリプトを使用してモデルを実行します。MaxText ベースイメージを使用するには、--base-docker-image フラグを指定するか、--docker-image フラグと使用するイメージを指定する必要があります。

    省略可: --enable-debug-logs フラグを含めると、デバッグ ロギングを有効にできます。詳細については、MaxText で JAX をデバッグするをご覧ください。

    省略可: --use-vertex-tensorboard フラグを指定して Vertex AI Experiments を作成し、Vertex AI TensorBoard にデータをアップロードできます。詳細については、Vertex AI を使用して MaxText で JAX をモニタリングするをご覧ください。

    python3 xpk.py workload create \
        --cluster ${CLUSTER_NAME} \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command=${YOUR_MODEL_SCRIPT}

    出力には、ワークロードを追跡するためのリンクが含まれます。リンクを開き、[ログ] タブをクリックして、ワークロードをリアルタイムで追跡します。

MaxText で JAX をデバッグする

補足 XPK コマンドを使用して、クラスタまたはワークロードが実行されていない理由を診断します。

Vertex AI を使用して MaxText の JAX をモニタリングする

Vertex AI のマネージド TensorBoard を使用して、スカラーデータとプロファイル データを表示します。

  1. 使用しているゾーンのリソース管理(CRUD)リクエストを 600 から 5,000 に増やします。16 個未満の VM を使用する小規模なワークロードでは、問題にならない場合があります。
  2. Vertex AI の cloud-accelerator-diagnostics などの依存関係をインストールします。

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Vertex AI TensorBoard を作成するで説明されているように、--create-vertex-tensorboard フラグを使用して XPK クラスタを作成します。このコマンドは既存のクラスタでも実行できます。

  4. --use-vertex-tensorboard フラグとオプションの --experiment-name フラグを使用して XPK ワークロードを実行するときに、Vertex AI テストを作成します。手順の一覧については、Vertex AI Experiments を作成して Vertex AI TensorBoard にデータをアップロードするをご覧ください。

ログには、次のような Vertex AI TensorBoard へのリンクが含まれています。

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Vertex AI TensorBoard のリンクは Google Cloud コンソールで確認することもできます。Google Cloud コンソールで Vertex AI Experiments に移動します。プルダウンから適切なリージョンを選択します。

TensorBoard ディレクトリも、${BASE_OUTPUT_DIR} で指定した Cloud Storage バケットに書き込まれます。

XPK ワークロードを削除する

ジョブの接頭辞またはジョブのステータスに基づいて 1 つ以上のワークロードを削除するには、xpk workload delete コマンドを使用します。このコマンドは、実行する必要がない XPK ワークロードを送信した場合や、キューにスタックしているジョブがある場合に便利です。

XPK クラスタを削除する

クラスタを削除するには、xpk cluster delete コマンドを使用します。

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone=${ZONE} --project=${PROJECT_ID}

v6e Cloud TPU VM での Llama と PyTorch/XLA トレーニング

このチュートリアルでは、WikiText データセットを使用して、Cloud TPU v6e で PyTorch/XLA を使用して Llama モデルをトレーニングする方法について説明します。

Hugging Face と Llama 3 モデルにアクセスする

このチュートリアルを実行するには、Hugging Face ユーザー アクセス トークンが必要です。ユーザー アクセス トークンの作成と使用については、ユーザー アクセス トークンの Hugging Face ドキュメントをご覧ください。

また、Hugging Face の Llama 3 8B モデルにアクセスする権限も必要です。アクセス権を取得するには、HuggingFace の Meta-Llama-3-8B モデルにアクセスしてアクセスをリクエストします。

Cloud TPU VM を作成する

チュートリアルを実行するために、8 個のチップを含む Cloud TPU v6e を作成します。

  1. 環境変数を設定します。

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT_ID=your-project-id
    export ZONE=us-east1-d
  2. Cloud TPU VM を作成します。

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID}

インストール

Hugging Face Transformer と依存関係の pytorch-tpu/transformers フォークをインストールします。このチュートリアルは、この例で使用されている次の依存関係のバージョンでテストされています。

  • torch: 2.5.0 と互換性あり
  • torch_xla[tpu]: 2.5.0 と互換性あり
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} --zone ${ZONE} \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

モデル構成を設定する

次のセクションのトレーニング コマンド(モデルを実行する)では、2 つの JSON 構成ファイルを使用して、モデル パラメータと FSDP(完全にシャーディングされたデータ パラレル)構成を定義します。FSDP シャーディングは、トレーニング中にモデルの重みが大きなバッチサイズに適合するように使用されます。小規模なモデルでトレーニングする場合は、データ並列処理を使用して各デバイスに重みを複製するだけで十分な場合があります。PyTorch/XLA でデバイス間でテンソルをシャーディングする方法の詳細については、PyTorch/XLA SPMD ユーザーガイドをご覧ください。

  1. モデル パラメータ構成ファイルを作成します。Llama3-8B のモデル パラメータ構成は次のとおりです。他のモデルについては、Hugging Face で構成を確認してください。たとえば、Llama2-7B 構成をご覧ください。

    cat > llama-config.json << EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. FSDP 構成ファイルを作成します。

    cat > fsdp-config.json << EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    FSDP の詳細については、FSDPv2 をご覧ください。

  3. 次のコマンドを使用して、構成ファイルを Cloud TPU VM にアップロードします。

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

モデルを実行する

前のセクションで作成した構成ファイルを使用して run_clm.py スクリプトを実行し、WikiText データセットで Llama 3 8B モデルをトレーニングします。トレーニング スクリプトを Cloud TPU v6e-8 で実行すると、約 10 分かかります。

  1. 次のコマンドを使用して、Cloud TPU で Hugging Face にログインします。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. モデル トレーニングを実行します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

PyTorch/XLA のトラブルシューティング

前のセクションでデバッグ用のオプション変数を設定した場合、モデルのプロファイルは変数 PROFILE_LOGDIR で指定された場所に保存されます。この場所に保存されている xplane.pb ファイルを抽出し、tensorboard を使用して、TensorBoard の手順に沿ってブラウザでプロファイルを表示できます。PyTorch/XLA が期待どおりに動作しない場合は、トラブルシューティング ガイドをご覧ください。モデルのデバッグ、プロファイリング、最適化に関する推奨事項が記載されています。

v6e での DLRM DCN v2 トレーニング

このチュートリアルでは、Cloud TPU v6e で DLRM DCN v2 モデルをトレーニングする方法について説明します。64、128、256 個のチップで TPU v6e をプロビジョニングする必要があります。

マルチホスト TPU で実行している場合は、次のコマンドを実行して、適切な TensorFlow バージョンで tpu-runtime をリセットします。単一ホストの TPU で実行している場合は、次の 2 つのコマンドを実行する必要はありません。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --worker=all \
    --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone  ${ZONE} \
    --worker=all \
    --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

SSH を使用して worker-0 に接続する

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project ${PROJECT_ID}

Cloud TPU 名を設定する

export TPU_NAME=your-tpu-name

DLRM v2 を実行する

次のコードスニペットを script.sh という名前のファイルにコピーします。

pip install --user setuptools==65.5.0

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

GKE で TensorFlow を実行している場合は、次のコマンドを使用して TensorFlow Cloud TPU ホイールおよび libtpu をインストールします。

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

推奨ワークロード(DLRM DCN など)の実行に必要な次のフラグを設定します。

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

script.sh を実行します。

chmod +x script.sh
./script.sh

ベンチマーク結果

次のセクションでは、v6e での DLRM DCN v2 と MaxDiffusion のベンチマーク結果について説明します。

DLRM DCN v2

DLRM DCN v2 トレーニング スクリプトは、さまざまなスケールで実行されました。スループットは次の表を参照してください。

v6e-64 v6e-128 v6e-256
トレーニングのステップ 7000 7000 7000
グローバル バッチサイズ 131072 262144 524288
スループット(例/秒) 2975334 5111808 10066329

MaxDiffusion

MaxDiffusion のトレーニング スクリプトは、v6e-4、v6e-16、2 つの v6e-16 で実行しました。スループットは次の表を参照してください。

v6e-4 v6e-16 2 個の v6e-16
トレーニングのステップ 0.069 0.073 0.13
グローバル バッチサイズ 8 32 64
スループット(例/秒) 115.9 438.4 492.3