Trillium (v6e) – Einführung

In dieser Dokumentation, in der TPU API und in Protokollen wird „v6e“ für Trillium verwendet. „v6e“ steht für die sechste Generation von TPU von Google.

Mit 256 Chips pro Pod hat die v6e-Architektur viele Ähnlichkeiten mit v5e. Dieses System ist für das Training, die Feinabstimmung und die Bereitstellung von Transformern, Text-zu-Bild-Modellen und CNNs (Convolutional Neural Networks) optimiert.

Weitere Informationen zur Systemarchitektur und zu den Konfigurationen von v6e finden Sie unter TPU v6e.

In diesem Einführungsdokument liegt der Schwerpunkt auf den Prozessen für das Modelltraining und die Modellbereitstellung mit den Frameworks JAX, PyTorch oder TensorFlow. Mit jedem Framework können Sie TPUs mit Ressourcen in der Warteschlange oder GKE bereitstellen. Die GKE-Einrichtung kann mit XPK- oder GKE-Befehlen erfolgen.

Allgemeines Verfahren zum Trainieren oder Bereitstellen eines Modells mit v6e

  1. Google Cloud Projekt vorbereiten
  2. Sichere Kapazität
  3. Cloud TPU-Umgebung bereitstellen
  4. Eine Arbeitslast für das Training oder die Inferenz eines Modells ausführen

Google Cloud -Projekt vorbereiten

Bevor Sie Cloud TPU verwenden können, müssen Sie Folgendes tun:

Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.

Sichere Kapazität

Wenden Sie sich an den Google Cloud -Support, um ein Cloud TPU v6e-Kontingent anzufordern und Fragen zur Kapazität zu stellen.

Cloud TPU-Umgebung bereitstellen

v6e Cloud TPU kann mit GKE, mit GKE und XPK (einem Befehlszeilen-Wrapper-Tool über GKE) oder als Ressourcen in der Warteschlange bereitgestellt und verwaltet werden.

Vorbereitung

  • Prüfen Sie, ob Ihr Projekt ein ausreichendes TPUS_PER_TPU_FAMILY-Kontingent hat. Dieses gibt die maximale Anzahl von Chips an, auf die Sie in Ihrem Projekt zugreifen können. Google Cloud
  • v6e wurde mit der folgenden Konfiguration getestet:
    • Python 3.10 oder höher
    • Nightly-Softwareversionen:
      • Täglich JAX 0.4.32.dev20240912
      • Nightly LibTPU 0.1.dev20240912+nightly
    • Stabile Softwareversionen:
      • JAX + JAX-Bibliothek der Version 0.4.37
  • Prüfen Sie, ob Ihr Projekt für Folgendes ausreichend Kontingente hat:

    • Cloud TPU-VM-Kontingent
    • Kontingent für IP-Adressen
    • Hyperdisk Balanced-Kontingent

  • Wenn Sie GKE mit XPK verwenden, finden Sie unter Cloud Console-Berechtigungen für das Nutzer- oder Dienstkonto Informationen zu den Berechtigungen, die zum Ausführen von XPK erforderlich sind.

Umgebungsvariablen erstellen

Erstellen Sie in Cloud Shell die folgenden Umgebungsvariablen:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Beschreibung der Befehls-Flags

Variable Beschreibung
NODE_ID Die vom Nutzer zugewiesene ID der Cloud TPU, die erstellt wird, wenn die in der Warteschlange befindliche Ressourcenanfrage zugewiesen wird.
PROJECT_ID Google Cloud ist der Projektname. Sie können ein vorhandenes Projekt verwenden oder ein neues erstellen. Weitere Informationen finden Sie unter Google Cloud Projekt einrichten.
ZONE Eine Liste der unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und ‑Zonen.
ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Das ist die E-Mail-Adresse Ihres Dienstkontos. Sie finden sie unter Google Cloud Console -> IAM -> Dienstkonten.

Beispiel: tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES Die Anzahl der zu erstellenden Schichten (nur für Mehrfachaufnahmen erforderlich).
QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der anstehenden Ressourcenanfrage.
VALID_DURATION Die Dauer, für die die angeforderte Ressource gültig ist.
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Netzwerkleistung optimieren

Für eine optimale Leistung sollten Sie ein Netzwerk mit einer MTU (maximale Übertragungseinheit) von 8.896 verwenden.

Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerks auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Ethernet-Standard) oder 8.896 Byte (maximal möglich). Weitere Informationen finden Sie unter Gültige MTU-Größen für VPC-Netzwerke.

Weitere Informationen zum Ändern der MTU-Einstellung für ein vorhandenes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.

Im folgenden Beispiel wird ein Netzwerk mit einer MTU von 8.896 erstellt.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
 --allow tcp,icmp,udp --project=${PROJECT_ID}

Mehrere NICs verwenden (Option für Multislice)

Die folgenden Umgebungsvariablen sind für ein sekundäres Subnetz erforderlich, wenn Sie eine Multislice-Umgebung verwenden.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Verwenden Sie die folgenden Befehle, um eine benutzerdefinierte IP-Weiterleitung für das Netzwerk und das Subnetz zu erstellen.

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
  --project=${PROJECT_ID} \
  --network=${NETWORK_NAME_2} \
  --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
  --router=${ROUTER_NAME} \
  --region=${REGION} \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project=${PROJECT_ID} \
  --enable-logging

Nachdem Sie ein Multi-Netzwerk-Stück erstellt haben, können Sie prüfen, ob beide Netzwerkschnittstellenkarten (NICs) verwendet werden. Dazu richten Sie einen XPK-Cluster ein und fügen dem Befehl zum Erstellen einer XPK-Arbeitslast das Flag --command ifconfig hinzu.

Verwenden Sie den folgenden xpk workload-Befehl, um die Ausgabe des ifconfig-Befehls in den Google Cloud Console-Protokollen anzuzeigen, und prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8896 haben.

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

Wenn Sie Debug-Logs aktivieren oder Vertex AI TensorBoard verwenden möchten, fügen Sie dem Befehl die folgenden optionalen Argumente hinzu:

    --enable-debug-logs \
    --use-vertex-tensorboard

Prüfen Sie,ob sowohl eth0 als auch eth1 die mtu=8.896 haben. Sie können prüfen, ob die Multi-NIC-Umgebung ausgeführt wird, indem Sie dem Befehl zum Erstellen der XPK-Arbeitslast das Flag --command ifconfig hinzufügen. Prüfen Sie die Ausgabe dieser XPK-Arbeitslast in den Protokollen der Google Cloud Console und achten Sie darauf, dass sowohl eth0 als auch eth1 eine mtu von 8896 haben.

TCP-Einstellungen optimieren

Wenn Sie Ihre Cloud TPUs über die Benutzeroberfläche für in die Warteschlange gestellte Ressourcen erstellt haben, können Sie den folgenden Befehl ausführen, um die Netzwerkleistung zu verbessern, indem Sie die TCP-Empfangsbufferlimits erhöhen.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "${PROJECT}" \
  --zone "${ZONE}" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

Mit in die Warteschlange gestellten Ressourcen bereitstellen

Sie können eine Cloud TPU v6e mit Ressourcen in der Warteschlange erstellen. Mit Ressourcen in der Warteschlange erhalten Sie Kapazität, sobald sie verfügbar ist. Sie können optional einen Start- und Endzeitpunkt für die Ausführung der Anfrage angeben. Weitere Informationen finden Sie unter In der Warteschlange befindliche Ressourcen verwalten.

Cloud TPUs der Version 6e mit GKE oder XPK bereitstellen

Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Kubernetes-Befehle oder XPK verwenden, um Cloud TPUs bereitzustellen und Modelle zu trainieren oder bereitzustellen. Unter Cloud TPUs in GKE planen erfahren Sie, wie Sie Ihre Cloud TPU-Konfigurationen in GKE-Clustern planen. In den folgenden Abschnitten finden Sie Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für eine einzelne NIC und mehrere NICs.

XPK-Cluster mit Unterstützung für eine einzelne NIC erstellen

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=n1-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud ist der Projektname. Sie können ein vorhandenes Projekt verwenden oder ein neues erstellen. Weitere Informationen finden Sie unter Google Cloud Projekt einrichten.
ZONE Eine Liste der unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und ‑Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Scheiben, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Die Anzahl der zu erstellenden Segmente.
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

XPK-Cluster mit Multi-NIC-Unterstützung erstellen

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_1} \
    --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
     --project=${PROJECT_ID} \
     --network=${NETWORK_NAME_2} \
     --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
    --cluster=${CLUSTER_NAME} \
    --num-slices=${NUM_SLICES} \
    --tpu-type=${TPU_TYPE} \
    --zone=${ZONE}  \
    --project=${PROJECT_ID} \
    --on-demand \
    --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
    --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
    --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud ist der Projektname. Sie können ein vorhandenes Projekt verwenden oder ein neues erstellen. Weitere Informationen finden Sie unter Google Cloud Projekt einrichten.
ZONE Eine Liste der unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und ‑Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Scheiben, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Zu verwendendes zusätzliches Knotennetzwerk.

Beispiel: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Die Anzahl der zu erstellenden Schichten (nur für Mehrfachaufnahmen erforderlich).
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Framework einrichten

In diesem Abschnitt wird die allgemeine Einrichtung für das Training von ML-Modellen mit den Frameworks JAX, PyTorch oder TensorFlow beschrieben. Wenn Sie GKE verwenden, können Sie XPK- oder Kubernetes-Befehle für die Framework-Einrichtung verwenden.

Einrichtung für JAX

In diesem Abschnitt finden Sie eine Anleitung zum Ausführen von JAX-Arbeitslasten mit oder ohne XPK in GKE sowie zum Verwenden von Ressourcen in der Warteschlange.

JAX mit GKE einrichten

Einzelnes Slice auf einem einzelnen Host

Im folgenden Beispiel wird ein 2 × 2-Knotenpool mit einem einzelnen Host mithilfe einer Kubernetes-YAML-Datei eingerichtet.

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 4

Einzelnes Slice auf mehreren Hosts

Im folgenden Beispiel wird ein 4 × 4-Knotenpool mit mehreren Hosts mithilfe einer Kubernetes-YAML-Datei eingerichtet.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 16

Multislice auf mehreren Hosts

Im folgenden Beispiel werden zwei 4 × 4-Multihost-Knotenpools mit einer Kubernetes-YAML-Datei eingerichtet.

Als Voraussetzung müssen Sie JobSet v0.2.3 oder höher installieren.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 32

Weitere Informationen finden Sie in der GKE-Dokumentation unter Multi-Slice-Arbeitslast ausführen.

Aktivieren Sie für eine bessere Leistung hostNetwork.

Mehrere NICs

Damit Sie Multi-NIC in GKE nutzen können, muss das Kubernetes-Pod-Manifest zusätzliche Anmerkungen enthalten. Im Folgenden finden Sie ein Beispielmanifest für eine Arbeitslast mit mehreren NICs ohne TPU.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Wenn Sie den Befehl exec verwenden, um eine Verbindung zum Kubernetes-Pod herzustellen, sollte die zusätzliche NIC mit dem folgenden Code angezeigt werden.

$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

JAX mit GKE und XPK einrichten

Informationen zum Einrichten von JAX mit GKE und XPK finden Sie in der README-Datei für XPK.

Informationen zum Einrichten und Ausführen von XPK mit MaxText finden Sie unter MaxText ausführen.

JAX mit in die Warteschlange gestellten Ressourcen einrichten

Sie können JAX mit dem Befehl gcloud alpha compute tpus tpu-vm ssh gleichzeitig auf allen Cloud TPU-VMs in Ihrem Slice oder Ihren Slices installieren. Fügen Sie für „Multislice“ das Flag --node=all hinzu.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Mit dem folgenden Befehl können Sie prüfen, wie viele Cloud TPU-Kerne in Ihrem Slice verfügbar sind, und testen, ob alles richtig installiert ist:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

Die Ausgabe sieht in etwa so aus, wenn der Befehl auf einem v6e-16-Speicherblock ausgeführt wird:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() gibt die Gesamtzahl der Chips im angegebenen Slice an. jax.local_device_count() gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   && pip install -r requirements.txt  && pip install . '

Fehlerbehebung bei der JAX-Einrichtung

Als allgemeinen Tipp können Sie die ausführliche Protokollierung in Ihrem GKE-Arbeitslastmanifest aktivieren. Senden Sie die Protokolle dann an den GKE-Support.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Fehlermeldungen

no endpoints available for service 'jobset-webhook-service'

Dieser Fehler bedeutet, dass der Jobsatz nicht richtig installiert wurde. Prüfen Sie, ob die Kubernetes-Pods der Bereitstellung „jobset-controller-manager“ ausgeführt werden. Weitere Informationen finden Sie in der Dokumentation zur Fehlerbehebung bei JobSets.

TPU initialization failed: Failed to connect

Die GKE-Knotenversion muss 1.30.4-gke.1348000 oder höher sein. GKE 1.31 wird nicht unterstützt.

Einrichtung für PyTorch

In diesem Abschnitt wird beschrieben, wie Sie PJRT in v6e mit PyTorch/XLA verwenden. Python 3.10 ist die empfohlene Python-Version.

PyTorch mit GKE und XPK einrichten

Sie können den folgenden Docker-Container mit XPK verwenden, in dem die PyTorch-Abhängigkeiten bereits installiert sind:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Führen Sie den folgenden Befehl aus, um eine XPK-Arbeitslast zu erstellen:

python3 xpk.py workload create \
    --cluster ${CLUSTER_NAME} \
    {--docker-image | --base-docker-image} us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
    --workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --num-slices=${NUM_SLICES}  \
    --on-demand \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --enable-debug-logs \
    --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

Mit --base-docker-image wird ein neues Docker-Image erstellt, in das das aktuelle Arbeitsverzeichnis eingebunden ist.

PyTorch mit Ressourcen in der Warteschlange einrichten

Führen Sie die folgenden Schritte aus, um PyTorch mit Ressourcen in der Warteschlange zu installieren und ein kleines Script in v6e auszuführen.

Abhängigkeiten über SSH installieren, um auf die VMs zuzugreifen

Mit dem folgenden Befehl können Sie Abhängigkeiten auf allen Cloud TPU-VMs installieren. Fügen Sie für „Multislice“ das Flag --worker=all hinzu:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 && \
               pip install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Leistung von Modellen mit großen, häufigen Zuweisungen verbessern

Bei Modellen mit großen, häufigen Zuweisungen wird mit der tcmalloc-Funktion die Leistung im Vergleich zur Standardimplementierung der malloc-Funktion erheblich verbessert. Daher ist tcmalloc die Standardmalloc-Funktion, die auf Cloud TPU-VMs verwendet wird. Je nach Arbeitslast (z. B. bei DLRM mit sehr großen Zuweisungen für die Einbettungstabellen) kann die tcmalloc-Funktion jedoch zu einer Verlangsamung führen. In diesem Fall können Sie versuchen, die folgende Variable mit der Standardfunktion malloc zurückzusetzen:

unset LD_PRELOAD

Mit einem Python-Script eine Berechnung auf einer v6e-VM ausführen

Mit dem folgenden Befehl wird ein Script ausgeführt, das zwei Tensoren erstellt, addiert und das Ergebnis ausgibt.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Dadurch wird eine Ausgabe generiert, die etwa so aussieht:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Einrichtung für TensorFlow

Sie können die Cloud TPU-Laufzeit mit der v6e-kompatiblen TensorFlow-Version zurücksetzen. Führen Sie dazu die folgenden Befehle aus:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Verwenden Sie SSH, um auf worker-0 zuzugreifen:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Installieren Sie TensorFlow auf worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Exportieren Sie die Umgebungsvariable TPU_NAME:

export TPU_NAME=v6e-16

Mit dem folgenden Python-Script können Sie prüfen, wie viele Cloud TPU-Kerne in Ihrem Slice verfügbar sind, und testen, ob alles korrekt installiert ist:

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

Die Ausgabe sieht in etwa so aus, wenn der Befehl auf einem v6e-16-Speicherblock ausgeführt wird:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e mit SkyPilot

Sie können Cloud TPU v6e mit SkyPilot verwenden. So fügen Sie SkyPilot v6e-bezogene Standort- und Preisinformationen hinzu:

  1. Fügen Sie am Ende der Datei ~/.sky/catalogs/v5/gcp/vms.csv Folgendes hinzu:

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Geben Sie die folgenden Ressourcen in einer YAML-Datei an:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. So starten Sie einen Cluster mit Cloud TPU v6e:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Stellen Sie eine SSH-Verbindung zur Cloud TPU v6e her: ssh tpu_v6

Anleitungen für Inferenz

In den folgenden Anleitungen erfahren Sie, wie Sie die Inferenz auf Cloud TPU v6e ausführen:

Trainingsbeispiele

In den folgenden Abschnitten finden Sie Beispiele für das Training von MaxText-, MaxDiffusion- und PyTorch-Modellen auf Cloud TPU v6e.

MaxText- und MaxDiffusion-Training auf einer v6e-Cloud TPU-VM

In den folgenden Abschnitten wird der Trainingszyklus der Modelle MaxText und MaxDiffusion beschrieben.

Im Allgemeinen sind dies die allgemeinen Schritte:

  1. Erstellen Sie das Basis-Image der Arbeitslast.
  2. Führen Sie Ihre Arbeitslast mit XPK aus.
    1. Erstellen Sie den Trainingsbefehl für die Arbeitslast.
    2. Stellen Sie die Arbeitslast bereit.
  3. Arbeitslast verfolgen und Messwerte ansehen
  4. Löschen Sie die XPK-Arbeitslast, wenn sie nicht benötigt wird.
  5. Löschen Sie den XPK-Cluster, wenn er nicht mehr benötigt wird.

Basis-Image erstellen

Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:

  1. Klonen Sie das gewünschte Repository und wechseln Sie in das Verzeichnis für das Repository:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Konfigurieren Sie Docker so, dass die Google Cloud CLI verwendet wird:

    gcloud auth configure-docker
    
  3. Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit JAX Stable Stack. Weitere Informationen zum JAX Stable Stack finden Sie unter Docker-Image mit JAX Stable Stack erstellen.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    

Arbeitslast mit XPK ausführen

  1. Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die Standardwerte von MaxText oder MaxDiffusion verwenden:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Erstellen Sie das Modellscript. Dieses Script wird in einem späteren Schritt als Trainingsbefehl kopiert.

    Führen Sie das Modellskript noch nicht aus.

    MaxText

    MaxText ist ein leistungsstarkes, hoch skalierbares Open-Source-LLM, das in reiner Python- und JAX-Programmierung geschrieben wurde und für das Training und die Inferenz auf Google Cloud TPUs und GPUs ausgerichtet ist.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma ist eine Familie von LLMs mit offenen Gewichten, die von Google DeepMind entwickelt wurden und auf der Gemini-Forschung und -Technologie basieren.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine sparse MoE-Architektur (Mixture of Experts) verwendet.

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama ist eine Familie von LLMs mit offenen Gewichten, die von Meta entwickelt wurden.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash
    

    MaxDiffusion

    MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reiner Python- und JAX-Programmierung geschrieben wurden und auf XLA-Geräten wie Cloud TPUs und GPUs ausgeführt werden. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, mit dem fotorealistische Bilder aus beliebigen Texteingaben generiert werden.

    Sie müssen einen bestimmten Git-Branch installieren, um MaxDiffusion auszuführen, wie im folgenden git checkout-Befehl gezeigt.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Trainingsskript:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
        
  3. Exportieren Sie die folgenden Variablen:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    CLUSTER_NAME Der Name Ihres XPK-Clusters.
    ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
    NUM_SLICES Die Anzahl der TPU-Slices.
    YOUR_MODEL_SCRIPT Das Modellskript, das als Trainingsbefehl ausgeführt werden soll.
  4. Führen Sie das Modell mit dem Skript aus, das Sie im vorherigen Schritt erstellt haben. Sie müssen entweder das Flag --base-docker-image angeben, um das MaxText-Basisbild zu verwenden, oder das Flag --docker-image und das gewünschte Bild.

    Optional: Sie können das Debug-Logging aktivieren, indem Sie das Flag --enable-debug-logs einfügen. Weitere Informationen finden Sie unter JAX in MaxText debuggen.

    Optional: Sie können einen Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Fügen Sie dazu das Flag --use-vertex-tensorboard hinzu. Weitere Informationen finden Sie unter JAX auf MaxText mit Vertex AI überwachen.

    python3 xpk.py workload create \
        --cluster ${CLUSTER_NAME} \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command=${YOUR_MODEL_SCRIPT}

    Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können. Öffnen Sie den Link und klicken Sie auf den Tab Protokolle, um Ihre Arbeitslast in Echtzeit zu verfolgen.

JAX in MaxText debuggen

Verwenden Sie zusätzliche XPK-Befehle, um zu ermitteln, warum der Cluster oder die Arbeitslast nicht ausgeführt wird.

  • XPK-Arbeitslastliste
  • XPK-Inspektor
  • Aktivieren Sie beim Erstellen der XPK-Arbeitslast mit dem Flag --enable-debug-logs ausführliche Protokolle in Ihren Arbeitslastprotokollen.

JAX auf MaxText mit Vertex AI überwachen

Skalar- und Profildaten über das verwaltete TensorBoard von Vertex AI aufrufen

  1. Erhöhen Sie die Anzahl der Resource Management (CRUD)-Anfragen für die von Ihnen verwendete Zone von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.
  2. Installieren Sie Abhängigkeiten wie cloud-accelerator-diagnostics für Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Erstellen Sie Ihren XPK-Cluster mit dem Flag --create-vertex-tensorboard, wie unter Vertex AI TensorBoard erstellen beschrieben. Sie können diesen Befehl auch auf vorhandenen Clustern ausführen.

  4. Erstellen Sie Ihren Vertex AI-Test, wenn Sie Ihre XPK-Arbeitslast mit dem Flag --use-vertex-tensorboard und dem optionalen Flag --experiment-name ausführen. Eine vollständige Liste der Schritte finden Sie unter Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen.

Die Protokolle enthalten einen Link zu einem Vertex AI TensorBoard, ähnlich wie hier:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Sie finden den Link zum Vertex AI TensorBoard auch in der Google Cloud Console. Rufen Sie in der Google Cloud Console Vertex AI Experiments auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.

Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR} angegeben haben.

XPK-Arbeitslasten löschen

Verwenden Sie den Befehl xpk workload delete, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder dem Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen.

XPK-Cluster löschen

Verwenden Sie den Befehl xpk cluster delete, um einen Cluster zu löschen:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone=${ZONE} --project=${PROJECT_ID}

Llama- und PyTorch/XLA-Training auf einer v6e-Cloud TPU-VM

In dieser Anleitung wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf Cloud TPU v6e mit dem WikiText-Dataset trainieren.

Zugriff auf Hugging Face und das Llama 3-Modell erhalten

Sie benötigen ein Hugging Face-Nutzerzugriffstoken, um dieses Tutorial auszuführen. Informationen zum Erstellen und Verwenden von Nutzerzugriffstokens finden Sie in der Hugging Face-Dokumentation zu Nutzerzugriffstokens.

Außerdem benötigen Sie die Berechtigung, auf das Llama 3 8B-Modell auf Hugging Face zuzugreifen. Wenn Sie Zugriff erhalten möchten, rufen Sie das Meta-Llama-3-8B-Modell auf Hugging Face auf und beantragen Sie Zugriff.

Cloud TPU-VM erstellen

Erstellen Sie eine Cloud TPU v6e mit 8 Chips, um die Anleitung auszuführen.

  1. Richten Sie Umgebungsvariablen ein:

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT_ID=your-project-id
    export ZONE=us-east1-d
  2. So erstellen Sie eine Cloud TPU-VM:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID}

Installation

Installieren Sie den pytorch-tpu/transformers-Fork der Hugging Face-Transformer und die Abhängigkeiten. Diese Anleitung wurde mit den folgenden Abhängigkeitsversionen getestet, die in diesem Beispiel verwendet werden:

  • torch: kompatibel mit 2.5.0
  • torch_xla[tpu]: kompatibel mit 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} --zone ${ZONE} \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Modellkonfigurationen einrichten

Im Trainingsbefehl im nächsten Abschnitt, Modell ausführen, werden zwei JSON-Konfigurationsdateien verwendet, um die Modellparameter und die FSDP-Konfiguration (Fully Sharded Data Parallel) zu definieren. Das FSDP-Sharding wird verwendet, damit die Modellgewichte während des Trainings zu einer größeren Batchgröße passen. Beim Training mit kleineren Modellen reicht es möglicherweise aus, Datenparallelität zu verwenden und die Gewichte auf jedem Gerät zu replizieren. Weitere Informationen zum Sharding von Tensoren auf Geräten in PyTorch/XLA finden Sie im PyTorch/XLA SPMD-Nutzerhandbuch.

  1. Erstellen Sie die Konfigurationsdatei für die Modellparameter. Im Folgenden finden Sie die Modellparameterkonfiguration für Llama3-8B. Für andere Modelle finden Sie die Konfiguration auf Hugging Face. Siehe z. B. die Llama2-7B-Konfiguration.

    cat > llama-config.json << EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. Erstellen Sie die FSDP-Konfigurationsdatei:

    cat > fsdp-config.json << EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Weitere Informationen zu FSDP finden Sie unter FSDPv2.

  3. Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre Cloud TPU-VMs hoch:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

Modell ausführen

Führen Sie mit den Konfigurationsdateien, die Sie im vorherigen Abschnitt erstellt haben, das run_clm.py-Script aus, um das Llama 3 8B-Modell mit dem WikiText-Dataset zu trainieren. Die Ausführung des Trainingsscripts auf einer Cloud TPU v6e-8 dauert etwa 10 Minuten.

  1. Melden Sie sich mit dem folgenden Befehl auf Ihrer Cloud TPU bei Hugging Face an:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Modelltraining ausführen:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

Fehlerbehebung bei PyTorch/XLA

Wenn Sie die optionalen Variablen für das Debuggen im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell an dem Speicherort gespeichert, der in der Variablen PROFILE_LOGDIR angegeben ist. Sie können die dort gespeicherte xplane.pb-Datei extrahieren und mit tensorboard die Profile in Ihrem Browser anzeigen lassen. Folgen Sie dazu der TensorBoard-Anleitung. Wenn PyTorch/XLA nicht wie erwartet funktioniert, lesen Sie den Leitfaden zur Fehlerbehebung. Dort finden Sie Vorschläge zum Debuggen, Profilieren und Optimieren Ihres Modells.

DLRM DCN v2-Training auf v6e

In dieser Anleitung wird beschrieben, wie Sie das DLRM DCN v2-Modell auf Cloud TPU v6e trainieren. Sie müssen eine TPU v6e mit 64, 128 oder 256 Chips bereitstellen.

Wenn Sie eine TPU mit mehreren Hosts verwenden, setzen Sie tpu-runtime mit der entsprechenden TensorFlow-Version zurück. Führen Sie dazu die folgenden Befehle aus. Wenn Sie eine TPU mit einem einzelnen Host verwenden, müssen Sie die folgenden beiden Befehle nicht ausführen.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --worker=all \
    --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone  ${ZONE} \
    --worker=all \
    --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Über SSH eine Verbindung zu worker-0 herstellen

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project ${PROJECT_ID}

Cloud TPU-Namen festlegen

export TPU_NAME=your-tpu-name

DLRM v2 ausführen

Kopieren Sie das folgende Code-Snippet in eine Datei mit dem Namen script.sh:

pip install --user setuptools==65.5.0

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Wenn Sie TensorFlow in GKE ausführen, installieren Sie das TensorFlow Cloud TPU-Wheel und libtpu mit dem folgenden Befehl:

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Legen Sie die folgenden Flags fest, die zum Ausführen von Empfehlungsarbeitslasten (z. B. DLRM DCN) erforderlich sind:

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Führen Sie script.sh aus.

chmod +x script.sh
./script.sh

Benchmarking-Ergebnisse

Der folgende Abschnitt enthält Benchmarking-Ergebnisse für DLRM DCN v2 und MaxDiffusion auf v6e.

DLRM DCN v2

Das DLRM DCN v2-Trainingsskript wurde in verschiedenen Skalen ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.

v6e-64 v6e-128 v6e-256
Trainingsschritte 7000 7000 7000
Globale Batchgröße 131.072 262144 524.288
Durchsatz (Beispiele/Sek.) 2975334 5111808 10066329

MaxDiffusion

Wir haben das Trainingsskript für MaxDiffusion auf einer v6e-4, einer v6e-16 und zwei v6e-16 ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.

v6e-4 v6e-16 Zwei v6e-16
Trainingsschritte 0.069 0,073 0,13
Globale Batchgröße 8 32 64
Durchsatz (Beispiele/Sek.) 115,9 438,4 492,3