Einführung in Trillium (v6e)

In dieser Dokumentation, der TPU API und den Protokollen wird mit „v6e“ auf Trillium verwiesen. „v6e“ steht für die sechste Generation von TPUs von Google.

Mit 256 Chips pro Pod weist die v6e-Architektur viele Ähnlichkeiten mit v5e auf. Dieses System ist für das Training, die Feinabstimmung und die Bereitstellung von Transformer-, Text-zu-Bild- und Convolutional Neural Network-Modellen (CNN) optimiert.

Weitere Informationen zur Systemarchitektur und zu den Konfigurationen von TPU v6e finden Sie unter TPU v6e.

In diesem Einführungsdokument werden die Prozesse für das Modelltraining und die Bereitstellung mit den Frameworks JAX oder PyTorch beschrieben. Mit jedem Framework können Sie TPUs mit in die Warteschlange eingereihten Ressourcen oder GKE bereitstellen. Die GKE-Einrichtung kann mit XPK- oder GKE-Befehlen erfolgen.

Allgemeine Vorgehensweise zum Trainieren oder Bereitstellen eines Modells mit v6e

  1. Google Cloud Projekt vorbereiten
  2. Sichere Kapazität
  3. Cloud TPU-Umgebung bereitstellen
  4. Eine Trainings- oder Inferenz-Arbeitslast für ein Modell ausführen

Google Cloud -Projekt vorbereiten

Bevor Sie Cloud TPU verwenden können, müssen Sie Folgendes tun:

  • Google Cloud Konto und Projekt mit aktivierter Abrechnung erstellen
  • Google Cloud CLI-Alphakomponenten installieren
  • Cloud TPU API aktivieren
  • Cloud TPU-Dienst-Agent erstellen
  • Cloud TPU-Dienstkonto erstellen und Berechtigungen erteilen

Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.

Kapazität sichern

Wenden Sie sich an den Google Cloud -Support, um ein Cloud TPU v6e-Kontingent anzufordern und Fragen zur Kapazität zu klären.

Cloud TPU-Umgebung bereitstellen

v6e Cloud TPU können mit GKE, mit GKE und XPK (ein Wrapper-CLI-Tool für GKE) oder als in die Warteschlange eingereihte Ressourcen bereitgestellt und verwaltet werden.

Vorbereitung

  • Prüfen Sie, ob Ihr Projekt über genügend TPUS_PER_TPU_FAMILY-Kontingent verfügt. Damit wird die maximale Anzahl von Chips angegeben, auf die Sie in Ihrem Google Cloud-Projekt zugreifen können.
  • v6e wurde mit der folgenden Konfiguration getestet:
    • Python 3.10 oder höher
    • Nightly-Softwareversionen:
      • Nächtlicher JAX-Wert 0.4.32.dev20240912
      • Nächtliche LibTPU-Version 0.1.dev20240912+nightly
    • Stabile Softwareversionen:
      • JAX + JAX Lib v0.4.37
  • Prüfen Sie, ob Ihr Projekt genügend Kontingente für Folgendes hat:

    • Cloud TPU-VM-Kontingent
    • Kontingent für IP-Adressen
    • Kontingent für Hyperdisk Balanced und alle anderen Laufwerkstypen, die Sie verwenden möchten

  • Wenn Sie GKE mit XPK verwenden, finden Sie unter Cloud Console-Berechtigungen für das Nutzer- oder Dienstkonto die Berechtigungen, die zum Ausführen von XPK erforderlich sind.

Umgebungsvariablen erstellen

Erstellen Sie in einer Cloud Shell die folgenden Umgebungsvariablen:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Beschreibung der Befehls-Flags

Variable Beschreibung
NODE_ID Die vom Nutzer zugewiesene ID der Cloud TPU, die erstellt wird, wenn die in die Warteschlange gestellte Ressourcenanfrage zugewiesen wird.
PROJECT_ID Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. Weitere Informationen finden Sie unter Google Cloud -Projekt einrichten.
ZONE Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen.
ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Dies ist die E-Mail-Adresse für Ihr Dienstkonto, die Sie in der Google Cloud Console -> IAM -> Dienstkonten finden.

Beispiel: tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES Die Anzahl der zu erstellenden Slices (nur für Multislice erforderlich).
QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage.
VALID_DURATION Die Dauer, für die die in die Warteschlange gestellte Ressourcenanfrage gültig ist.
NETWORK_NAME Der Name eines zu verwendenden sekundären Netzwerks.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Netzwerkleistung optimieren

Für eine optimale Leistung sollten Sie ein Netzwerk mit einer MTU (maximale Übertragungseinheit) von 8.896 verwenden.

Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerk auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Standard-Ethernet) oder 8.896 Byte (das Maximum). Weitere Informationen finden Sie unter Gültige MTU-VPC-Netzwerk-Netzwerke.

Weitere Informationen zum Ändern der MTU-Einstellung für ein vorhandenes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.

Im folgenden Beispiel wird ein Netzwerk mit einer MTU von 8.896 erstellt.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
   --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp --project=${PROJECT_ID}

Multi-NIC verwenden (Option für Multislice)

Die folgenden Umgebungsvariablen sind für ein sekundäres Subnetz erforderlich, wenn Sie eine Multislice-Umgebung verwenden.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Verwenden Sie die folgenden Befehle, um ein benutzerdefiniertes IP-Routing für das Netzwerk und das Subnetz zu erstellen.

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging

Nachdem Sie einen Multi-Network-Slice erstellt haben, können Sie prüfen, ob beide Netzwerkkarten (NICs) verwendet werden. Dazu richten Sie einen XPK-Cluster ein und fügen dem Befehl zum Erstellen von XPK-Arbeitslasten das Flag --command ifconfig hinzu.

Verwenden Sie den folgenden workload create-Befehl, um die Ausgabe des ifconfig-Befehls in den Google Cloud -Konsolenlogs anzuzeigen und zu prüfen, ob sowohl eth0 als auch eth1 mtu=8896 haben.

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

Wenn Sie Debug-Logs aktivieren oder Vertex AI TensorBoard verwenden möchten, fügen Sie dem Befehl die folgenden optionalen Argumente hinzu:

   --enable-debug-logs \
   --use-vertex-tensorboard

Prüfen Sie,ob sowohl eth0 als auch eth1 den Wert „mtu=8896“ haben. Sie können prüfen, ob die Multi-NIC ausgeführt wird, indem Sie dem Befehl zum Erstellen der XPK-Arbeitslast das Flag --command ifconfig hinzufügen. Prüfen Sie die Ausgabe dieser XPK-Arbeitslast in den Google Cloud Konsolenlogs und prüfen Sie,ob sowohl eth0 als auch eth1 den Wert „mtu=8.896“ haben.

TCP-Einstellungen verbessern

Wenn Sie Ihre Cloud TPUs über die Schnittstelle für in die Warteschlange gestellte Ressourcen erstellt haben, können Sie die Netzwerkleistung verbessern, indem Sie die TCP-Empfangspufferlimits mit dem folgenden Befehl erhöhen.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "${PROJECT_ID}" \
   --zone "${ZONE}" \
   --node=all \
   --worker=all \
   --command='
   sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Mit in die Warteschlange gestellten Ressourcen bereitstellen

Sie können eine Cloud TPU v6e mit Ressourcen in der Warteschlange erstellen. Mit Ressourcen in der Warteschlange können Sie Kapazität erhalten, sobald sie verfügbar ist. Sie können optional eine Start- und Endzeit für die Bearbeitung der Anfrage angeben. Weitere Informationen finden Sie unter In die Warteschlange gestellte Ressourcen verwalten.

Cloud TPUs v6e mit GKE oder XPK bereitstellen

Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Kubernetes-Befehle oder XPK verwenden, um Cloud TPUs bereitzustellen und Modelle zu trainieren oder bereitzustellen. Unter Cloud TPUs in GKE planen erfahren Sie, wie Sie Ihre Cloud TPU-Konfigurationen in GKE-Clustern planen. In den folgenden Abschnitten finden Sie Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für einzelne NICs und mehrere NICs.

XPK-Cluster mit Unterstützung für eine einzelne NIC erstellen

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. Weitere Informationen finden Sie unter Google Cloud -Projekt einrichten.
ZONE Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Segmente, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Die Anzahl der zu erstellenden Slices.
NETWORK_NAME Der Name eines zu verwendenden sekundären Netzwerks.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

XPK-Cluster mit Unterstützung für mehrere NICs erstellen

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. Weitere Informationen finden Sie unter Google Cloud -Projekt einrichten.
ZONE Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Segmente, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Zusätzliches Knotennetzwerk, das verwendet werden soll.

Beispiel: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Die Anzahl der zu erstellenden Slices (nur für Multislice erforderlich).
NETWORK_NAME Der Name eines zu verwendenden sekundären Netzwerks.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Framework einrichten

In diesem Abschnitt wird die allgemeine Einrichtung für das Trainieren von ML-Modellen mit den Frameworks JAX und PyTorch beschrieben. Wenn Sie GKE verwenden, können Sie XPK- oder Kubernetes-Befehle für die Einrichtung des Frameworks verwenden.

Einrichtung für JAX

In diesem Abschnitt finden Sie eine Einrichtungsanleitung für die Ausführung von JAX-Arbeitslasten in GKE mit oder ohne XPK sowie für die Verwendung von Ressourcen in der Warteschlange.

JAX mit GKE einrichten

Einzelner Slice auf einem einzelnen Host

Im folgenden Beispiel wird ein 2 × 2-Knotenpool mit einem einzelnen Host mithilfe einer Kubernetes-YAML-Datei eingerichtet.

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

Nach dem erfolgreichen Abschluss des Vorgangs sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 4

Einzelner Slice auf mehreren Hosts

Im folgenden Beispiel wird ein Knotenpool mit mehreren Hosts (4 × 4) mithilfe einer Kubernetes-YAML-Datei eingerichtet.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

Nach dem erfolgreichen Abschluss des Vorgangs sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 16

Multislice auf mehreren Hosts

Im folgenden Beispiel werden zwei 4x4-Knotenpools mit mehreren Hosts mithilfe einer Kubernetes-YAML-Datei eingerichtet.

Als Voraussetzung müssen Sie JobSet v0.2.3 oder höher installieren.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

Nach dem erfolgreichen Abschluss des Vorgangs sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 32

Weitere Informationen finden Sie in der GKE-Dokumentation unter Multi-Slice-Arbeitslast ausführen.

Aktivieren Sie hostNetwork, um die Leistung zu verbessern.

Multi-NIC

Damit Sie das folgende Manifest mit mehreren NICs verwenden können, müssen Sie Ihre Netzwerke einrichten. Weitere Informationen finden Sie unter Unterstützung mehrerer Netzwerke für Kubernetes-Pods einrichten.

Wenn Sie mehrere NICs in GKE nutzen möchten, müssen Sie dem Kubernetes-Pod-Manifest einige zusätzliche Annotationen hinzufügen.

Das Folgende ist ein Beispielmanifest für eine Multi-NIC-Arbeitslast ohne TPU.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Wenn Sie den Befehl exec verwenden, um eine Verbindung zum Kubernetes-Pod herzustellen, sollte die zusätzliche NIC mit dem folgenden Code angezeigt werden:

$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

JAX mit GKE und XPK einrichten

Informationen zum Einrichten von JAX mit GKE und XPK finden Sie in der XPK-README-Datei.

Informationen zum Einrichten und Ausführen von XPK mit MaxText finden Sie unter MaxText ausführen.

JAX mit in die Warteschlange gestellten Ressourcen einrichten

Installieren Sie JAX auf allen Cloud TPU-VMs in Ihrem Slice oder Ihren Slices gleichzeitig mit dem Befehl gcloud alpha compute tpus tpu-vm ssh. Fügen Sie für Multislice das Flag --node=all hinzu.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='
   pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Mit dem folgenden Befehl können Sie prüfen, wie viele Cloud TPU-Kerne in Ihrem Slice verfügbar sind, und testen, ob alles korrekt installiert ist:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='
   python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

Die Ausgabe sieht bei der Ausführung auf einem v6e-16-Slice in etwa so aus:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() gibt die Gesamtzahl der Chips im angegebenen Slice an. jax.local_device_count() gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
   git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 &&
   pip install setuptools==59.6.0 &&
   pip install -r requirements.txt && pip install .'

Probleme bei der JAX-Einrichtung beheben

Ein allgemeiner Tipp ist, das ausführliche Logging in Ihrem GKE-Arbeitslastmanifest zu aktivieren. Stellen Sie die Logs dann dem GKE-Support zur Verfügung.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Fehlermeldungen

no endpoints available for service 'jobset-webhook-service'

Dieser Fehler bedeutet, dass das Jobset nicht richtig installiert wurde. Prüfen Sie, ob die Kubernetes-Pods für die jobset-controller-manager-Bereitstellung ausgeführt werden. Weitere Informationen finden Sie in der Dokumentation zur Fehlerbehebung bei JobSets.

TPU initialization failed: Failed to connect

Die GKE-Knotenversion muss 1.30.4-gke.1348000 oder höher sein (GKE 1.31 wird nicht unterstützt).

Einrichtung für PyTorch

In diesem Abschnitt wird beschrieben, wie Sie PJRT auf v6e mit PyTorch/XLA verwenden. Python 3.10 ist die empfohlene Python-Version.

PyTorch mit GKE und XPK einrichten

Sie können den folgenden Docker-Container mit XPK verwenden, in dem PyTorch-Abhängigkeiten bereits installiert sind:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Verwenden Sie den folgenden Befehl, um eine XPK-Arbeitslast zu erstellen:

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone ${ZONE} \
   --project ${PROJECT_ID} \
   --enable-debug-logs \
   --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

Mit --base-docker-image wird ein neues Docker-Image erstellt, in das das aktuelle Arbeitsverzeichnis eingebunden ist.

PyTorch mit in die Warteschlange gestellten Ressourcen einrichten

So installieren Sie PyTorch mit in die Warteschlange eingereihten Ressourcen und führen ein kleines Skript auf v6e aus:

Abhängigkeiten über SSH installieren, um auf die VMs zuzugreifen

Verwenden Sie den folgenden Befehl, um Abhängigkeiten auf allen Cloud TPU-VMs zu installieren. Fügen Sie für Multislice das Flag --worker=all hinzu:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
   sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
   pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Leistung von Modellen mit großen, häufigen Zuweisungen verbessern

Bei Modellen mit großen, häufigen Zuweisungen wird mit der Funktion tcmalloc die Leistung im Vergleich zur Standardimplementierung der Funktion malloc erheblich verbessert. Daher ist die Standardfunktion malloc, die auf Cloud TPU-VMs verwendet wird, tcmalloc. Je nach Arbeitslast (z. B. bei DLRM, das sehr große Zuweisungen für seine Einbettungstabellen hat) kann die Funktion tcmalloc jedoch zu einer Verlangsamung führen. In diesem Fall können Sie versuchen, die folgende Variable mit der Standardfunktion malloc zu deaktivieren:

unset LD_PRELOAD

Python-Skript zum Ausführen einer Berechnung auf einer v6e-VM verwenden

Verwenden Sie den folgenden Befehl, um ein Skript auszuführen, das zwei Tensoren erstellt, sie addiert und das Ergebnis ausgibt:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} \
   --worker all \
   --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

Dadurch wird eine Ausgabe generiert, die etwa so aussieht:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

v6e mit SkyPilot

Sie können Cloud TPU v6e mit SkyPilot verwenden. So fügen Sie SkyPilot v6e-bezogene Standort- und Preisinformationen hinzu: Weitere Informationen finden Sie im SkyPilot-Beispiel für TPU v6e.

Anleitungen für die Inferenz

In den folgenden Anleitungen wird beschrieben, wie Sie Inferenz auf Cloud TPU v6e ausführen:

Trainingsbeispiele

In den folgenden Abschnitten finden Sie Beispiele für das Training von MaxText-, MaxDiffusion- und PyTorch-Modellen auf Cloud TPU v6e.

MaxText- und MaxDiffusion-Training auf der v6e Cloud TPU-VM

In den folgenden Abschnitten wird der Trainingszyklus der Modelle MaxText und MaxDiffusion beschrieben.

Im Allgemeinen sind folgende Schritte erforderlich:

  1. Erstellen Sie das Basis-Image für die Arbeitslast.
  2. Führen Sie Ihre Arbeitslast mit XPK aus.
    1. Erstellen Sie den Trainingsbefehl für die Arbeitslast.
    2. Stellen Sie die Arbeitslast bereit.
  3. Folgen Sie der Arbeitslast und sehen Sie sich die Messwerte an.
  4. Löschen Sie die XPK-Arbeitslast, wenn sie nicht benötigt wird.
  5. Löschen Sie den XPK-Cluster, wenn er nicht mehr benötigt wird.

Basis-Image erstellen

Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:

  1. Klonen Sie das gewünschte Repository und wechseln Sie in das Verzeichnis des Repositorys:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Konfigurieren Sie Docker für die Verwendung der Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit dem JAX Stable Stack. Weitere Informationen zum JAX Stable Stack finden Sie unter Docker-Image mit JAX Stable Stack erstellen.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Legen Sie Ihre Projekt-ID in der aktiven gcloud CLI-Konfiguration fest:

    gcloud config set project ${PROJECT_ID}
    
  5. Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch.

    1. Legen Sie die Umgebungsvariable CLOUD_IMAGE_NAME fest:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Laden Sie das Bild hoch:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Arbeitslast mit XPK ausführen

  1. Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die von MaxText festgelegten Standardwerte oder MaxDiffusion verwenden:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Modellskript erstellen Dieses Skript wird in einem späteren Schritt als Trainingsbefehl kopiert.

    Führen Sie das Modellskript noch nicht aus.

    MaxText

    MaxText ist ein leistungsstarkes, hochgradig skalierbares Open-Source-LLM, das in reinem Python und JAX geschrieben wurde und auf Google Cloud TPUs und GPUs für Training und Inferenz ausgerichtet ist.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma ist eine Familie von LLMs mit offenen Gewichten, die von Google DeepMind auf Grundlage von Gemini-Forschung und -Technologie entwickelt wurden.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine spärliche MoE-Architektur (Mixture of Experts) nutzt.

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama ist eine Familie von LLMs mit offenen Gewichten, die von Meta entwickelt wurden.

    Ein Beispiel für die Ausführung von Llama3 in PyTorch finden Sie unter torch_xla-Modelle im torchprime-Repository.

    MaxDiffusion

    MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reinem Python und JAX geschrieben sind und auf XLA-Geräten wie Cloud TPUs und GPUs ausgeführt werden. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, mit dem fotorealistische Bilder aus beliebigen Texteingaben generiert werden.

    Sie müssen einen bestimmten Git-Branch installieren, um MaxDiffusion auszuführen, wie im folgenden git clone-Befehl gezeigt.

    Trainingsskript:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16  resolution=1024  per_device_batch_size=1 output_dir=${OUT_DIR} jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=200 attention=flash run_name=sdxl-ddp-v6e
    
  3. Exportieren Sie die folgenden Variablen:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    CLUSTER_NAME Der Name Ihres XPK-Clusters.
    ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
    NUM_SLICES Die Anzahl der TPU-Slices.
    YOUR_MODEL_SCRIPT Das Modellskript, das als Trainingsbefehl ausgeführt werden soll.
  4. Führen Sie das Modell mit dem Skript aus, das Sie im vorherigen Schritt erstellt haben. Sie müssen entweder das Flag --base-docker-image angeben, um das MaxText-Basis-Image zu verwenden, oder das Flag --docker-image und das Image, das Sie verwenden möchten.

    Optional: Sie können das Debug-Logging aktivieren, indem Sie das Flag --enable-debug-logs einfügen. Weitere Informationen finden Sie unter JAX auf MaxText debuggen.

    Optional: Sie können einen Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Dazu müssen Sie das Flag --use-vertex-tensorboard einfügen. Weitere Informationen finden Sie unter JAX auf MaxText mit Vertex AI überwachen.

    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      [--enable-debug-logs] \
      [--use-vertex-tensorboard] \
      --command="${YOUR_MODEL_SCRIPT}"

    Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können. Öffnen Sie den Link und klicken Sie auf den Tab Logs, um Ihre Arbeitslast in Echtzeit zu verfolgen.

JAX in MaxText debuggen

Verwenden Sie zusätzliche XPK-Befehle, um zu ermitteln, warum der Cluster oder die Arbeitslast nicht ausgeführt wird:

  • XPK-Arbeitslastliste
  • XPK-Prüftool
  • Aktivieren Sie das ausführliche Logging in Ihren Arbeitslastlogs mit dem Flag --enable-debug-logs, wenn Sie die XPK-Arbeitslast erstellen.

JAX auf MaxText mit Vertex AI überwachen

Damit Sie TensorBoard verwenden können, muss Ihrem Google Cloud -Nutzerkonto die Rolle aiplatform.user zugewiesen sein. Führen Sie den folgenden Befehl aus, um diese Rolle zuzuweisen:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Skalar- und Profildaten über das verwaltete TensorBoard von Vertex AI ansehen.

  1. Erhöhen Sie die Resource Management (CRUD)-Anfragen für die Zone, die Sie verwenden, von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.

  2. Installieren Sie Abhängigkeiten wie cloud-accelerator-diagnostics für Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Erstellen Sie Ihren XPK-Cluster mit dem Flag --create-vertex-tensorboard, wie in Vertex AI TensorBoard erstellen beschrieben. Sie können diesen Befehl auch für vorhandene Cluster ausführen.

  4. Erstellen Sie Ihren Vertex AI-Test, wenn Sie Ihren XPK-Arbeitslast mit dem Flag --use-vertex-tensorboard und dem optionalen Flag --experiment-name ausführen. Eine vollständige Liste der Schritte finden Sie unter Vertex AI-Test zum Hochladen von Daten in Vertex AI TensorBoard erstellen.

Die Logs enthalten einen Link zu einem Vertex AI TensorBoard, ähnlich dem folgenden:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Sie können den Link zu Vertex AI TensorBoard auch in der Google Cloud Console aufrufen. Rufen Sie Vertex AI Experiments in der Google Cloud Konsole auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.

Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR} angegeben haben.

XPK-Arbeitslasten löschen

Verwenden Sie den Befehl xpk workload delete, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder dem Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen bleiben.

XPK-Cluster löschen

Verwenden Sie den Befehl xpk cluster delete, um einen Cluster zu löschen:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
   --zone=${ZONE} --project=${PROJECT_ID}

Llama- und PyTorch/XLA-Training auf einer Cloud TPU v6e-VM

In dieser Anleitung wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf Cloud TPU v6e mit dem WikiText-Dataset trainieren.

Zugriff auf Hugging Face und das Llama 3-Modell erhalten

Für dieses Tutorial benötigen Sie ein Hugging Face-Nutzerzugriffstoken. Informationen zum Erstellen von Nutzerzugriffstokens finden Sie in der Hugging Face-Dokumentation zu Nutzerzugriffstokens.

Außerdem benötigen Sie die Berechtigung für den Zugriff auf das Modell „Llama-3-8B“ auf Hugging Face. Wenn Sie Zugriff erhalten möchten, rufen Sie das Meta-Llama-3-8B-Modell auf Hugging Face auf und beantragen Sie den Zugriff.

Cloud TPU-VM erstellen

Erstellen Sie eine Cloud TPU v6e mit 8 Chips, um das Tutorial auszuführen.

  1. Richten Sie Umgebungsvariablen ein:

    export NODE_ID=your-tpu-name
    export PROJECT_ID=your-project-id
    export ACCELERATOR_TYPE=v6e-8
    export ZONE=us-east1-d
    export RUNTIME_VERSION=v2-alpha-tpuv6e
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export VALID_DURATION=your-duration 
  2. Cloud TPU-VM erstellen:

    gcloud alpha compute tpus tpu-vm create ${NODE_ID} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Installation

Installieren Sie den pytorch-tpu/transformers-Fork von Hugging Face-Transformern und Abhängigkeiten. Diese Anleitung wurde mit den folgenden Abhängigkeitsversionen getestet, die in diesem Beispiel verwendet werden:

  • torch: kompatibel mit 2.5.0
  • torch_xla[tpu]: kompatibel mit 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Modellkonfigurationen einrichten

Im Trainingsbefehl im nächsten Abschnitt Modell ausführen werden zwei JSON-Konfigurationsdateien verwendet, um Modellparameter und die Konfiguration für Fully Sharded Data Parallel (FSDP) zu definieren. Mit FSDP-Sharding können Sie beim Training eine größere Batchgröße verwenden, da die Modellgewichte auf mehrere TPUs verteilt werden. Beim Training mit kleineren Modellen kann es ausreichen, Datenparallelität zu verwenden und die Gewichte auf jedem Gerät zu replizieren. Weitere Informationen zum Sharding von Tensoren auf Geräten in PyTorch/XLA finden Sie im SPMD-Nutzerhandbuch für PyTorch/XLA.

  1. Erstellen Sie die Konfigurationsdatei für Modellparameter. Im Folgenden finden Sie die Modellparameterkonfiguration für Llama-3-8B. Die Konfiguration für andere Modelle finden Sie auf Hugging Face. Ein Beispiel finden Sie in der Llama-2-7B-Konfiguration.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Erstellen Sie die FSDP-Konfigurationsdatei:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Weitere Informationen zum FSDP finden Sie unter FSDPv2.

  3. Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre Cloud TPU-VMs hoch:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${NODE_ID}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Modell ausführen

Führen Sie mit den Konfigurationsdateien, die Sie im vorherigen Abschnitt erstellt haben, das Skript run_clm.py aus, um das Modell Llama-3-8B mit dem WikiText-Dataset zu trainieren. Die Ausführung des Trainingsskripts dauert auf einer Cloud TPU v6e-8 etwa 10 Minuten.

  1. Melden Sie sich mit dem folgenden Befehl auf Ihrer Cloud TPU bei Hugging Face an:

    gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Modelltraining ausführen:

    gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Fehlerbehebung bei PyTorch/XLA

Wenn Sie die optionalen Variablen für das Debugging im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell am Speicherort gespeichert, der durch die Variable PROFILE_LOGDIR angegeben wird. Sie können die xplane.pb-Datei, die an diesem Speicherort gespeichert ist, extrahieren und tensorboard verwenden, um die Profile in Ihrem Browser gemäß der TensorBoard-Anleitung anzusehen.

Wenn PyTorch/XLA nicht wie erwartet funktioniert, finden Sie im Leitfaden zur Fehlerbehebung Vorschläge zum Debuggen, Profilerstellen und Optimieren Ihres Modells.

Benchmarking-Ergebnisse

Der folgende Abschnitt enthält Benchmark-Ergebnisse für MaxDiffusion auf v6e.

MaxDiffusion

Wir haben das Trainingsskript für MaxDiffusion auf einer v6e-4, einer v6e-16 und zwei v6e-16 ausgeführt. Die Durchsatzraten sind in der folgenden Tabelle aufgeführt.

v6e-4 v6e-16 Zwei v6e-16
Trainingsschritte 0.069 0,073 0,13
Globale Batchgröße 8 32 64
Durchsatz (Beispiele/s) 115.9 438,4 492,3