Einführung in Trillium (v6e)
In dieser Dokumentation, der TPU API und den Protokollen wird mit „v6e“ auf Trillium verwiesen. „v6e“ steht für die sechste Generation von TPUs von Google.
Mit 256 Chips pro Pod weist die v6e-Architektur viele Ähnlichkeiten mit v5e auf. Dieses System ist für das Training, die Feinabstimmung und die Bereitstellung von Transformer-, Text-zu-Bild- und Convolutional Neural Network-Modellen (CNN) optimiert.
Weitere Informationen zur Systemarchitektur und zu den Konfigurationen von TPU v6e finden Sie unter TPU v6e.
In diesem Einführungsdokument werden die Prozesse für das Modelltraining und die Bereitstellung mit den Frameworks JAX oder PyTorch beschrieben. Mit jedem Framework können Sie TPUs mit in die Warteschlange eingereihten Ressourcen oder GKE bereitstellen. Die GKE-Einrichtung kann mit XPK- oder GKE-Befehlen erfolgen.
Allgemeine Vorgehensweise zum Trainieren oder Bereitstellen eines Modells mit v6e
- Google Cloud Projekt vorbereiten
- Sichere Kapazität
- Cloud TPU-Umgebung bereitstellen
- Eine Trainings- oder Inferenz-Arbeitslast für ein Modell ausführen
Google Cloud -Projekt vorbereiten
Bevor Sie Cloud TPU verwenden können, müssen Sie Folgendes tun:
- Google Cloud Konto und Projekt mit aktivierter Abrechnung erstellen
- Google Cloud CLI-Alphakomponenten installieren
- Cloud TPU API aktivieren
- Cloud TPU-Dienst-Agent erstellen
- Cloud TPU-Dienstkonto erstellen und Berechtigungen erteilen
Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.
Kapazität sichern
Wenden Sie sich an den Google Cloud -Support, um ein Cloud TPU v6e-Kontingent anzufordern und Fragen zur Kapazität zu klären.
Cloud TPU-Umgebung bereitstellen
v6e Cloud TPU können mit GKE, mit GKE und XPK (ein Wrapper-CLI-Tool für GKE) oder als in die Warteschlange eingereihte Ressourcen bereitgestellt und verwaltet werden.
Vorbereitung
- Prüfen Sie, ob Ihr Projekt über genügend
TPUS_PER_TPU_FAMILY
-Kontingent verfügt. Damit wird die maximale Anzahl von Chips angegeben, auf die Sie in Ihrem Google Cloud-Projekt zugreifen können. - v6e wurde mit der folgenden Konfiguration getestet:
- Python
3.10
oder höher - Nightly-Softwareversionen:
- Nächtlicher JAX-Wert
0.4.32.dev20240912
- Nächtliche LibTPU-Version
0.1.dev20240912+nightly
- Nächtlicher JAX-Wert
- Stabile Softwareversionen:
- JAX + JAX Lib v0.4.37
- Python
Prüfen Sie, ob Ihr Projekt genügend Kontingente für Folgendes hat:
- Cloud TPU-VM-Kontingent
- Kontingent für IP-Adressen
Kontingent für Hyperdisk Balanced und alle anderen Laufwerkstypen, die Sie verwenden möchten
Wenn Sie GKE mit XPK verwenden, finden Sie unter Cloud Console-Berechtigungen für das Nutzer- oder Dienstkonto die Berechtigungen, die zum Ausführen von XPK erforderlich sind.
Umgebungsvariablen erstellen
Erstellen Sie in einer Cloud Shell die folgenden Umgebungsvariablen:
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration # Additional environment variable needed for Multislice: export NUM_SLICES=number-of-slices # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Beschreibung der Befehls-Flags
Variable | Beschreibung |
NODE_ID | Die vom Nutzer zugewiesene ID der Cloud TPU, die erstellt wird, wenn die in die Warteschlange gestellte Ressourcenanfrage zugewiesen wird. |
PROJECT_ID | Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. Weitere Informationen finden Sie unter Google Cloud -Projekt einrichten. |
ZONE | Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen. |
ACCELERATOR_TYPE | Weitere Informationen finden Sie unter Beschleunigertypen. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Dies ist die E-Mail-Adresse für Ihr Dienstkonto, die Sie in der Google Cloud Console -> IAM -> Dienstkonten finden.
Beispiel: |
NUM_SLICES | Die Anzahl der zu erstellenden Slices (nur für Multislice erforderlich). |
QUEUED_RESOURCE_ID | Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage. |
VALID_DURATION | Die Dauer, für die die in die Warteschlange gestellte Ressourcenanfrage gültig ist. |
NETWORK_NAME | Der Name eines zu verwendenden sekundären Netzwerks. |
NETWORK_FW_NAME | Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll. |
Netzwerkleistung optimieren
Für eine optimale Leistung sollten Sie ein Netzwerk mit einer MTU (maximale Übertragungseinheit) von 8.896 verwenden.
Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerk auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Standard-Ethernet) oder 8.896 Byte (das Maximum). Weitere Informationen finden Sie unter Gültige MTU-VPC-Netzwerk-Netzwerke.
Weitere Informationen zum Ändern der MTU-Einstellung für ein vorhandenes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.
Im folgenden Beispiel wird ein Netzwerk mit einer MTU von 8.896 erstellt.
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
Multi-NIC verwenden (Option für Multislice)
Die folgenden Umgebungsvariablen sind für ein sekundäres Subnetz erforderlich, wenn Sie eine Multislice-Umgebung verwenden.
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
Verwenden Sie die folgenden Befehle, um ein benutzerdefiniertes IP-Routing für das Netzwerk und das Subnetz zu erstellen.
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
--network=${NETWORK_NAME_2} \
--range=10.10.0.0/18 --region=${REGION} \
--project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
--network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
--project=${PROJECT_ID} \
--network=${NETWORK_NAME_2} \
--region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
--router=${ROUTER_NAME} \
--region=${REGION} \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project=${PROJECT_ID} \
--enable-logging
Nachdem Sie einen Multi-Network-Slice erstellt haben, können Sie prüfen, ob beide Netzwerkkarten (NICs) verwendet werden. Dazu richten Sie einen XPK-Cluster ein und fügen dem Befehl zum Erstellen von XPK-Arbeitslasten das Flag --command ifconfig
hinzu.
Verwenden Sie den folgenden workload create
-Befehl, um die Ausgabe des ifconfig
-Befehls in den Google Cloud -Konsolenlogs anzuzeigen und zu prüfen, ob sowohl eth0 als auch eth1 mtu=8896 haben.
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command "ifconfig"
Wenn Sie Debug-Logs aktivieren oder Vertex AI TensorBoard verwenden möchten, fügen Sie dem Befehl die folgenden optionalen Argumente hinzu:
--enable-debug-logs \ --use-vertex-tensorboard
Prüfen Sie,ob sowohl eth0 als auch eth1 den Wert „mtu=8896“ haben. Sie können prüfen, ob die Multi-NIC ausgeführt wird, indem Sie dem Befehl zum Erstellen der XPK-Arbeitslast das Flag --command ifconfig
hinzufügen. Prüfen Sie die Ausgabe dieser XPK-Arbeitslast in den Google Cloud Konsolenlogs und prüfen Sie,ob sowohl eth0 als auch eth1 den Wert „mtu=8.896“ haben.
TCP-Einstellungen verbessern
Wenn Sie Ihre Cloud TPUs über die Schnittstelle für in die Warteschlange gestellte Ressourcen erstellt haben, können Sie die Netzwerkleistung verbessern, indem Sie die TCP-Empfangspufferlimits mit dem folgenden Befehl erhöhen.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT_ID}" \ --zone "${ZONE}" \ --node=all \ --worker=all \ --command=' sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'
Mit in die Warteschlange gestellten Ressourcen bereitstellen
Sie können eine Cloud TPU v6e mit Ressourcen in der Warteschlange erstellen. Mit Ressourcen in der Warteschlange können Sie Kapazität erhalten, sobald sie verfügbar ist. Sie können optional eine Start- und Endzeit für die Bearbeitung der Anfrage angeben. Weitere Informationen finden Sie unter In die Warteschlange gestellte Ressourcen verwalten.
Cloud TPUs v6e mit GKE oder XPK bereitstellen
Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Kubernetes-Befehle oder XPK verwenden, um Cloud TPUs bereitzustellen und Modelle zu trainieren oder bereitzustellen. Unter Cloud TPUs in GKE planen erfahren Sie, wie Sie Ihre Cloud TPU-Konfigurationen in GKE-Clustern planen. In den folgenden Abschnitten finden Sie Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für einzelne NICs und mehrere NICs.
XPK-Cluster mit Unterstützung für eine einzelne NIC erstellen
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Beschreibung der Befehls-Flags
Variable | Beschreibung |
CLUSTER_NAME | Der vom Nutzer zugewiesene Name für den XPK-Cluster. |
PROJECT_ID | Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. Weitere Informationen finden Sie unter Google Cloud -Projekt einrichten. |
ZONE | Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen. |
TPU_TYPE | Weitere Informationen finden Sie unter Beschleunigertypen. |
NUM_SLICES | Die Anzahl der Segmente, die Sie erstellen möchten |
CLUSTER_ARGUMENTS | Das zu verwendende Netzwerk und Subnetzwerk.
Beispiel: |
NUM_SLICES | Die Anzahl der zu erstellenden Slices. |
NETWORK_NAME | Der Name eines zu verwendenden sekundären Netzwerks. |
NETWORK_FW_NAME | Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll. |
XPK-Cluster mit Unterstützung für mehrere NICs erstellen
export CLUSTER_NAME=xpk-cluster-name export REGION=your-region export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \ --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \ --create-vertex-tensorboard
Beschreibung der Befehls-Flags
Variable | Beschreibung |
CLUSTER_NAME | Der vom Nutzer zugewiesene Name für den XPK-Cluster. |
PROJECT_ID | Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. Weitere Informationen finden Sie unter Google Cloud -Projekt einrichten. |
ZONE | Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen. |
TPU_TYPE | Weitere Informationen finden Sie unter Beschleunigertypen. |
NUM_SLICES | Die Anzahl der Segmente, die Sie erstellen möchten |
CLUSTER_ARGUMENTS | Das zu verwendende Netzwerk und Subnetzwerk.
Beispiel: |
NODE_POOL_ARGUMENTS | Zusätzliches Knotennetzwerk, das verwendet werden soll.
Beispiel: |
NUM_SLICES | Die Anzahl der zu erstellenden Slices (nur für Multislice erforderlich). |
NETWORK_NAME | Der Name eines zu verwendenden sekundären Netzwerks. |
NETWORK_FW_NAME | Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll. |
Framework einrichten
In diesem Abschnitt wird die allgemeine Einrichtung für das Trainieren von ML-Modellen mit den Frameworks JAX und PyTorch beschrieben. Wenn Sie GKE verwenden, können Sie XPK- oder Kubernetes-Befehle für die Einrichtung des Frameworks verwenden.
Einrichtung für JAX
In diesem Abschnitt finden Sie eine Einrichtungsanleitung für die Ausführung von JAX-Arbeitslasten in GKE mit oder ohne XPK sowie für die Verwendung von Ressourcen in der Warteschlange.
JAX mit GKE einrichten
Einzelner Slice auf einem einzelnen Host
Im folgenden Beispiel wird ein 2 × 2-Knotenpool mit einem einzelnen Host mithilfe einer Kubernetes-YAML-Datei eingerichtet.
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Nach dem erfolgreichen Abschluss des Vorgangs sollte im GKE-Log die folgende Meldung angezeigt werden:
Total TPU chips: 4
Einzelner Slice auf mehreren Hosts
Im folgenden Beispiel wird ein Knotenpool mit mehreren Hosts (4 × 4) mithilfe einer Kubernetes-YAML-Datei eingerichtet.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Nach dem erfolgreichen Abschluss des Vorgangs sollte im GKE-Log die folgende Meldung angezeigt werden:
Total TPU chips: 16
Multislice auf mehreren Hosts
Im folgenden Beispiel werden zwei 4x4-Knotenpools mit mehreren Hosts mithilfe einer Kubernetes-YAML-Datei eingerichtet.
Als Voraussetzung müssen Sie JobSet v0.2.3 oder höher installieren.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Nach dem erfolgreichen Abschluss des Vorgangs sollte im GKE-Log die folgende Meldung angezeigt werden:
Total TPU chips: 32
Weitere Informationen finden Sie in der GKE-Dokumentation unter Multi-Slice-Arbeitslast ausführen.
Aktivieren Sie hostNetwork, um die Leistung zu verbessern.
Multi-NIC
Damit Sie das folgende Manifest mit mehreren NICs verwenden können, müssen Sie Ihre Netzwerke einrichten. Weitere Informationen finden Sie unter Unterstützung mehrerer Netzwerke für Kubernetes-Pods einrichten.
Wenn Sie mehrere NICs in GKE nutzen möchten, müssen Sie dem Kubernetes-Pod-Manifest einige zusätzliche Annotationen hinzufügen.
Das Folgende ist ein Beispielmanifest für eine Multi-NIC-Arbeitslast ohne TPU.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
Wenn Sie den Befehl exec
verwenden, um eine Verbindung zum Kubernetes-Pod herzustellen, sollte die zusätzliche NIC mit dem folgenden Code angezeigt werden:
$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
JAX mit GKE und XPK einrichten
Informationen zum Einrichten von JAX mit GKE und XPK finden Sie in der XPK-README-Datei.
Informationen zum Einrichten und Ausführen von XPK mit MaxText finden Sie unter MaxText ausführen.
JAX mit in die Warteschlange gestellten Ressourcen einrichten
Installieren Sie JAX auf allen Cloud TPU-VMs in Ihrem Slice oder Ihren Slices gleichzeitig mit dem Befehl gcloud alpha compute tpus tpu-vm ssh
. Fügen Sie für Multislice das Flag --node=all
hinzu.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Mit dem folgenden Befehl können Sie prüfen, wie viele Cloud TPU-Kerne in Ihrem Slice verfügbar sind, und testen, ob alles korrekt installiert ist:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
Die Ausgabe sieht bei der Ausführung auf einem v6e-16-Slice in etwa so aus:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count()
gibt die Gesamtzahl der Chips im angegebenen Slice an.
jax.local_device_count()
gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 &&
pip install setuptools==59.6.0 &&
pip install -r requirements.txt && pip install .'
Probleme bei der JAX-Einrichtung beheben
Ein allgemeiner Tipp ist, das ausführliche Logging in Ihrem GKE-Arbeitslastmanifest zu aktivieren. Stellen Sie die Logs dann dem GKE-Support zur Verfügung.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Fehlermeldungen
no endpoints available for service 'jobset-webhook-service'
Dieser Fehler bedeutet, dass das Jobset nicht richtig installiert wurde. Prüfen Sie, ob die Kubernetes-Pods für die jobset-controller-manager-Bereitstellung ausgeführt werden. Weitere Informationen finden Sie in der Dokumentation zur Fehlerbehebung bei JobSets.
TPU initialization failed: Failed to connect
Die GKE-Knotenversion muss 1.30.4-gke.1348000 oder höher sein (GKE 1.31 wird nicht unterstützt).
Einrichtung für PyTorch
In diesem Abschnitt wird beschrieben, wie Sie PJRT auf v6e mit PyTorch/XLA verwenden. Python 3.10 ist die empfohlene Python-Version.
PyTorch mit GKE und XPK einrichten
Sie können den folgenden Docker-Container mit XPK verwenden, in dem PyTorch-Abhängigkeiten bereits installiert sind:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Verwenden Sie den folgenden Befehl, um eine XPK-Arbeitslast zu erstellen:
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone ${ZONE} \ --project ${PROJECT_ID} \ --enable-debug-logs \ --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
Mit --base-docker-image
wird ein neues Docker-Image erstellt, in das das aktuelle Arbeitsverzeichnis eingebunden ist.
PyTorch mit in die Warteschlange gestellten Ressourcen einrichten
So installieren Sie PyTorch mit in die Warteschlange eingereihten Ressourcen und führen ein kleines Skript auf v6e aus:
Abhängigkeiten über SSH installieren, um auf die VMs zuzugreifen
Verwenden Sie den folgenden Befehl, um Abhängigkeiten auf allen Cloud TPU-VMs zu installieren. Fügen Sie für Multislice das Flag --worker=all
hinzu:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Leistung von Modellen mit großen, häufigen Zuweisungen verbessern
Bei Modellen mit großen, häufigen Zuweisungen wird mit der Funktion tcmalloc
die Leistung im Vergleich zur Standardimplementierung der Funktion malloc
erheblich verbessert. Daher ist die Standardfunktion malloc
, die auf Cloud TPU-VMs verwendet wird, tcmalloc
. Je nach Arbeitslast (z. B. bei DLRM, das sehr große Zuweisungen für seine Einbettungstabellen hat) kann die Funktion tcmalloc
jedoch zu einer Verlangsamung führen. In diesem Fall können Sie versuchen, die folgende Variable mit der Standardfunktion malloc
zu deaktivieren:
unset LD_PRELOAD
Python-Skript zum Ausführen einer Berechnung auf einer v6e-VM verwenden
Verwenden Sie den folgenden Befehl, um ein Skript auszuführen, das zwei Tensoren erstellt, sie addiert und das Ergebnis ausgibt:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker all \
--command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Dadurch wird eine Ausgabe generiert, die etwa so aussieht:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
v6e mit SkyPilot
Sie können Cloud TPU v6e mit SkyPilot verwenden. So fügen Sie SkyPilot v6e-bezogene Standort- und Preisinformationen hinzu: Weitere Informationen finden Sie im SkyPilot-Beispiel für TPU v6e.
Anleitungen für die Inferenz
In den folgenden Anleitungen wird beschrieben, wie Sie Inferenz auf Cloud TPU v6e ausführen:
Trainingsbeispiele
In den folgenden Abschnitten finden Sie Beispiele für das Training von MaxText-, MaxDiffusion- und PyTorch-Modellen auf Cloud TPU v6e.
MaxText- und MaxDiffusion-Training auf der v6e Cloud TPU-VM
In den folgenden Abschnitten wird der Trainingszyklus der Modelle MaxText und MaxDiffusion beschrieben.
Im Allgemeinen sind folgende Schritte erforderlich:
- Erstellen Sie das Basis-Image für die Arbeitslast.
- Führen Sie Ihre Arbeitslast mit XPK aus.
- Erstellen Sie den Trainingsbefehl für die Arbeitslast.
- Stellen Sie die Arbeitslast bereit.
- Folgen Sie der Arbeitslast und sehen Sie sich die Messwerte an.
- Löschen Sie die XPK-Arbeitslast, wenn sie nicht benötigt wird.
- Löschen Sie den XPK-Cluster, wenn er nicht mehr benötigt wird.
Basis-Image erstellen
Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:
Klonen Sie das gewünschte Repository und wechseln Sie in das Verzeichnis des Repositorys:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
Konfigurieren Sie Docker für die Verwendung der Google Cloud CLI:
gcloud auth configure-docker
Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit dem JAX Stable Stack. Weitere Informationen zum JAX Stable Stack finden Sie unter Docker-Image mit JAX Stable Stack erstellen.
MaxText:
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
MaxDiffusion:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
Legen Sie Ihre Projekt-ID in der aktiven gcloud CLI-Konfiguration fest:
gcloud config set project ${PROJECT_ID}
Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch.
Legen Sie die Umgebungsvariable
CLOUD_IMAGE_NAME
fest:export CLOUD_IMAGE_NAME=${USER}_runner
Laden Sie das Bild hoch:
bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
Arbeitslast mit XPK ausführen
Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die von MaxText festgelegten Standardwerte oder MaxDiffusion verwenden:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Modellskript erstellen Dieses Skript wird in einem späteren Schritt als Trainingsbefehl kopiert.
Führen Sie das Modellskript noch nicht aus.
MaxText
MaxText ist ein leistungsstarkes, hochgradig skalierbares Open-Source-LLM, das in reinem Python und JAX geschrieben wurde und auf Google Cloud TPUs und GPUs für Training und Inferenz ausgerichtet ist.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma ist eine Familie von LLMs mit offenen Gewichten, die von Google DeepMind auf Grundlage von Gemini-Forschung und -Technologie entwickelt wurden.
python3 -m MaxText.train MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine spärliche MoE-Architektur (Mixture of Experts) nutzt.
python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama ist eine Familie von LLMs mit offenen Gewichten, die von Meta entwickelt wurden.
Ein Beispiel für die Ausführung von Llama3 in PyTorch finden Sie unter torch_xla-Modelle im torchprime-Repository.
MaxDiffusion
MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reinem Python und JAX geschrieben sind und auf XLA-Geräten wie Cloud TPUs und GPUs ausgeführt werden. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, mit dem fotorealistische Bilder aus beliebigen Texteingaben generiert werden.
Sie müssen einen bestimmten Git-Branch installieren, um MaxDiffusion auszuführen, wie im folgenden
git clone
-Befehl gezeigt.Trainingsskript:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 resolution=1024 per_device_batch_size=1 output_dir=${OUT_DIR} jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=200 attention=flash run_name=sdxl-ddp-v6e
Exportieren Sie die folgenden Variablen:
export CLUSTER_NAME=CLUSTER_NAME export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=NUM_SLICES export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
Beschreibungen von Umgebungsvariablen
Variable Beschreibung CLUSTER_NAME
Der Name Ihres XPK-Clusters. ACCELERATOR_TYPE
Weitere Informationen finden Sie unter Beschleunigertypen. NUM_SLICES
Die Anzahl der TPU-Slices. YOUR_MODEL_SCRIPT
Das Modellskript, das als Trainingsbefehl ausgeführt werden soll. Führen Sie das Modell mit dem Skript aus, das Sie im vorherigen Schritt erstellt haben. Sie müssen entweder das Flag
--base-docker-image
angeben, um das MaxText-Basis-Image zu verwenden, oder das Flag--docker-image
und das Image, das Sie verwenden möchten.Optional: Sie können das Debug-Logging aktivieren, indem Sie das Flag
--enable-debug-logs
einfügen. Weitere Informationen finden Sie unter JAX auf MaxText debuggen.Optional: Sie können einen Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Dazu müssen Sie das Flag
--use-vertex-tensorboard
einfügen. Weitere Informationen finden Sie unter JAX auf MaxText mit Vertex AI überwachen.python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command="${YOUR_MODEL_SCRIPT}"
Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können. Öffnen Sie den Link und klicken Sie auf den Tab Logs, um Ihre Arbeitslast in Echtzeit zu verfolgen.
JAX in MaxText debuggen
Verwenden Sie zusätzliche XPK-Befehle, um zu ermitteln, warum der Cluster oder die Arbeitslast nicht ausgeführt wird:
- XPK-Arbeitslastliste
- XPK-Prüftool
- Aktivieren Sie das ausführliche Logging in Ihren Arbeitslastlogs mit dem Flag
--enable-debug-logs
, wenn Sie die XPK-Arbeitslast erstellen.
JAX auf MaxText mit Vertex AI überwachen
Damit Sie TensorBoard verwenden können, muss Ihrem Google Cloud -Nutzerkonto die Rolle aiplatform.user
zugewiesen sein. Führen Sie den folgenden Befehl aus, um diese Rolle zuzuweisen:
gcloud projects add-iam-policy-binding your-project-id \ --member='user:your-email' \ --role='roles/aiplatform.user'
Skalar- und Profildaten über das verwaltete TensorBoard von Vertex AI ansehen.
Erhöhen Sie die Resource Management (CRUD)-Anfragen für die Zone, die Sie verwenden, von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.
Installieren Sie Abhängigkeiten wie
cloud-accelerator-diagnostics
für Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Erstellen Sie Ihren XPK-Cluster mit dem Flag
--create-vertex-tensorboard
, wie in Vertex AI TensorBoard erstellen beschrieben. Sie können diesen Befehl auch für vorhandene Cluster ausführen.Erstellen Sie Ihren Vertex AI-Test, wenn Sie Ihren XPK-Arbeitslast mit dem Flag
--use-vertex-tensorboard
und dem optionalen Flag--experiment-name
ausführen. Eine vollständige Liste der Schritte finden Sie unter Vertex AI-Test zum Hochladen von Daten in Vertex AI TensorBoard erstellen.
Die Logs enthalten einen Link zu einem Vertex AI TensorBoard, ähnlich dem folgenden:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Sie können den Link zu Vertex AI TensorBoard auch in der Google Cloud Console aufrufen. Rufen Sie Vertex AI Experiments in der Google Cloud Konsole auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.
Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR}
angegeben haben.
XPK-Arbeitslasten löschen
Verwenden Sie den Befehl xpk workload delete
, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder dem Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen bleiben.
XPK-Cluster löschen
Verwenden Sie den Befehl xpk cluster delete
, um einen Cluster zu löschen:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
Llama- und PyTorch/XLA-Training auf einer Cloud TPU v6e-VM
In dieser Anleitung wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf Cloud TPU v6e mit dem WikiText-Dataset trainieren.
Zugriff auf Hugging Face und das Llama 3-Modell erhalten
Für dieses Tutorial benötigen Sie ein Hugging Face-Nutzerzugriffstoken. Informationen zum Erstellen von Nutzerzugriffstokens finden Sie in der Hugging Face-Dokumentation zu Nutzerzugriffstokens.
Außerdem benötigen Sie die Berechtigung für den Zugriff auf das Modell „Llama-3-8B“ auf Hugging Face. Wenn Sie Zugriff erhalten möchten, rufen Sie das Meta-Llama-3-8B-Modell auf Hugging Face auf und beantragen Sie den Zugriff.
Cloud TPU-VM erstellen
Erstellen Sie eine Cloud TPU v6e mit 8 Chips, um das Tutorial auszuführen.
Richten Sie Umgebungsvariablen ein:
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-8 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration
Cloud TPU-VM erstellen:
gcloud alpha compute tpus tpu-vm create ${NODE_ID} --version=${RUNTIME_VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Installation
Installieren Sie den pytorch-tpu/transformers
-Fork von Hugging Face-Transformern und Abhängigkeiten. Diese Anleitung wurde mit den folgenden Abhängigkeitsversionen getestet, die in diesem Beispiel verwendet werden:
torch
: kompatibel mit 2.5.0torch_xla[tpu]
: kompatibel mit 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'
Modellkonfigurationen einrichten
Im Trainingsbefehl im nächsten Abschnitt Modell ausführen werden zwei JSON-Konfigurationsdateien verwendet, um Modellparameter und die Konfiguration für Fully Sharded Data Parallel (FSDP) zu definieren. Mit FSDP-Sharding können Sie beim Training eine größere Batchgröße verwenden, da die Modellgewichte auf mehrere TPUs verteilt werden. Beim Training mit kleineren Modellen kann es ausreichen, Datenparallelität zu verwenden und die Gewichte auf jedem Gerät zu replizieren. Weitere Informationen zum Sharding von Tensoren auf Geräten in PyTorch/XLA finden Sie im SPMD-Nutzerhandbuch für PyTorch/XLA.
Erstellen Sie die Konfigurationsdatei für Modellparameter. Im Folgenden finden Sie die Modellparameterkonfiguration für Llama-3-8B. Die Konfiguration für andere Modelle finden Sie auf Hugging Face. Ein Beispiel finden Sie in der Llama-2-7B-Konfiguration.
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Erstellen Sie die FSDP-Konfigurationsdatei:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
Weitere Informationen zum FSDP finden Sie unter FSDPv2.
Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre Cloud TPU-VMs hoch:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${NODE_ID}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Modell ausführen
Führen Sie mit den Konfigurationsdateien, die Sie im vorherigen Abschnitt erstellt haben, das Skript run_clm.py
aus, um das Modell Llama-3-8B mit dem WikiText-Dataset zu trainieren. Die Ausführung des Trainingsskripts dauert auf einer Cloud TPU v6e-8 etwa 10 Minuten.
Melden Sie sich mit dem folgenden Befehl auf Ihrer Cloud TPU bei Hugging Face an:
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Modelltraining ausführen:
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Fehlerbehebung bei PyTorch/XLA
Wenn Sie die optionalen Variablen für das Debugging im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell am Speicherort gespeichert, der durch die Variable PROFILE_LOGDIR
angegeben wird. Sie können die xplane.pb
-Datei, die an diesem Speicherort gespeichert ist, extrahieren und tensorboard
verwenden, um die Profile in Ihrem Browser gemäß der TensorBoard-Anleitung anzusehen.
Wenn PyTorch/XLA nicht wie erwartet funktioniert, finden Sie im Leitfaden zur Fehlerbehebung Vorschläge zum Debuggen, Profilerstellen und Optimieren Ihres Modells.
Benchmarking-Ergebnisse
Der folgende Abschnitt enthält Benchmark-Ergebnisse für MaxDiffusion auf v6e.
MaxDiffusion
Wir haben das Trainingsskript für MaxDiffusion auf einer v6e-4, einer v6e-16 und zwei v6e-16 ausgeführt. Die Durchsatzraten sind in der folgenden Tabelle aufgeführt.
v6e-4 | v6e-16 | Zwei v6e-16 | |
---|---|---|---|
Trainingsschritte | 0.069 | 0,073 | 0,13 |
Globale Batchgröße | 8 | 32 | 64 |
Durchsatz (Beispiele/s) | 115.9 | 438,4 | 492,3 |