Introduction à Trillium (v6e)
Dans cette documentation, l'API TPU et les journaux, v6e fait référence à Trillium. v6e représente la sixième génération de TPU de Google.
Avec 256 puces par pod, l'architecture v6e présente de nombreuses similitudes avec la v5e. Ce système est optimisé pour l'entraînement, l'ajustement et le service des transformateurs, des réseaux de neurones convolutifs (CNN) et des modèles de conversion de texte en image.
Pour en savoir plus sur l'architecture et les configurations du système v6e, consultez TPU v6e.
Ce document d'introduction se concentre sur les processus d'entraînement et de diffusion de modèles à l'aide des frameworks JAX ou PyTorch. Avec chaque framework, vous pouvez provisionner des TPU à l'aide de ressources mises en file d'attente ou de GKE. La configuration de GKE peut être effectuée à l'aide de XPK ou de commandes GKE.
Procédure générale pour entraîner ou diffuser un modèle à l'aide de v6e
- Préparer un projet Google Cloud
- Sécuriser la capacité
- Provisionner l'environnement Cloud TPU
- Exécuter une charge de travail d'entraînement ou d'inférence de modèle
Préparer un projet Google Cloud
Pour pouvoir utiliser Cloud TPU, vous devez :
- Créer un compte Google Cloud et un projet avec la facturation activée
- Installer les composants alpha de Google Cloud CLI
- Activer l'API Cloud TPU
- Créer un agent de service Cloud TPU
- Créer un compte de service Cloud TPU et accorder des autorisations
Pour en savoir plus, consultez Configurer l'environnement Cloud TPU.
Sécuriser la capacité
Contactez l'assistanceGoogle Cloud pour demander un quota Cloud TPU v6e et obtenir des réponses à vos questions sur la capacité.
Provisionner l'environnement Cloud TPU
Les v6e Cloud TPU peuvent être provisionnées et gérées avec GKE, avec GKE et XPK (un outil CLI wrapper sur GKE), ou en tant que ressources mises en file d'attente.
Prérequis
- Vérifiez que votre projet dispose d'un quota
TPUS_PER_TPU_FAMILY
suffisant, qui spécifie le nombre maximal de chips auxquels vous pouvez accéder dans votre projet Google Cloud. - v6e a été testé avec la configuration suivante :
- Python
3.10
ou version ultérieure - Versions logicielles Nightly :
- JAX quotidien
0.4.32.dev20240912
- LibTPU nightly
0.1.dev20240912+nightly
- JAX quotidien
- Versions logicielles stables :
- JAX + JAX Lib de la version 0.4.37
- Python
Vérifiez que votre projet dispose d'un quota suffisant pour :
- Quota de VM Cloud TPU
- Quota d'adresses IP
Quota pour Hyperdisk équilibré et pour tout autre type de disque que vous souhaitez utiliser
Si vous utilisez GKE avec XPK, consultez Autorisations de la console Cloud sur le compte utilisateur ou de service pour connaître les autorisations nécessaires à l'exécution de XPK.
Créer des variables d'environnement
Dans Cloud Shell, créez les variables d'environnement suivantes :
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration # Additional environment variable needed for Multislice: export NUM_SLICES=number-of-slices # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Description des options de commande
Variable | Description |
NODE_ID | ID attribué par l'utilisateur au Cloud TPU créé lorsque la demande de ressource en file d'attente est allouée. |
PROJECT_ID | Nom du projet :Google Cloud Utilisez un projet existant ou créez-en un. Pour en savoir plus, consultez Configurer votre projet Google Cloud . |
ZONE | Pour connaître les zones compatibles, consultez le document Régions et zones Cloud TPU. |
ACCELERATOR_TYPE | Consultez la section Types d'accélérateurs. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Il s'agit de l'adresse e-mail de votre compte de service, que vous trouverez dans la console Google Cloud > IAM > Comptes de service.
Par exemple : |
NUM_SLICES | Nombre de segments à créer (nécessaire uniquement pour Multislice). |
QUEUED_RESOURCE_ID | ID de texte attribué par l'utilisateur à la demande de ressource mise en file d'attente. |
VALID_DURATION | Durée de validité de la demande de ressource mise en file d'attente. |
NETWORK_NAME | Nom d'un réseau secondaire à utiliser. |
NETWORK_FW_NAME | Nom d'un pare-feu réseau secondaire à utiliser. |
Optimiser les performances du réseau
Pour des performances optimales,utilisez un réseau avec une MTU (unité de transmission maximale) de 8 896.
Par défaut, un cloud privé virtuel (VPC) ne fournit qu'une MTU de 1 460 octets,ce qui offre des performances réseau sous-optimales. Vous pouvez définir la MTU d'un réseau VPC sur n'importe quelle valeur comprise entre 1 300 octets et 8 896 octets (inclus). Les tailles de MTU personnalisées les plus courantes sont de 1 500 octets (le standard Ethernet) ou de 8 896 octets (le maximum possible). Pour en savoir plus, consultez Tailles de MTU valides pour le réseau VPC.
Pour en savoir plus sur la modification du paramètre de MTU d'un réseau existant ou par défaut, consultez Modifier le paramètre de MTU d'un réseau VPC.
L'exemple suivant crée un réseau avec une MTU de 8 896.
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
Utiliser plusieurs cartes d'interface réseau (option pour Multislice)
Les variables d'environnement suivantes sont nécessaires pour un sous-réseau secondaire lorsque vous utilisez un environnement multislices.
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
Utilisez les commandes suivantes pour créer un routage IP personnalisé pour le réseau et le sous-réseau.
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
--network=${NETWORK_NAME_2} \
--range=10.10.0.0/18 --region=${REGION} \
--project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
--network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
--project=${PROJECT_ID} \
--network=${NETWORK_NAME_2} \
--region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
--router=${ROUTER_NAME} \
--region=${REGION} \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project=${PROJECT_ID} \
--enable-logging
Après avoir créé un slice de réseau multiple, vous pouvez vérifier que les deux cartes d'interface réseau (NIC) sont utilisées en configurant un cluster XPK et en ajoutant l'indicateur --command ifconfig
à la commande de création de charge de travail XPK.
Utilisez la commande workload create
suivante pour afficher le résultat de la commande ifconfig
dans les journaux de la console Google Cloud et vérifiez que les interfaces eth0 et eth1 ont toutes les deux la valeur mtu=8896.
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command "ifconfig"
Si vous souhaitez activer les journaux de débogage ou utiliser Vertex AI TensorBoard, ajoutez les arguments facultatifs suivants à la commande :
--enable-debug-logs \ --use-vertex-tensorboard
Vérifiez que les deux interfaces eth0 et eth1 ont la valeur mtu=8 896. Pour vérifier que l'interface multi-NIC est en cours d'exécution, ajoutez l'indicateur --command ifconfig
à la commande de création de la charge de travail XPK. Vérifiez la sortie de cette charge de travail XPK dans les journaux de la console Google Cloud et assurez-vous que eth0 et eth1 ont tous deux la valeur mtu=8 896.
Améliorer les paramètres TCP
Si vous avez créé vos Cloud TPU à l'aide de l'interface des ressources en file d'attente, vous pouvez exécuter la commande suivante pour améliorer les performances réseau en augmentant les limites du tampon de réception TCP.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT_ID}" \ --zone "${ZONE}" \ --node=all \ --worker=all \ --command=' sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'
Provisionner avec des ressources en file d'attente
Vous pouvez créer un Cloud TPU v6e à l'aide de ressources mises en file d'attente. Les ressources mises en file d'attente vous permettent de recevoir de la capacité dès qu'elle est disponible. Vous pouvez spécifier une heure de début et de fin facultatives pour le traitement de la demande. Pour en savoir plus, consultez Gérer les ressources mises en file d'attente.
Provisionner des Cloud TPU v6e avec GKE ou XPK
Si vous utilisez des commandes GKE avec v6e, vous pouvez utiliser des commandes Kubernetes ou XPK pour provisionner des Cloud TPU et entraîner ou diffuser des modèles. Consultez Planifier des Cloud TPU dans GKE pour savoir comment planifier vos configurations Cloud TPU dans les clusters GKE. Les sections suivantes fournissent des commandes permettant de créer un cluster XPK compatible avec une ou plusieurs cartes réseau.
Créer un cluster XPK compatible avec une seule carte d'interface réseau
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Description des options de commande
Variable | Description |
CLUSTER_NAME | Nom attribué par l'utilisateur au cluster XPK. |
PROJECT_ID | Nom du projet :Google Cloud Utilisez un projet existant ou créez-en un. Pour en savoir plus, consultez Configurer votre projet Google Cloud . |
ZONE | Pour connaître les zones compatibles, consultez le document Régions et zones Cloud TPU. |
TPU_TYPE | Consultez la section Types d'accélérateurs. |
NUM_SLICES | Nombre de segments à créer |
CLUSTER_ARGUMENTS | Réseau et sous-réseau à utiliser.
Par exemple : |
NUM_SLICES | Nombre de tranches à créer. |
NETWORK_NAME | Nom d'un réseau secondaire à utiliser. |
NETWORK_FW_NAME | Nom d'un pare-feu réseau secondaire à utiliser. |
Créer un cluster XPK compatible avec plusieurs cartes d'interface réseau
export CLUSTER_NAME=xpk-cluster-name export REGION=your-region export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \ --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \ --create-vertex-tensorboard
Description des options de commande
Variable | Description |
CLUSTER_NAME | Nom attribué par l'utilisateur au cluster XPK. |
PROJECT_ID | Nom du projet :Google Cloud Utilisez un projet existant ou créez-en un. Pour en savoir plus, consultez Configurer votre projet Google Cloud . |
ZONE | Pour connaître les zones compatibles, consultez le document Régions et zones Cloud TPU. |
TPU_TYPE | Consultez la section Types d'accélérateurs. |
NUM_SLICES | Nombre de segments à créer |
CLUSTER_ARGUMENTS | Réseau et sous-réseau à utiliser.
Par exemple : |
NODE_POOL_ARGUMENTS | Réseau de nœuds supplémentaires à utiliser.
Par exemple : |
NUM_SLICES | Nombre de segments à créer (nécessaire uniquement pour Multislice). |
NETWORK_NAME | Nom d'un réseau secondaire à utiliser. |
NETWORK_FW_NAME | Nom d'un pare-feu réseau secondaire à utiliser. |
Configurer le framework
Cette section décrit la procédure de configuration générale pour l'entraînement de modèles de ML à l'aide des frameworks JAX et PyTorch. Si vous utilisez GKE, vous pouvez utiliser les commandes XPK ou Kubernetes pour configurer le framework.
Configurer pour JAX
Cette section fournit des instructions de configuration pour exécuter des charges de travail JAX sur GKE, avec ou sans XPK, ainsi que pour utiliser des ressources mises en file d'attente.
Configurer JAX à l'aide de GKE
Tranche unique sur un seul hôte
L'exemple suivant configure un pool de nœuds à hôte unique 2x2 à l'aide d'un fichier YAML Kubernetes.
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Une fois l'opération terminée, le message suivant doit s'afficher dans le journal GKE :
Total TPU chips: 4
Tranche unique sur plusieurs hôtes
L'exemple suivant configure un pool de nœuds multihôtes 4x4 à l'aide d'un fichier YAML Kubernetes.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Une fois l'opération terminée, le message suivant doit s'afficher dans le journal GKE :
Total TPU chips: 16
Multitranches sur plusieurs hôtes
L'exemple suivant configure deux pools de nœuds multihôtes 4x4 à l'aide d'un fichier YAML Kubernetes.
Pour commencer, vous devez installer JobSet v0.2.3 ou version ultérieure.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Une fois l'opération terminée, le message suivant doit s'afficher dans le journal GKE :
Total TPU chips: 32
Pour en savoir plus, consultez Exécuter une charge de travail multicouche dans la documentation de GKE.
Pour améliorer les performances, activez hostNetwork.
Multi-NIC
Pour utiliser le fichier manifeste multi-NIC suivant, vous devez configurer vos réseaux. Pour en savoir plus, consultez Configurer la compatibilité multiréseau pour les pods Kubernetes.
Pour profiter de plusieurs cartes d'interface réseau dans GKE, vous devez inclure des annotations supplémentaires dans le fichier manifeste du pod Kubernetes.
Voici un exemple de fichier manifeste de charge de travail multi-NIC non-TPU.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
Si vous utilisez la commande exec
pour vous connecter au pod Kubernetes, vous devriez voir la carte d'interface réseau supplémentaire à l'aide du code suivant :
$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
Configurer JAX à l'aide de GKE avec XPK
Pour configurer JAX à l'aide de GKE et de XPK, consultez le fichier README de XPK.
Pour configurer et exécuter XPK avec MaxText, consultez Exécuter MaxText.
Configurer JAX à l'aide de ressources en file d'attente
Installez JAX sur toutes les VM Cloud TPU de votre ou vos tranches simultanément à l'aide de la commande gcloud alpha compute tpus tpu-vm ssh
. Pour Multislice, ajoutez l'indicateur --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Vous pouvez exécuter la commande suivante pour vérifier le nombre de cœurs Cloud TPU disponibles dans votre tranche et pour vous assurer que tout est correctement installé :
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
Le résultat ressemble à ce qui suit lors de l'exécution sur une tranche v6e-16 :
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count()
indique le nombre total de puces dans la tranche donnée.
jax.local_device_count()
indique le nombre de puces accessibles par une seule VM dans cette tranche.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 &&
pip install setuptools==59.6.0 &&
pip install -r requirements.txt && pip install .'
Résoudre les problèmes de configuration de JAX
Un conseil général consiste à activer la journalisation détaillée dans le fichier manifeste de votre charge de travail GKE. Ensuite, fournissez les journaux à l'assistance GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Messages d'erreur
no endpoints available for service 'jobset-webhook-service'
Cette erreur signifie que le jobset n'a pas été installé correctement. Vérifiez si les pods Kubernetes de déploiement jobset-controller-manager sont en cours d'exécution. Pour en savoir plus, consultez la documentation sur la résolution des problèmes liés aux JobSet.
TPU initialization failed: Failed to connect
Assurez-vous que la version de votre nœud GKE est 1.30.4-gke.1348000 ou ultérieure (GKE 1.31 n'est pas compatible).
Configuration pour PyTorch
Cette section explique comment commencer à utiliser PJRT sur v6e avec PyTorch/XLA. Python 3.10 est la version de Python recommandée.
Configurer PyTorch à l'aide de GKE avec XPK
Vous pouvez utiliser le conteneur Docker suivant avec XPK, qui contient les dépendances PyTorch déjà installées :
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Pour créer une charge de travail XPK, utilisez la commande suivante :
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone ${ZONE} \ --project ${PROJECT_ID} \ --enable-debug-logs \ --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
L'utilisation de --base-docker-image
crée une image Docker avec le répertoire de travail actuel intégré au nouveau Docker.
Configurer PyTorch à l'aide de ressources en file d'attente
Suivez ces étapes pour installer PyTorch à l'aide de ressources mises en file d'attente et exécuter un petit script sur v6e.
Installer les dépendances à l'aide de SSH pour accéder aux VM
Utilisez la commande suivante pour installer les dépendances sur toutes les VM Cloud TPU. Pour Multislice, ajoutez l'option --worker=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Améliorer les performances des modèles avec des allocations importantes et fréquentes
Pour les modèles comportant des allocations fréquentes et importantes, l'utilisation de la fonction tcmalloc
améliore considérablement les performances par rapport à l'implémentation par défaut de la fonction malloc
. La fonction malloc
par défaut utilisée sur la VM Cloud TPU est donc tcmalloc
. Toutefois, en fonction de votre charge de travail (par exemple, avec DLRM qui dispose d'allocations très importantes pour ses tables d'intégration), la fonction tcmalloc
peut entraîner un ralentissement. Dans ce cas, vous pouvez essayer de désactiver la variable suivante à l'aide de la fonction malloc
par défaut :
unset LD_PRELOAD
Utiliser un script Python pour effectuer un calcul sur une VM v6e
Utilisez la commande suivante pour exécuter un script qui crée deux tenseurs, les additionne et affiche le résultat :
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker all \
--command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Un résultat semblable à celui-ci doit s'afficher :
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
v6e avec SkyPilot
Vous pouvez utiliser Cloud TPU v6e avec SkyPilot. Procédez comme suit pour ajouter des informations sur les prix et les emplacements liés à la norme v6e à SkyPilot. Pour en savoir plus, consultez l'exemple SkyPilot TPU v6e.
Tutoriels sur l'inférence
Les tutoriels suivants montrent comment exécuter l'inférence sur Cloud TPU v6e :
Exemples d'entraînement
Les sections suivantes fournissent des exemples d'entraînement de modèles MaxText, MaxDiffusion et PyTorch sur Cloud TPU v6e.
Entraînement MaxText et MaxDiffusion sur une VM Cloud TPU v6e
Les sections suivantes couvrent le cycle de vie de l'entraînement des modèles MaxText et MaxDiffusion.
De manière générale, voici les étapes à suivre :
- Créez l'image de base de la charge de travail.
- Exécutez votre charge de travail à l'aide de XPK.
- Créez la commande d'entraînement pour la charge de travail.
- Déployez la charge de travail.
- Suivez la charge de travail et affichez les métriques.
- Supprimez la charge de travail XPK si vous n'en avez pas besoin.
- Supprimez le cluster XPK lorsqu'il n'est plus nécessaire.
Créer l'image de base
Installez MaxText ou MaxDiffusion, puis créez l'image Docker :
Clonez le dépôt que vous souhaitez utiliser et accédez à son répertoire :
MaxText :
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion :
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
Configurez Docker pour utiliser Google Cloud CLI :
gcloud auth configure-docker
Créez l'image Docker à l'aide de la commande suivante ou de la pile stable JAX. Pour en savoir plus sur JAX Stable Stack, consultez Créer une image Docker avec JAX Stable Stack.
MaxText :
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
MaxDiffusion :
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
Définissez l'ID de votre projet dans votre configuration gcloud CLI active :
gcloud config set project ${PROJECT_ID}
Si vous lancez la charge de travail à partir d'une machine sur laquelle l'image n'a pas été créée localement, importez l'image.
Définissez la variable d'environnement
CLOUD_IMAGE_NAME
:export CLOUD_IMAGE_NAME=${USER}_runner
Importez l'image :
bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
Exécuter votre charge de travail à l'aide de XPK
Définissez les variables d'environnement suivantes si vous n'utilisez pas les valeurs par défaut définies par MaxText ou MaxDiffusion :
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Créez le script de votre modèle. Ce script sera copié en tant que commande d'entraînement lors d'une étape ultérieure.
N'exécutez pas encore le script du modèle.
MaxText
MaxText est un LLM Open Source hautes performances et hautement évolutif écrit en Python et JAX purs, et ciblant les TPU et les GPU pour l'entraînement et l'inférence. Google Cloud
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma est une famille de LLM à poids ouverts développée par Google DeepMind, basée sur la recherche et la technologie Gemini.
python3 -m MaxText.train MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral est un modèle d'IA de pointe développé par Mistral AI, qui utilise une architecture MoE (Mixture of Experts) éparse.
python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama est une famille de LLM à pondération ouverte développée par Meta.
Pour obtenir un exemple d'exécution de Llama3 sur PyTorch, consultez les modèles torch_xla dans le dépôt torchprime.
MaxDiffusion
MaxDiffusion est une collection d'implémentations de référence de divers modèles de diffusion latente écrits en Python et JAX purs, qui s'exécutent sur des appareils XLA, y compris les Cloud TPU et les GPU. Stable Diffusion est un modèle latent de texte vers image qui génère des images photoréalistes à partir de n'importe quelle entrée de texte.
Vous devez installer une branche Git spécifique pour exécuter MaxDiffusion, comme indiqué dans la commande
git clone
suivante.Script d'entraînement :
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 resolution=1024 per_device_batch_size=1 output_dir=${OUT_DIR} jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=200 attention=flash run_name=sdxl-ddp-v6e
Exportez les variables suivantes :
export CLUSTER_NAME=CLUSTER_NAME export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=NUM_SLICES export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
Descriptions des variables d'environnement
Variable Description CLUSTER_NAME
Nom de votre cluster XPK. ACCELERATOR_TYPE
Consultez Types d'accélérateurs. NUM_SLICES
Nombre de tranches de TPU. YOUR_MODEL_SCRIPT
Script de modèle à exécuter en tant que commande d'entraînement. Exécutez le modèle à l'aide du script que vous avez créé à l'étape précédente. Vous devez spécifier l'option
--base-docker-image
pour utiliser l'image de base MaxText, ou spécifier l'option--docker-image
et l'image que vous souhaitez utiliser.Facultatif : Vous pouvez activer la journalisation de débogage en incluant l'option
--enable-debug-logs
. Pour en savoir plus, consultez Déboguer JAX sur MaxText.Facultatif : Vous pouvez créer un test Vertex AI pour importer des données dans Vertex AI TensorBoard en incluant l'indicateur
--use-vertex-tensorboard
. Pour en savoir plus, consultez Surveiller JAX sur MaxText à l'aide de Vertex AI.python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command="${YOUR_MODEL_SCRIPT}"
Le résultat inclut un lien pour suivre votre charge de travail. Ouvrez le lien et cliquez sur l'onglet Journaux pour suivre votre charge de travail en temps réel.
Déboguer JAX sur MaxText
Utilisez des commandes XPK supplémentaires pour déterminer pourquoi le cluster ou la charge de travail ne s'exécutent pas :
- Liste des charges de travail XPK
- Inspecteur XPK
- Activez la journalisation détaillée dans les journaux de charge de travail à l'aide de l'option
--enable-debug-logs
lorsque vous créez la charge de travail XPK.
Surveiller JAX sur MaxText à l'aide de Vertex AI
Pour utiliser TensorBoard, votre compte utilisateur Google Cloud doit disposer du rôle aiplatform.user
. Exécutez la commande suivante pour attribuer ce rôle :
gcloud projects add-iam-policy-binding your-project-id \ --member='user:your-email' \ --role='roles/aiplatform.user'
Affichez les données scalaires et de profil via TensorBoard géré par Vertex AI.
Augmentez le nombre de requêtes de gestion des ressources (CRUD) pour la zone que vous utilisez, de 600 à 5 000. Cela ne devrait pas poser de problème pour les petites charges de travail utilisant moins de 16 VM.
Installez les dépendances telles que
cloud-accelerator-diagnostics
pour Vertex AI :# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Créez votre cluster XPK à l'aide de l'indicateur
--create-vertex-tensorboard
, comme indiqué dans Créer Vertex AI TensorBoard. Vous pouvez également exécuter cette commande sur des clusters existants.Créez votre test Vertex AI lorsque vous exécutez votre charge de travail XPK à l'aide de l'indicateur
--use-vertex-tensorboard
et de l'indicateur facultatif--experiment-name
. Pour obtenir la liste complète des étapes, consultez Créer un test Vertex AI pour importer des données dans Vertex AI TensorBoard.
Les journaux incluent un lien vers un Vertex AI TensorBoard, semblable à ce qui suit :
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Vous pouvez également trouver le lien Vertex AI TensorBoard dans la console Google Cloud . Accédez à Vertex AI Experiments dans la console Google Cloud . Sélectionnez la région appropriée dans le menu déroulant.
Le répertoire TensorBoard est également écrit dans le bucket Cloud Storage que vous avez spécifié avec ${BASE_OUTPUT_DIR}
.
Supprimer les charges de travail XPK
Utilisez la commande xpk workload delete
pour supprimer une ou plusieurs charges de travail en fonction du préfixe ou de l'état du job. Cette commande peut être utile si vous avez envoyé des charges de travail XPK qui n'ont plus besoin d'être exécutées ou si des jobs sont bloqués dans la file d'attente.
Supprimer le cluster XPK
Exécutez la commande xpk cluster delete
pour supprimer un cluster :
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
Entraînement Llama et PyTorch/XLA sur une VM Cloud TPU v6e
Ce tutoriel explique comment entraîner des modèles Llama à l'aide de PyTorch/XLA sur Cloud TPU v6e avec l'ensemble de données WikiText.
Accéder à Hugging Face et au modèle Llama 3
Vous avez besoin d'un jeton d'accès utilisateur Hugging Face pour exécuter ce tutoriel. Pour savoir comment créer des jetons d'accès utilisateur, consultez la documentation Hugging Face sur les jetons d'accès utilisateur.
Vous devez également disposer de l'autorisation d'accéder au modèle Llama-3-8B sur Hugging Face. Pour y accéder, accédez au modèle Meta-Llama-3-8B sur HuggingFace et demandez-y l'accès.
Créer une VM Cloud TPU
Créez un Cloud TPU v6e avec huit puces pour exécuter le tutoriel.
Configurez des variables d'environnement :
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-8 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration
Créez une VM Cloud TPU :
gcloud alpha compute tpus tpu-vm create ${NODE_ID} --version=${RUNTIME_VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Installation
Installez le fork pytorch-tpu/transformers
de Hugging Face Transformers et ses dépendances. Ce tutoriel a été testé avec les versions de dépendances suivantes utilisées dans cet exemple :
torch
: compatible avec la version 2.5.0torch_xla[tpu]
: compatible avec la version 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'
Configurer les configurations de modèle
La commande d'entraînement de la section suivante, Exécuter le modèle, utilise deux fichiers de configuration JSON pour définir les paramètres du modèle et la configuration FSDP (Fully Sharded Data Parallel). Le sharding FSDP vous permet d'utiliser une taille de lot plus importante lors de l'entraînement en shardant les pondérations de votre modèle sur plusieurs TPU. Lorsque vous entraînez des modèles plus petits, il peut suffire d'utiliser le parallélisme des données et de répliquer les pondérations sur chaque appareil. Pour en savoir plus sur le partitionnement des Tensors sur plusieurs appareils dans PyTorch/XLA, consultez le guide de l'utilisateur SPMD PyTorch/XLA.
Créez le fichier de configuration des paramètres du modèle. Voici la configuration des paramètres du modèle pour Llama-3-8B. Pour les autres modèles, recherchez la configuration sur Hugging Face. Par exemple, consultez la configuration Llama-2-7B.
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Créez le fichier de configuration FSDP :
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
Pour en savoir plus sur le FSDP, consultez FSDPv2.
Importez les fichiers de configuration dans vos VM Cloud TPU à l'aide de la commande suivante :
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${NODE_ID}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Exécuter le modèle
À l'aide des fichiers de configuration que vous avez créés dans la section précédente, exécutez le script run_clm.py
pour entraîner le modèle Llama-3-8B sur l'ensemble de données WikiText. L'exécution du script d'entraînement prend environ 10 minutes sur un Cloud TPU v6e-8.
Connectez-vous à Hugging Face sur votre Cloud TPU à l'aide de la commande suivante :
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Exécutez l'entraînement du modèle :
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Résoudre les problèmes liés à PyTorch/XLA
Si vous avez défini les variables facultatives pour le débogage dans la section précédente, le profil du modèle sera stocké à l'emplacement spécifié par la variable PROFILE_LOGDIR
. Vous pouvez extraire le fichier xplane.pb
stocké à cet emplacement et utiliser tensorboard
pour afficher les profils dans votre navigateur en suivant les instructions TensorBoard.
Si PyTorch/XLA ne fonctionne pas comme prévu, consultez le guide de dépannage, qui contient des suggestions pour déboguer, profiler et optimiser votre modèle.
Résultats du benchmarking
La section suivante contient les résultats des tests comparatifs pour MaxDiffusion sur v6e.
MaxDiffusion
Nous avons exécuté le script d'entraînement pour MaxDiffusion sur un TPU v6e-4, un TPU v6e-16 et deux TPU v6e-16. Consultez les débits dans le tableau suivant.
v6e-4 | v6e-16 | Deux v6e-16 | |
---|---|---|---|
Étapes de l'entraînement | 0,069 | 0.073 | 0,13 |
Taille du lot global | 8 | 32 | 64 |
Débit (exemples/s) | 115.9 | 438,4 | 492.3 |