Trillium(v6e)の概要

v6e は、このドキュメント、TPU API、ログで Trillium を指すために使用され、Google の第 6 世代の TPU を表します。

Pod あたり 256 個のチップを備えた v6e アーキテクチャは、v5e と多くの類似点があります。このシステムは、トランスフォーマー、テキスト画像変換、畳み込みニューラル ネットワーク(CNN)のトレーニング、ファインチューニング、サービングに最適化されています。

v6e システム アーキテクチャと構成の詳細については、TPU v6e をご覧ください。

この概要ドキュメントでは、JAX または PyTorch フレームワークを使用したモデルのトレーニングとサービングのプロセスに焦点を当てます。各フレームワークで、キューに格納されたリソースまたは GKE を使用して TPU をプロビジョニングできます。GKE の設定は、XPK または GKE コマンドを使用して行うことができます。

v6e を使用したモデルのトレーニングまたはサービングの一般的な手順

  1. Google Cloud プロジェクトを準備する
  2. 容量を確保する
  3. Cloud TPU 環境をプロビジョニングする
  4. モデルのトレーニングまたは推論ワークロードを実行する

Google Cloud プロジェクトを準備する

Cloud TPU を使用するには、次の操作が必要です。

  • 課金を有効にした Google Cloud アカウントおよびプロジェクトを作成する
  • Google Cloud CLI アルファ版コンポーネントをインストールする
  • Cloud TPU API を有効にする
  • Cloud TPU サービス エージェントを作成する
  • Cloud TPU サービス アカウントを作成して権限を付与する

詳細については、Cloud TPU 環境を設定するをご覧ください。

容量を確保する

Cloud TPU v6e の割り当てをリクエストし、容量に関する疑問を解決するには、Google Cloud サポートにお問い合わせください。

Cloud TPU 環境をプロビジョニングする

v6e Cloud TPU は、GKE を使用するか、GKE と XPK(GKE のラッパー CLI ツール)を使用するか、キューに格納されたリソースとしてプロビジョニングおよび管理できます。

前提条件

  • プロジェクトに十分な TPUS_PER_TPU_FAMILY 割り当てがあることを確認します。これは、 Google Cloudプロジェクト内でアクセスできるチップの最大数を指します。
  • v6e は次の構成でテストされています。
    • Python 3.10 以降
    • ナイトリー ソフトウェア バージョン:
      • ナイトリー JAX 0.4.32.dev20240912
      • ナイトリー LibTPU 0.1.dev20240912+nightly
    • 安定版ソフトウェア バージョン:
      • JAX + JAX Lib バージョン 0.4.37
  • プロジェクトの次の割り当てが十分にあることを確認します。

    • Cloud TPU VM の割り当て
    • IP アドレスの割り当て
    • Hyperdisk Balanced と使用する他のディスクタイプの割り当て

  • GKE と XPK を使用している場合は、XPK の実行に必要な権限について、ユーザーまたはサービス アカウントに対する Google Cloud コンソールの権限をご覧ください。

環境変数を作成する

Cloud Shell で、次の環境変数を作成します。

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

コマンドフラグの説明

変数 説明
NODE_ID キューに格納されたリソース リクエストの割り当て時に作成される Cloud TPU のユーザー割り当て ID。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 詳細については、 Google Cloud プロジェクトを設定するをご覧ください。
ZONE サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。
ACCELERATOR_TYPE アクセラレータ タイプをご覧ください。
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT サービス アカウントのメールアドレス。 Google Cloud コンソール -> [IAM] -> [サービス アカウント] で確認できます。

例: tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES 作成するスライスの数(マルチスライスのみで必要)。
QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当てテキスト ID。
VALID_DURATION キューに格納されたリソース リクエストが有効である期間。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

ネットワーク パフォーマンスを最適化する

パフォーマンスを最大限に高めるには、8,896 MTU(最大伝送単位)のネットワークを使用します。

デフォルトでは、Virtual Private Cloud(VPC)では 1,460 バイトの MTU のみが提供され、ネットワーク パフォーマンスが最適化されません。VPC ネットワークの MTU は、1,300~8,896 バイト(両端を含む)の任意の値に設定できます。一般的なカスタム MTU サイズは 1,500 バイト(標準イーサネット)または 8,896 バイト(可能な最大値)です。詳細については、有効な VPC ネットワークの MTU サイズをご覧ください。

既存またはデフォルトのネットワークの MTU 設定の変更について詳しくは、VPC ネットワークの MTU 設定を変更するをご覧ください。

次の例では、8,896 MTU のネットワークを作成します。

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
 --allow tcp,icmp,udp --project=${PROJECT_ID}

マルチ NIC の使用(マルチスライスのオプション)

マルチスライス環境を使用している場合、セカンダリ サブネットには次の環境変数が必要です。

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

次のコマンドを使用して、ネットワークとサブネットのカスタム IP ルーティングを作成します。

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
  --project=${PROJECT_ID} \
  --network=${NETWORK_NAME_2} \
  --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
  --router=${ROUTER_NAME} \
  --region=${REGION} \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project=${PROJECT_ID} \
  --enable-logging

マルチネットワーク スライスを作成したら、XPK クラスタを設定し、XPK ワークロード作成コマンド--command ifconfig フラグを追加して、両方のネットワーク インターフェース カード(NIC)が使用されていることを確認できます。

次の xpk workload コマンドを使用して、 Google Cloud コンソールログに ifconfig コマンドの出力を表示し、eth0 と eth1 の両方が mtu=8896 であることを確認します。

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

デバッグログを有効にするVertex AI TensorBoard を使用する場合は、次のオプション引数をコマンドに追加します。

    --enable-debug-logs \
    --use-vertex-tensorboard

eth0 と eth1 の両方が mtu=8,896 であることを確認します。マルチ NIC が実行されていることを確認するには、XPK ワークロード作成コマンドに --command ifconfig フラグを追加します。 Google Cloud コンソールログでその xpk ワークロードの出力をチェックし、eth0 と eth1 の両方が mtu=8896 であることを確認します。

TCP 設定の改善

キューに格納されたリソース インターフェースを使用して Cloud TPU を作成した場合は、次のコマンドを実行して TCP 受信バッファの上限を増やし、ネットワーク パフォーマンスを改善できます。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "${PROJECT_ID}" \
  --zone "${ZONE}" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

キューに格納されたリソースを使用してプロビジョニングする

キューに格納されたリソースを使用して Cloud TPU v6e を作成できます。キューに格納されたリソースを使用することで、容量が利用可能になった時点で使用できるようになります。必要に応じて、リクエストを処理する開始時刻と終了時刻を指定できます。詳細については、キューに格納されたリソースを管理するをご覧ください。

GKE または XPK を使用して v6e Cloud TPU をプロビジョニングする

v6e で GKE コマンドを使用している場合は、Kubernetes コマンドまたは XPK を使用して Cloud TPU をプロビジョニングし、モデルをトレーニングまたはサービングできます。GKE クラスタで Cloud TPU 構成を計画する方法については、GKE で Cloud TPU を計画するをご覧ください。以降のセクションでは、単一 NIC とマルチ NIC をサポートする XPK クラスタを作成するコマンドについて説明します。

単一 NIC をサポートする XPK クラスタを作成する

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

コマンドフラグの説明

変数 説明
CLUSTER_NAME ユーザーが割り当てた XPK クラスタの名前。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 詳細については、 Google Cloud プロジェクトを設定するをご覧ください。
ZONE サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。
TPU_TYPE アクセラレータ タイプをご覧ください。
NUM_SLICES 作成するスライスの数。
CLUSTER_ARGUMENTS 使用するネットワークとサブネットワーク。

例: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES 作成するスライスの数。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

マルチ NIC をサポートする XPK クラスタを作成する

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_1} \
    --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
     --project=${PROJECT_ID} \
     --network=${NETWORK_NAME_2} \
     --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
    --cluster=${CLUSTER_NAME} \
    --cluster-cpu-machine-type=e2-standard-8 \
    --num-slices=${NUM_SLICES} \
    --tpu-type=${TPU_TYPE} \
    --zone=${ZONE}  \
    --project=${PROJECT_ID} \
    --on-demand \
    --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
    --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
    --create-vertex-tensorboard

コマンドフラグの説明

変数 説明
CLUSTER_NAME ユーザーが割り当てた XPK クラスタの名前。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 詳細については、 Google Cloud プロジェクトを設定するをご覧ください。
ZONE サポートされているゾーンについては、Cloud TPU のリージョンとゾーンのドキュメントをご覧ください。
TPU_TYPE アクセラレータ タイプをご覧ください。
NUM_SLICES 作成するスライスの数。
CLUSTER_ARGUMENTS 使用するネットワークとサブネットワーク。

例: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS 使用する追加のノード ネットワーク。

例: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES 作成するスライスの数(マルチスライスのみで必要)。
NETWORK_NAME 使用するセカンダリ ネットワークの名前。
NETWORK_FW_NAME 使用するセカンダリ ネットワーク ファイアウォールの名前。

フレームワークの設定

このセクションでは、JAX フレームワークと PyTorch フレームワークを使用した ML モデル トレーニングの一般的な設定プロセスについて説明します。GKE を使用している場合は、フレームワークの設定に XPK または Kubernetes コマンドを使用できます。

JAX を設定する

このセクションでは、XPK を使用して、XPK を使用しないで、またはキューに格納されたリソースを使用して、GKE で JAX ワークロードを実行するための設定方法について説明します。

GKE を使用して JAX を設定する

単一ホストの単一スライス

次の例では、Kubernetes YAML ファイルを使用して 2x2 の単一ホスト ノードプールを設定します。

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

正常に完了すると、GKE ログに次のメッセージが表示されます。

Total TPU chips: 4

マルチホストの単一スライス

次の例では、Kubernetes YAML ファイルを使用して 4x4 のマルチホスト ノードプールを設定します。

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

正常に完了すると、GKE ログに次のメッセージが表示されます。

Total TPU chips: 16

マルチホストのマルチスライス

次の例では、Kubernetes YAML ファイルを使用して 2 つの 4x4 マルチホスト ノードプールを設定します。

前提条件として、v0.2.3 以降の JobSet をインストールする必要があります。

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

正常に完了すると、GKE ログに次のメッセージが表示されます。

Total TPU chips: 32

詳細については、GKE ドキュメントのマルチスライス ワークロードを実行するをご覧ください。

パフォーマンスを向上させるには、hostNetwork を有効にします

マルチ NIC

次のマルチ NIC マニフェストを使用するには、ネットワークを設定する必要があります。詳細については、Kubernetes Pod のマルチネットワーク サポートを設定するをご覧ください。GKE でマルチ NIC を利用するには、Kubernetes Pod マニフェストに追加のアノテーションを含める必要があります。以下は、TPU 以外のマルチ NIC ワークロード マニフェストの例です。

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

exec コマンドを使用して Kubernetes Pod に接続した場合、次のコードを使用すると、追加の NIC が表示されます。

$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

GKE と XPK を使用して JAX を設定する

GKE と XPK を使用して JAX を設定するには、xpk README をご覧ください。

MaxText で XPK を設定、実行する方法については、MaxText の実行方法をご覧ください。

キューに格納されたリソースを使用して JAX を設定する

gcloud alpha compute tpus tpu-vm ssh コマンドを使用して、スライスのすべての Cloud TPU VM に JAX を同時にインストールします。マルチスライスの場合は、--node=all フラグを追加します。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

次のコマンドを実行すると、スライスで使用可能な Cloud TPU コアの数を確認して、すべてが正しくインストールされていることをテストできます。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

v6e-16 スライスで実行した場合の出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() は、指定されたスライス内のチップの合計数を示します。jax.local_device_count() は、このスライス内の単一の VM からアクセス可能なチップの数を示します。


 gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   pip install setuptools==59.6.0 &&
   pip install -r requirements.txt  && pip install . '

JAX の設定に関するトラブルシューティングを行う

一般的なヒントとして、GKE ワークロード マニフェストで詳細ログを有効にします。その後、ログを GKE サポートに提供します。

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

エラー メッセージ

no endpoints available for service 'jobset-webhook-service'

このエラーは、ジョブセットが正しくインストールされていないことを意味します。jobset-controller-manager Deployment Kubernetes Pod が実行されているかどうかを確認してください。詳細については、JobSet のトラブルシューティングに関するドキュメントをご覧ください。

TPU initialization failed: Failed to connect

GKE ノード バージョンが 1.30.4-gke.1348000 以降であることを確認します(GKE 1.31 はサポートされていません)。

PyTorch を設定する

このセクションでは、PyTorch/XLA を使用して v6e で PJRT の使用を開始する方法について説明します。推奨される Python バージョンは Python 3.10 です。

GKE と XPK を使用して PyTorch を設定する

PyTorch の依存関係がすでにインストールされている XPK で、次の Docker コンテナを使用できます。

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

XPK ワークロードを作成するには、次のコマンドを使用します。

python3 xpk.py workload create \
    --cluster ${CLUSTER_NAME} \
    {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name \
    --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --num-slices=${NUM_SLICES}  \
    --on-demand \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --enable-debug-logs \
    --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

--base-docker-image を使用すると、新しい Docker イメージが作成され、現在の作業ディレクトリが新しい Docker に組み込まれます。

キューに格納されたリソースを使用して PyTorch を設定する

キューに格納されたリソースを使用して PyTorch をインストールし、v6e で小さなスクリプトを実行する手順は次のとおりです。

SSH を使用して依存関係をインストールし、VM にアクセスする

次のコマンドを使用して、すべての Cloud TPU VM に依存関係をインストールします。マルチスライスの場合は、--worker=all フラグを追加します。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
               pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

大きい割り当てが頻繁に発生するモデルのパフォーマンスを改善する

大きい割り当てが頻繁に発生するモデルの場合、tcmalloc 関数を使用すると、デフォルトの malloc 関数の実装と比較してパフォーマンスが大幅に向上するため、Cloud TPU VM で使用されるデフォルトの malloc 関数は tcmalloc です。ただし、ワークロードによっては(たとえば、埋め込みテーブルへの割り当てが非常に大きい DLRM の場合)、tcmalloc 関数を使用すると速度が低下する可能性があります。その場合は、次の変数の設定を解除して、代わりにデフォルトの malloc 関数を使用できます。

unset LD_PRELOAD

Python スクリプトを使用して v6e VM で計算を行う

次のコマンドを使用して、2 つのテンソルを作成し、それらを加算して結果を出力するスクリプトを実行します。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

これにより、次のような出力が生成されます。

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

SkyPilot で v6e を使用する

Cloud TPU v6e は SkyPilot で使用できます。v6e に関連するロケーションと料金の情報を SkyPilot に追加する手順は次のとおりです。詳細については、GitHub の SkyPilot TPU v6e の例をご覧ください。

推論チュートリアル

次のチュートリアルでは、Cloud TPU v6e で推論を実行する方法について説明します。

トレーニング サンプル

以降のセクションでは、Cloud TPU v6e で MaxText、MaxDiffusion、PyTorch の各モデルをトレーニングする例について説明します。

v6e Cloud TPU VM での MaxText と MaxDiffusion のトレーニング

以降のセクションでは、MaxText モデルと MaxDiffusion モデルのトレーニング ライフサイクルについて説明します。

一般的な手順は次のとおりです。

  1. ワークロードのベースイメージをビルドします。
  2. XPK を使用してワークロードを実行します。
    1. ワークロードのトレーニング コマンドをビルドします。
    2. ワークロードをデプロイします。
  3. ワークロードを追跡して指標を表示します。
  4. 不要な XPK ワークロードを削除します。
  5. 不要になった XPK クラスタを削除します。

ベースイメージをビルドする

MaxText または MaxDiffusion をインストールして Docker イメージをビルドします。

  1. 使用するリポジトリのクローンを作成し、リポジトリのディレクトリに移動します。

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Google Cloud CLI を使用するように Docker を構成します。

    gcloud auth configure-docker
    
  3. 次のコマンドまたは JAX Stable Stack を使用して Docker イメージをビルドします。JAX Stable Stack の詳細については、JAX Stable Stack を使用して Docker イメージをビルドするをご覧ください。

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. アクティブな gcloud CLI 構成でプロジェクト ID を設定します。

    gcloud config set project ${PROJECT_ID}
    
  5. ローカルにビルドされたイメージがないマシンからワークロードを起動する場合は、イメージをアップロードします。

    1. CLOUD_IMAGE_NAME 環境変数を設定します。

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. 画像をアップロードします。

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

XPK を使用してワークロードを実行する

  1. MaxText または MaxDiffusion によって設定されたデフォルト値を使用していない場合は、次の環境変数を設定します。

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. モデル スクリプトをビルドします。このスクリプトは、後でトレーニング コマンドとしてコピーされます。

    モデル スクリプトはまだ実行しないでください。

    MaxText

    MaxText は、高パフォーマンスでスケーラビリティに優れたオープンソースの LLM です。ピュア Python と JAX で記述され、トレーニングと推論で Google Cloud TPU および GPU をターゲットとします。

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma は、Gemini の研究とテクノロジーに基づいて Google DeepMind が開発したオープン ウェイト LLM のファミリーです。

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral は、Mistral AI が開発した最先端の AI モデルであり、スパースな Mixture of Experts(MoE)アーキテクチャを利用します。

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama は、Meta が開発したオープン ウェイト LLM のファミリーです。

    PyTorch で Llama3 を実行する方法の例については、torchprime GitHub リポジトリの torch_xla モデルをご覧ください。

    MaxDiffusion

    MaxDiffusion は、さまざまな潜在拡散モデルのリファレンス実装のコレクションです。ピュア Python と JAX で記述され、Cloud TPU や GPU などの XLA デバイスで実行されます。Stable Diffusion は、テキスト入力からフォトリアリスティックな画像を生成する、潜在テキスト画像変換モデルです。

    MaxDiffusion を実行するには、次の git checkout コマンドに示すように、特定の Git ブランチをインストールする必要があります。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    トレーニング スクリプト:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
        
  3. 次の変数をエクスポートします。

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    環境変数の説明

    変数 説明
    CLUSTER_NAME XPK クラスタの名前。
    ACCELERATOR_TYPE アクセラレータ タイプをご覧ください。
    NUM_SLICES TPU スライスの数。
    YOUR_MODEL_SCRIPT トレーニング コマンドとして実行するモデル スクリプト。
  4. 前の手順で作成したスクリプトを使用してモデルを実行します。MaxText ベースイメージを使用するために --base-docker-image フラグを指定するか、--docker-image フラグと使用するイメージを指定する必要があります。

    省略可: --enable-debug-logs フラグを含めると、デバッグ ロギングを有効にできます。詳細については、MaxText で JAX をデバッグするをご覧ください。

    省略可: --use-vertex-tensorboard フラグを含めると、Vertex AI Experiments を作成し、Vertex AI TensorBoard にデータをアップロードできます。詳細については、Vertex AI を使用して MaxText で JAX をモニタリングするをご覧ください。

    python3 xpk.py workload create \
        --cluster ${CLUSTER_NAME} \
        {--base-docker-image maxtext_base_image|--docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command=${YOUR_MODEL_SCRIPT}

    出力には、ワークロードを追跡するためのリンクが含まれます。リンクを開き、[ログ] タブをクリックしてワークロードをリアルタイムで追跡します。

MaxText で JAX をデバッグする

補足 XPK コマンドを使用して、クラスタまたはワークロードが実行されていない理由を診断します。

Vertex AI を使用して MaxText の JAX をモニタリングする

TensorBoard を使用するには、 Google Cloud ユーザー アカウントに aiplatform.user ロールが必要です。次のコマンドを実行して、これらのロールを付与します。

gcloud projects add-iam-policy-binding your-project-id \
--member='user:your-email' \
--role='roles/aiplatform.user'

Vertex AI のマネージド TensorBoard を使用して、スカラーデータとプロファイル データを表示します。

  1. 使用しているゾーンのリソース管理(CRUD)リクエストを 600 から 5,000 に増やします。これは、16 台未満の VM を使用する小規模なワークロードでは問題にならない可能性があります。
  2. Vertex AI の cloud-accelerator-diagnostics などの依存関係をインストールします。

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Vertex AI TensorBoard を作成するで説明されているように、--create-vertex-tensorboard フラグを使用して XPK クラスタを作成します。このコマンドは既存のクラスタでも実行できます。

  4. --use-vertex-tensorboard フラグとオプションの --experiment-name フラグを使用して、XPK ワークロードを実行するときに、Vertex AI テストを作成します。手順の一覧については、Vertex AI テストを作成して Vertex AI TensorBoard にデータをアップロードするをご覧ください。

ログには、次のような Vertex AI TensorBoard へのリンクが含まれています。

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Vertex AI TensorBoard のリンクは、 Google Cloud コンソールでも確認できます。 Google Cloud コンソールで Vertex AI Experiments に移動します。プルダウンから適切なリージョンを選択します。

TensorBoard ディレクトリも、${BASE_OUTPUT_DIR} で指定した Cloud Storage バケットに書き込まれます。

XPK ワークロードを削除する

ジョブの接頭辞またはジョブ ステータスに基づいて 1 つ以上のワークロードを削除するには、xpk workload delete コマンドを使用します。このコマンドは、実行する必要がなくなった XPK ワークロードを送信した場合や、キュー内に停止しているジョブがある場合に便利です。

XPK クラスタを削除する

クラスタを削除するには、xpk cluster delete コマンドを使用します。

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone=${ZONE} --project=${PROJECT_ID}

v6e Cloud TPU VM での Llama と PyTorch/XLA のトレーニング

このチュートリアルでは、WikiText データセットを使用し、Cloud TPU v6e で PyTorch/XLA を使用して Llama モデルをトレーニングする方法について説明します。

Hugging Face および Llama 3 モデルにアクセスする

このチュートリアルを実行するには、Hugging Face ユーザー アクセス トークンが必要です。ユーザー アクセス トークンの作成については、Hugging Face のユーザー アクセス トークンに関するドキュメントをご覧ください。

また、Hugging Face の Llama 3 8B モデルにアクセスする権限も必要です。アクセス権を取得するには、HuggingFace の Meta-Llama-3-8B モデルにアクセスしてアクセス権をリクエストしてください。

Cloud TPU VM を作成する

チュートリアルを実行するために、8 チップを搭載した Cloud TPU v6e を作成します。

  1. 環境変数を設定します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン

  2. Cloud TPU VM を作成します。

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID}

インストール

Hugging Face Transformers と依存関係の pytorch-tpu/transformers フォークをインストールします。このチュートリアルは、この例で使用されている次の依存関係のバージョンでテストされています。

  • torch: 2.5.0 との互換性あり
  • torch_xla[tpu]: 2.5.0 との互換性あり
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} --zone ${ZONE} \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

モデル構成を設定する

次のセクション(モデルを実行する)のトレーニング コマンドでは、2 つの JSON 構成ファイルを使用して、モデル パラメータと FSDP(完全にシャーディングされたデータ並列処理)構成を定義します。FSDP シャーディングは、トレーニング中にモデルの重みを大きなバッチサイズに適合させるために使用します。小さなモデルでトレーニングする場合は、データ並列処理を使用して各デバイスに重みを複製するだけで十分な可能性があります。PyTorch/XLA でデバイス間でテンソルをシャーディングする方法の詳細については、PyTorch/XLA SPMD ユーザーガイドをご覧ください。

  1. モデル パラメータ構成ファイルを作成します。Llama3-8B のモデル パラメータ構成は次のとおりです。他のモデルについては、Hugging Face で構成を確認してください。たとえば、Llama2-7B 構成をご覧ください。

    cat > llama-config.json << EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. FSDP 構成ファイルを作成します。

    cat > fsdp-config.json << EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    FSDP の詳細については、FSDPv2 をご覧ください。

  3. 次のコマンドを使用して、構成ファイルを Cloud TPU VM にアップロードします。

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

モデルを実行する

前のセクションで作成した構成ファイルを使用して run_clm.py スクリプトを実行し、WikiText データセットで Llama 3 8B モデルをトレーニングします。Cloud TPU v6e-8 でトレーニング スクリプトを実行すると、約 10 分かかります。

  1. 次のコマンドを使用して、Cloud TPU で Hugging Face にログインします。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. モデル トレーニングを実行します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

PyTorch/XLA のトラブルシューティングを行う

前のセクションでデバッグ用のオプション変数を設定した場合、モデルのプロファイルは PROFILE_LOGDIR 変数で指定された場所に保存されます。この場所に保存されている xplane.pb ファイルを抽出し、tensorboard を使用して、TensorBoard の手順に沿ってブラウザでプロファイルを表示できます。PyTorch/XLA が想定どおりに動作しない場合は、トラブルシューティング ガイドをご覧ください。モデルのデバッグ、プロファイリング、最適化に関する推奨事項が記載されています。

ベンチマーク結果

次のセクションでは、v6e での MaxDiffusion のベンチマーク結果について説明します。

MaxDiffusion

MaxDiffusion のトレーニング スクリプトを v6e-4、v6e-16、2 つの v6e-16 で実行しました。スループットは次の表を参照してください。

v6e-4 v6e-16 2 つの v6e-16
トレーニング ステップ数 0.069 0.073 0.13
グローバル バッチサイズ 8 32 64
スループット(例/秒) 115.9 438.4 492.3