Trillium (v6e) – Einführung

In dieser Dokumentation, in der TPU API und in Protokollen wird „v6e“ für Trillium verwendet. „v6e“ steht für die sechste Generation von TPU von Google.

Mit 256 Chips pro Pod hat die v6e-Architektur viele Ähnlichkeiten mit v5e. Dieses System ist für das Training, die Feinabstimmung und die Bereitstellung von Transformern, Text-zu-Bild-Modellen und CNNs (Convolutional Neural Networks) optimiert.

Weitere Informationen zur Systemarchitektur und zu den Konfigurationen von v6e finden Sie unter TPU v6e.

In diesem Einführungsdokument liegt der Schwerpunkt auf den Prozessen für das Modelltraining und die Bereitstellung mit den Frameworks JAX oder PyTorch. Mit jedem Framework können Sie TPUs mit Ressourcen in der Warteschlange oder GKE bereitstellen. Die GKE-Einrichtung kann mit XPK- oder GKE-Befehlen erfolgen.

Allgemeines Verfahren zum Trainieren oder Bereitstellen eines Modells mit v6e

  1. Google Cloud Projekt vorbereiten
  2. Sichere Kapazität
  3. Cloud TPU-Umgebung bereitstellen
  4. Eine Arbeitslast für das Training oder die Inferenz eines Modells ausführen

Google Cloud -Projekt vorbereiten

Bevor Sie Cloud TPU verwenden können, müssen Sie Folgendes tun:

Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.

Sichere Kapazität

Wenden Sie sich an den Google Cloud -Support, um ein Cloud TPU v6e-Kontingent anzufordern und Fragen zur Kapazität zu stellen.

Cloud TPU-Umgebung bereitstellen

v6e Cloud TPU kann mit GKE, mit GKE und XPK (einem Befehlszeilen-Wrapper-Tool über GKE) oder als Ressourcen in der Warteschlange bereitgestellt und verwaltet werden.

Vorbereitung

  • Prüfen Sie, ob Ihr Projekt ein ausreichendes TPUS_PER_TPU_FAMILY-Kontingent hat. Dieses gibt die maximale Anzahl von Chips an, auf die Sie in Ihrem Google Cloud-Projekt zugreifen können.
  • v6e wurde mit der folgenden Konfiguration getestet:
    • Python 3.10 oder höher
    • Nightly-Softwareversionen:
      • Täglich JAX 0.4.32.dev20240912
      • Nightly LibTPU 0.1.dev20240912+nightly
    • Stabile Softwareversionen:
      • JAX + JAX-Bibliothek der Version 0.4.37
  • Prüfen Sie, ob Ihr Projekt für Folgendes genügend Kontingente hat:

    • Cloud TPU-VM-Kontingent
    • Kontingent für IP-Adressen
    • Kontingent für Hyperdisk Balanced und alle anderen Laufwerktypen, die Sie verwenden möchten

  • Wenn Sie GKE mit XPK verwenden, finden Sie unter Cloud Console-Berechtigungen für das Nutzer- oder Dienstkonto Informationen zu den Berechtigungen, die zum Ausführen von XPK erforderlich sind.

Umgebungsvariablen erstellen

Erstellen Sie in Cloud Shell die folgenden Umgebungsvariablen:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Beschreibung der Befehls-Flags

Variable Beschreibung
NODE_ID Die vom Nutzer zugewiesene ID der Cloud TPU, die erstellt wird, wenn die in der Warteschlange befindliche Ressourcenanfrage zugewiesen wird.
PROJECT_ID Google Cloud ist der Projektname. Sie können ein vorhandenes Projekt verwenden oder ein neues erstellen. Weitere Informationen finden Sie unter Google Cloud Projekt einrichten.
ZONE Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und ‑Zonen.
ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Das ist die E-Mail-Adresse Ihres Dienstkontos. Sie finden sie unter Google Cloud Console -> IAM -> Dienstkonten.

Beispiel: tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES Die Anzahl der zu erstellenden Schichten (nur für Mehrfachaufnahmen erforderlich).
QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der anstehenden Ressourcenanfrage.
VALID_DURATION Die Dauer, für die die angeforderte Ressource gültig ist.
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Netzwerkleistung optimieren

Für eine optimale Leistung sollten Sie ein Netzwerk mit einer MTU (maximale Übertragungseinheit) von 8.896 verwenden.

Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerk auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Ethernet-Standard) oder 8.896 Byte (maximal möglich). Weitere Informationen finden Sie unter Gültige MTU-VPC-Netzwerk-Netzwerke.

Weitere Informationen zum Ändern der MTU-Einstellung für ein vorhandenes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.

Im folgenden Beispiel wird ein Netzwerk mit einer MTU von 8.896 erstellt.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
 --allow tcp,icmp,udp --project=${PROJECT_ID}

Mehrere NICs verwenden (Option für Multislice)

Die folgenden Umgebungsvariablen sind für ein sekundäres Subnetz erforderlich, wenn Sie eine Multislice-Umgebung verwenden.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Verwenden Sie die folgenden Befehle, um eine benutzerdefinierte IP-Weiterleitung für das Netzwerk und das Subnetz zu erstellen.

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
  --project=${PROJECT_ID} \
  --network=${NETWORK_NAME_2} \
  --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
  --router=${ROUTER_NAME} \
  --region=${REGION} \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project=${PROJECT_ID} \
  --enable-logging

Nachdem Sie ein Multi-Netzwerk-Speicher-Slice erstellt haben, können Sie prüfen, ob beide Netzwerkschnittstellenkarten (NICs) verwendet werden. Dazu richten Sie einen XPK-Cluster ein und fügen dem Befehl zum Erstellen einer XPK-Arbeitslast das Flag --command ifconfig hinzu.

Verwenden Sie den folgenden xpk workload-Befehl, um die Ausgabe des ifconfig-Befehls in den Google Cloud -Konsolenprotokollen anzuzeigen, und prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8896 haben.

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

Wenn Sie Debug-Logs aktivieren oder Vertex AI TensorBoard verwenden möchten, fügen Sie dem Befehl die folgenden optionalen Argumente hinzu:

    --enable-debug-logs \
    --use-vertex-tensorboard

Prüfen Sie,ob sowohl eth0 als auch eth1 die mtu=8.896 haben. Sie können prüfen, ob die Multi-NIC-Umgebung ausgeführt wird, indem Sie dem Befehl zum Erstellen der XPK-Arbeitslast das Flag --command ifconfig hinzufügen. Prüfen Sie die Ausgabe dieser XPK-Arbeitslast in den Google Cloud Console-Protokollen und achten Sie darauf, dass sowohl eth0 als auch eth1 die mtu=8896 haben.

TCP-Einstellungen optimieren

Wenn Sie Ihre Cloud TPUs über die Benutzeroberfläche für in die Warteschlange gestellte Ressourcen erstellt haben, können Sie den folgenden Befehl ausführen, um die Netzwerkleistung zu verbessern, indem Sie die TCP-Empfangsbufferlimits erhöhen.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "${PROJECT_ID}" \
  --zone "${ZONE}" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

Mit in die Warteschlange gestellten Ressourcen bereitstellen

Sie können eine Cloud TPU v6e mit Ressourcen in der Warteschlange erstellen. Mit Ressourcen in der Warteschlange erhalten Sie Kapazität, sobald sie verfügbar ist. Sie können optional einen Start- und Endzeitpunkt für die Ausführung der Anfrage angeben. Weitere Informationen finden Sie unter In der Warteschlange befindliche Ressourcen verwalten.

Cloud TPUs der Version 6e mit GKE oder XPK bereitstellen

Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Kubernetes-Befehle oder XPK verwenden, um Cloud TPUs bereitzustellen und Modelle zu trainieren oder bereitzustellen. Unter Cloud TPUs in GKE planen erfahren Sie, wie Sie Ihre Cloud TPU-Konfigurationen in GKE-Clustern planen. In den folgenden Abschnitten finden Sie Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für eine einzelne NIC und mehrere NICs.

XPK-Cluster mit Unterstützung für eine einzelne NIC erstellen

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud ist der Projektname. Sie können ein vorhandenes Projekt verwenden oder ein neues erstellen. Weitere Informationen finden Sie unter Google Cloud Projekt einrichten.
ZONE Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und ‑Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Scheiben, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Die Anzahl der zu erstellenden Segmente.
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

XPK-Cluster mit Multi-NIC-Unterstützung erstellen

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_1} \
    --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
     --project=${PROJECT_ID} \
     --network=${NETWORK_NAME_2} \
     --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
    --cluster=${CLUSTER_NAME} \
    --cluster-cpu-machine-type=e2-standard-8 \
    --num-slices=${NUM_SLICES} \
    --tpu-type=${TPU_TYPE} \
    --zone=${ZONE}  \
    --project=${PROJECT_ID} \
    --on-demand \
    --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
    --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
    --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud ist der Projektname. Sie können ein vorhandenes Projekt verwenden oder ein neues erstellen. Weitere Informationen finden Sie unter Google Cloud Projekt einrichten.
ZONE Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und ‑Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Scheiben, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Zu verwendendes zusätzliches Knotennetzwerk.

Beispiel: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Die Anzahl der zu erstellenden Schichten (nur für Mehrfachaufnahmen erforderlich).
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Framework einrichten

In diesem Abschnitt wird die allgemeine Einrichtung für das Training von ML-Modellen mit den Frameworks JAX und PyTorch beschrieben. Wenn Sie GKE verwenden, können Sie XPK- oder Kubernetes-Befehle für die Framework-Einrichtung verwenden.

Einrichtung für JAX

In diesem Abschnitt finden Sie eine Anleitung zum Ausführen von JAX-Arbeitslasten mit oder ohne XPK in GKE sowie zum Verwenden von Ressourcen in der Warteschlange.

JAX mit GKE einrichten

Einzelnes Slice auf einem einzelnen Host

Im folgenden Beispiel wird ein 2 × 2-Knotenpool mit einem einzelnen Host mithilfe einer Kubernetes-YAML-Datei eingerichtet.

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 4

Einzelnes Slice auf mehreren Hosts

Im folgenden Beispiel wird ein 4 × 4-Knotenpool mit mehreren Hosts mithilfe einer Kubernetes-YAML-Datei eingerichtet.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 16

Multislice auf mehreren Hosts

Im folgenden Beispiel werden zwei 4 × 4-Multihost-Knotenpools mit einer Kubernetes-YAML-Datei eingerichtet.

Als Voraussetzung müssen Sie JobSet v0.2.3 oder höher installieren.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 32

Weitere Informationen finden Sie in der GKE-Dokumentation unter Multi-Slice-Arbeitslast ausführen.

Aktivieren Sie für eine bessere Leistung hostNetwork.

Mehrere NICs

Damit Sie das folgende Manifest mit mehreren NICs verwenden können, müssen Sie Ihre Netzwerke einrichten. Weitere Informationen finden Sie unter Unterstützung mehrerer Netzwerke für Kubernetes-Pods einrichten. Wenn Sie mehrere NICs in GKE nutzen möchten, müssen Sie dem Kubernetes-Pod-Manifest einige zusätzliche Anmerkungen hinzufügen. Im Folgenden finden Sie ein Beispielmanifest für eine Arbeitslast mit mehreren NICs ohne TPU.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Wenn Sie den Befehl exec verwenden, um eine Verbindung zum Kubernetes-Pod herzustellen, sollte die zusätzliche NIC mit dem folgenden Code angezeigt werden.

$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

JAX mit GKE und XPK einrichten

Informationen zum Einrichten von JAX mit GKE und XPK finden Sie in der README-Datei für XPK.

Informationen zum Einrichten und Ausführen von XPK mit MaxText finden Sie unter MaxText ausführen.

JAX mit in die Warteschlange gestellten Ressourcen einrichten

Sie können JAX mit dem Befehl gcloud alpha compute tpus tpu-vm ssh gleichzeitig auf allen Cloud TPU-VMs in Ihrem Slice oder Ihren Slices installieren. Fügen Sie für „Multislice“ das Flag --node=all hinzu.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Mit dem folgenden Befehl können Sie prüfen, wie viele Cloud TPU-Kerne in Ihrem Slice verfügbar sind, und testen, ob alles richtig installiert ist:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

Die Ausgabe sieht in etwa so aus, wenn der Befehl auf einem v6e-16-Speicherblock ausgeführt wird:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() gibt die Gesamtzahl der Chips im angegebenen Slice an. jax.local_device_count() gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.


 gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   pip install setuptools==59.6.0 &&
   pip install -r requirements.txt  && pip install . '

Probleme bei der JAX-Einrichtung beheben

Als allgemeinen Tipp können Sie die ausführliche Protokollierung in Ihrem GKE-Arbeitslastmanifest aktivieren. Senden Sie die Protokolle dann an den GKE-Support.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Fehlermeldungen

no endpoints available for service 'jobset-webhook-service'

Dieser Fehler bedeutet, dass der Jobsatz nicht richtig installiert wurde. Prüfen Sie, ob die Kubernetes-Pods der Bereitstellung „jobset-controller-manager“ ausgeführt werden. Weitere Informationen finden Sie in der Dokumentation zur Fehlerbehebung bei JobSets.

TPU initialization failed: Failed to connect

Die GKE-Knotenversion muss 1.30.4-gke.1348000 oder höher sein. GKE 1.31 wird nicht unterstützt.

Einrichtung für PyTorch

In diesem Abschnitt wird beschrieben, wie Sie PJRT unter v6e mit PyTorch/XLA verwenden. Python 3.10 ist die empfohlene Python-Version.

PyTorch mit GKE und XPK einrichten

Sie können den folgenden Docker-Container mit XPK verwenden, in dem die PyTorch-Abhängigkeiten bereits installiert sind:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Führen Sie den folgenden Befehl aus, um eine XPK-Arbeitslast zu erstellen:

python3 xpk.py workload create \
    --cluster ${CLUSTER_NAME} \
    {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name \
    --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --num-slices=${NUM_SLICES}  \
    --on-demand \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --enable-debug-logs \
    --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

Mit --base-docker-image wird ein neues Docker-Image erstellt, in das das aktuelle Arbeitsverzeichnis eingebunden ist.

PyTorch mit Ressourcen in der Warteschlange einrichten

Führen Sie die folgenden Schritte aus, um PyTorch mit Ressourcen in der Warteschlange zu installieren und ein kleines Script auf v6e auszuführen.

Abhängigkeiten über SSH installieren, um auf die VMs zuzugreifen

Mit dem folgenden Befehl können Sie Abhängigkeiten auf allen Cloud TPU-VMs installieren. Fügen Sie für „Multislice“ das Flag --worker=all hinzu:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
               pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Leistung von Modellen mit großen, häufigen Zuweisungen verbessern

Bei Modellen mit großen, häufigen Zuweisungen wird mit der tcmalloc-Funktion die Leistung im Vergleich zur Standardimplementierung der malloc-Funktion erheblich verbessert. Daher ist die Standardmalloc-Funktion, die auf Cloud TPU-VMs verwendet wird, tcmalloc. Je nach Arbeitslast (z. B. bei DLRM mit sehr großen Zuweisungen für die Einbettungstabellen) kann die tcmalloc-Funktion jedoch zu einer Verlangsamung führen. In diesem Fall können Sie versuchen, die folgende Variable mit der Standardfunktion malloc zurückzusetzen:

unset LD_PRELOAD

Mit einem Python-Script eine Berechnung auf einer v6e-VM ausführen

Mit dem folgenden Befehl wird ein Script ausgeführt, das zwei Tensoren erstellt, addiert und das Ergebnis ausgibt.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Dadurch wird eine Ausgabe generiert, die etwa so aussieht:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

v6e mit SkyPilot

Sie können Cloud TPU v6e mit SkyPilot verwenden. So fügen Sie SkyPilot v6e-bezogene Standort- und Preisinformationen hinzu: Weitere Informationen finden Sie im Beispiel für SkyPilot TPU v6e.

Anleitungen für Inferenz

In den folgenden Anleitungen erfahren Sie, wie Sie die Inferenz auf Cloud TPU v6e ausführen:

Trainingsbeispiele

In den folgenden Abschnitten finden Sie Beispiele für das Training von MaxText-, MaxDiffusion- und PyTorch-Modellen auf Cloud TPU v6e.

MaxText- und MaxDiffusion-Training auf einer v6e-Cloud TPU-VM

In den folgenden Abschnitten wird der Trainingszyklus der Modelle MaxText und MaxDiffusion beschrieben.

Im Allgemeinen sind dies die allgemeinen Schritte:

  1. Erstellen Sie das Basis-Image der Arbeitslast.
  2. Führen Sie Ihre Arbeitslast mit XPK aus.
    1. Erstellen Sie den Trainingsbefehl für die Arbeitslast.
    2. Stellen Sie die Arbeitslast bereit.
  3. Arbeitslast verfolgen und Messwerte ansehen
  4. Löschen Sie die XPK-Arbeitslast, wenn sie nicht benötigt wird.
  5. Löschen Sie den XPK-Cluster, wenn er nicht mehr benötigt wird.

Basis-Image erstellen

Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:

  1. Klonen Sie das gewünschte Repository und wechseln Sie in das Verzeichnis für das Repository:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Konfigurieren Sie Docker so, dass die Google Cloud CLI verwendet wird:

    gcloud auth configure-docker
    
  3. Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit JAX Stable Stack. Weitere Informationen zum JAX Stable Stack finden Sie unter Docker-Image mit JAX Stable Stack erstellen.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. Legen Sie Ihre Projekt-ID in der aktiven gcloud CLI-Konfiguration fest:

    gcloud config set project ${PROJECT_ID}
    
  5. Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch.

    1. Legen Sie die Umgebungsvariable CLOUD_IMAGE_NAME fest:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Laden Sie das Bild hoch:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Arbeitslast mit XPK ausführen

  1. Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die Standardwerte von MaxText oder MaxDiffusion verwenden:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Erstellen Sie das Modellscript. Dieses Script wird in einem späteren Schritt als Trainingsbefehl kopiert.

    Führen Sie das Modellskript noch nicht aus.

    MaxText

    MaxText ist ein leistungsstarkes, hoch skalierbares Open-Source-LLM, das in reiner Python- und JAX-Programmierung geschrieben wurde und für das Training und die Inferenz auf Google Cloud TPUs und GPUs ausgerichtet ist.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma ist eine Familie von LLMs mit offenen Gewichten, die von Google DeepMind entwickelt wurden und auf der Gemini-Forschung und -Technologie basieren.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine sparse MoE-Architektur (Mixture of Experts) verwendet.

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama ist eine Familie von LLMs mit offenen Gewichten, die von Meta entwickelt wurden.

    Ein Beispiel für die Ausführung von Llama3 in PyTorch finden Sie unter torch_xla-Modelle im Repository „torchprime“.

    MaxDiffusion

    MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reiner Python- und JAX-Programmierung geschrieben wurden und auf XLA-Geräten wie Cloud TPUs und GPUs ausgeführt werden. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, mit dem fotorealistische Bilder aus beliebigen Texteingaben generiert werden.

    Sie müssen einen bestimmten Git-Branch installieren, um MaxDiffusion auszuführen, wie im folgenden git checkout-Befehl gezeigt.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Trainingsskript:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
        
  3. Exportieren Sie die folgenden Variablen:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    CLUSTER_NAME Der Name Ihres XPK-Clusters.
    ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
    NUM_SLICES Die Anzahl der TPU-Slices.
    YOUR_MODEL_SCRIPT Das Modellskript, das als Trainingsbefehl ausgeführt werden soll.
  4. Führen Sie das Modell mit dem im vorherigen Schritt erstellten Script aus. Sie müssen entweder das Flag --base-docker-image angeben, um das MaxText-Basisbild zu verwenden, oder das Flag --docker-image und das gewünschte Bild.

    Optional: Sie können das Debug-Logging aktivieren, indem Sie das Flag --enable-debug-logs einfügen. Weitere Informationen finden Sie unter JAX in MaxText debuggen.

    Optional: Sie können einen Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Fügen Sie dazu das Flag --use-vertex-tensorboard hinzu. Weitere Informationen finden Sie unter JAX auf MaxText mit Vertex AI überwachen.

    python3 xpk.py workload create \
        --cluster ${CLUSTER_NAME} \
        {--base-docker-image maxtext_base_image|--docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command=${YOUR_MODEL_SCRIPT}

    Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können. Öffnen Sie den Link und klicken Sie auf den Tab Protokolle, um Ihre Arbeitslast in Echtzeit zu verfolgen.

JAX in MaxText debuggen

Verwenden Sie zusätzliche XPK-Befehle, um zu ermitteln, warum der Cluster oder die Arbeitslast nicht ausgeführt wird.

  • XPK-Arbeitslastliste
  • XPK-Inspektor
  • Aktivieren Sie beim Erstellen der XPK-Arbeitslast mit dem Flag --enable-debug-logs ausführliche Protokolle in Ihren Arbeitslastprotokollen.

JAX auf MaxText mit Vertex AI überwachen

Damit Sie TensorBoard verwenden können, muss Ihr Nutzerkonto die Rolle aiplatform.user haben. Google Cloud Führen Sie den folgenden Befehl aus, um diese Rolle zuzuweisen:

gcloud projects add-iam-policy-binding your-project-id \
--member='user:your-email' \
--role='roles/aiplatform.user'

Skalar- und Profildaten über das verwaltete TensorBoard von Vertex AI aufrufen

  1. Erhöhen Sie die Anzahl der Resource Management (CRUD)-Anfragen für die von Ihnen verwendete Zone von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.
  2. Installieren Sie Abhängigkeiten wie cloud-accelerator-diagnostics für Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Erstellen Sie Ihren XPK-Cluster mit dem Flag --create-vertex-tensorboard, wie unter Vertex AI TensorBoard erstellen beschrieben. Sie können diesen Befehl auch auf vorhandenen Clustern ausführen.

  4. Erstellen Sie Ihren Vertex AI-Test, wenn Sie Ihre XPK-Arbeitslast mit dem Flag --use-vertex-tensorboard und dem optionalen Flag --experiment-name ausführen. Eine vollständige Liste der Schritte finden Sie unter Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen.

Die Protokolle enthalten einen Link zu einem Vertex AI TensorBoard, ähnlich wie hier:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Sie finden den Link zum Vertex AI TensorBoard auch in der Google Cloud Console. Rufen Sie in der Google Cloud -Konsole Vertex AI Experiments auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.

Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR} angegeben haben.

XPK-Arbeitslasten löschen

Verwenden Sie den Befehl xpk workload delete, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder dem Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen.

XPK-Cluster löschen

Verwenden Sie den Befehl xpk cluster delete, um einen Cluster zu löschen:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone=${ZONE} --project=${PROJECT_ID}

Llama- und PyTorch/XLA-Training auf einer v6e-Cloud TPU-VM

In dieser Anleitung wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf Cloud TPU v6e mit dem WikiText-Dataset trainieren.

Zugriff auf Hugging Face und das Llama 3-Modell erhalten

Sie benötigen ein Hugging Face-Nutzerzugriffstoken, um dieses Tutorial auszuführen. Informationen zum Erstellen und Verwenden von Nutzerzugriffstokens finden Sie in der Hugging Face-Dokumentation zu Nutzerzugriffstokens.

Außerdem benötigen Sie die Berechtigung, auf das Llama 3 8B-Modell auf Hugging Face zuzugreifen. Wenn Sie Zugriff erhalten möchten, rufen Sie das Meta-Llama-3-8B-Modell auf Hugging Face auf und beantragen Sie Zugriff.

Cloud TPU-VM erstellen

Erstellen Sie eine Cloud TPU v6e mit 8 Chips, um die Anleitung auszuführen.

  1. Richten Sie Umgebungsvariablen ein:

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT_ID=your-project-id
    export ZONE=us-east1-d
  2. So erstellen Sie eine Cloud TPU-VM:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID}

Installation

Installieren Sie den pytorch-tpu/transformers-Fork der Hugging Face-Transformer und die Abhängigkeiten. Diese Anleitung wurde mit den folgenden Abhängigkeitsversionen getestet, die in diesem Beispiel verwendet werden:

  • torch: kompatibel mit 2.5.0
  • torch_xla[tpu]: kompatibel mit 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} --zone ${ZONE} \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Modellkonfigurationen einrichten

Im Trainingsbefehl im nächsten Abschnitt, Modell ausführen, werden zwei JSON-Konfigurationsdateien verwendet, um die Modellparameter und die FSDP-Konfiguration (Fully Sharded Data Parallel) zu definieren. Das FSDP-Sharding wird verwendet, damit die Modellgewichte während des Trainings zu einer größeren Batchgröße passen. Beim Training mit kleineren Modellen reicht es möglicherweise aus, Datenparallelität zu verwenden und die Gewichte auf jedem Gerät zu replizieren. Weitere Informationen zum Sharding von Tensoren auf Geräten in PyTorch/XLA finden Sie im PyTorch/XLA SPMD-Nutzerhandbuch.

  1. Erstellen Sie die Konfigurationsdatei für die Modellparameter. Im Folgenden finden Sie die Modellparameterkonfiguration für Llama3-8B. Für andere Modelle finden Sie die Konfiguration auf Hugging Face. Siehe z. B. die Llama2-7B-Konfiguration.

    cat > llama-config.json << EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. Erstellen Sie die FSDP-Konfigurationsdatei:

    cat > fsdp-config.json << EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Weitere Informationen zu FSDP finden Sie unter FSDPv2.

  3. Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre Cloud TPU-VMs hoch:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

Modell ausführen

Führen Sie mit den Konfigurationsdateien, die Sie im vorherigen Abschnitt erstellt haben, das run_clm.py-Script aus, um das Llama 3 8B-Modell auf dem WikiText-Dataset zu trainieren. Die Ausführung des Trainingsscripts auf einer Cloud TPU v6e-8 dauert etwa 10 Minuten.

  1. Melden Sie sich mit dem folgenden Befehl auf Ihrer Cloud TPU bei Hugging Face an:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Modelltraining ausführen:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

Fehlerbehebung bei PyTorch/XLA

Wenn Sie die optionalen Variablen für das Debuggen im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell an dem Speicherort gespeichert, der in der Variablen PROFILE_LOGDIR angegeben ist. Sie können die dort gespeicherte xplane.pb-Datei extrahieren und mit tensorboard die Profile in Ihrem Browser anzeigen lassen. Folgen Sie dazu der TensorBoard-Anleitung. Wenn PyTorch/XLA nicht wie erwartet funktioniert, lesen Sie den Leitfaden zur Fehlerbehebung. Dort finden Sie Vorschläge zum Debuggen, Profilieren und Optimieren Ihres Modells.

Benchmarking-Ergebnisse

Der folgende Abschnitt enthält Benchmarkergebnisse für MaxDiffusion in v6e.

MaxDiffusion

Wir haben das Trainingsskript für MaxDiffusion auf einer v6e-4, einer v6e-16 und zwei v6e-16 ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.

v6e-4 v6e-16 Zwei v6e-16
Trainingsschritte 0.069 0,073 0,13
Globale Batchgröße 8 32 64
Durchsatz (Beispiele/Sek.) 115,9 438,4 492,3