Entraîner Llama 3 avec PyTorch sur TPU v5e
Ce tutoriel explique comment entraîner un modèle Llama-3-8B à l'aide de PyTorch/XLA sur un TPU v5e à l'aide de l'ensemble de données WikiText. Pour en savoir plus sur le modèle, consultez Meta-Llama-3-8B.
Le modèle Llama-3-8B est hébergé sur la plate-forme Hugging Face.
Il existe deux versions de Meta-Llama-3-8B, l'une à utiliser avec Transformers et l'autre avec le codebase Llama 3 d'origine. Ce tutoriel utilise la version Transformers, car elle:
Intégration parfaite à l'écosystème Hugging Face: il est ainsi plus facile d'ajuster le modèle, d'utiliser des pipelines prédéfinis et d'accéder à une vaste collection d'ensembles de données et d'outils.
Permet de faire preuve de flexibilité et de personnalisation: la version Transformers offre une flexibilité et des options de personnalisation importantes pour affiner et déployer le modèle.
Fournit une assistance de la communauté: la communauté Hugging Face fournit une documentation, des tutoriels et une assistance complets pour l'utilisation des modèles Transformers.
Pour en savoir plus sur les Transformers, consultez la documentation sur les Transformers Hugging Face.
Pour accéder au modèle Meta-Llama-3-8B et l'utiliser, y compris pour télécharger ses poids et son tokenizer, vous avez besoin d'un jeton d'accès utilisateur Hugging Face. Le jeton fournit les éléments suivants:
Authentification et autorisation: le jeton d'accès sert d'identifiant et permet aux serveurs Hugging Face d'autoriser votre accès aux ressources du modèle. Ainsi, seuls les utilisateurs autorisés peuvent télécharger et utiliser le modèle.
Sécurité: Hugging Face utilise des jetons d'accès pour protéger ses modèles et empêcher tout accès non autorisé ou toute utilisation abusive.
Pour en savoir plus sur la création et l'utilisation d'un jeton d'accès pour ce tutoriel, consultez la section Exécuter le modèle. Pour en savoir plus sur la création et l'utilisation de jetons d'accès, consultez la documentation de Hugging Face sur les jetons d'accès utilisateur.
Vous devez également disposer d'une autorisation pour accéder au modèle Llama 3 8B sur Hugging Face. Pour obtenir cette autorisation, accédez au modèle Meta-Llama-3-8B sur Hugging Face et demandez l'accès.
Préparation du provisionnement d'un TPU v5litepod-16
Ce tutoriel a été testé à l'aide des variables d'environnement Cloud TPU suivantes. Vous pouvez utiliser d'autres variables pour provisionner votre TPU, à condition que le type d'accélérateur, la zone et la version de l'environnement d'exécution soient compatibles.
Par exemple, dans ce tutoriel, europe-west4-b
est utilisé comme zone. Vous pouvez utiliser n'importe quelle autre zone compatible avec la version de TPU (type d'accélérateur) que vous exécutez (v5litepod-16 dans ce tutoriel).
Définissez les variables d'environnement de la VM TPU suivantes.
export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v5litepod-16 export ZONE=europe-west4-b export RUNTIME_VERSION=v2-alpha-tpuv5-lite export QUEUED_RESOURCE_ID=queued-resource-id export VALID_UNTIL_DURATION=1d
Une fois que vous avez accès au modèle Meta-Llama-3-8B sur Hugging Face, préparez l'environnement TPU pour exécuter le tutoriel.
Suivez le guide Configurer l'environnement Cloud TPU pour vous assurer que vous disposez des droits d'accès appropriés pour utiliser Cloud TPU.
Créez une identité de service pour la VM TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
Créez un compte de service TPU et accordez l'accès aux services Google Cloud .
Les comptes de service permettent au service Google Cloud TPU d'accéder à d'autres services Google Cloud. Un compte de service géré par l'utilisateur est recommandé. Vous pouvez créer un compte de service à partir de la console Google Cloud ou à l'aide de la commande
gcloud
.Créez un compte de service à l'aide de l'outil de ligne de commande
gcloud
:gcloud iam service-accounts create your-service-account-name \ --description="your-sa-description" \ --display-name="your-sa-display-name" export SERVICE_ACCOUNT_NAME=your-service-account-name
Créez un compte de service à partir de la console Google Cloud:
- Accédez à la page "Comptes de service" dans la console Google Cloud.
- Cliquez sur Créer un compte de service.
- Saisissez le nom du compte de service.
- (Facultatif) Saisissez une description du compte de service.
- Cliquez sur Créer et continuez.
- Choisissez les rôles que vous souhaitez accorder au compte de service.
- Cliquez sur Continuer.
- (Facultatif) Spécifiez les utilisateurs ou les groupes autorisés à gérer le compte de service.
- Cliquez sur OK pour terminer la création du compte de service.
Une fois votre compte de service créé, procédez comme suit pour accorder des rôles de compte de service.
Les rôles suivants sont nécessaires:
- Administrateur TPU: rôle nécessaire pour créer un TPU
- Storage Admin: rôle nécessaire pour accéder à Cloud Storage
- Rédacteur de journaux
- Rédacteur de métriques Monitoring: nécessaire pour écrire des métriques dans Cloud Monitoring
Votre administrateur doit vous accorder le rôle
roles/resourcemanager.projectIamAdmin
pour que vous puissiez attribuer des rôles IAM aux utilisateurs. Un utilisateur disposant du rôleroles/resourcemanager.projectIamAdmin
Administrateur IAM de projet peut également attribuer ce rôle.Utilisez les commandes
gcloud
suivantes pour ajouter des rôles de compte de service:gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/tpu.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/storage.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/logging.logWriter gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/monitoring.metricWriter
Vous pouvez également attribuer des rôles à l'aide de la console Google Cloud.
Dans la console Google Cloud, sélectionnez les rôles suivants:
- Sélectionnez votre compte de service, puis cliquez sur Ajouter un compte principal.
- Dans le champ Nouveaux comptes principaux, saisissez l'adresse e-mail de votre compte de service.
- Dans le menu déroulant Sélectionner un rôle, recherchez le rôle (par exemple, Storage Admin) et sélectionnez-le.
- Cliquez sur Enregistrer.
Authentifiez-vous avec Google Cloud et configurez le projet et la zone par défaut pour Google Cloud CLI.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Sécuriser la capacité
Lorsque vous êtes prêt à sécuriser la capacité de TPU, consultez la page sur les quotas pour en savoir plus sur le système Cloud Quotas. Si vous avez d'autres questions sur la sécurisation de la capacité, contactez votre équipe commerciale ou votre équipe chargée de votre compte Cloud TPU.
Provisionner l'environnement Cloud TPU
Vous pouvez provisionner des VM TPU avec GKE, avec GKE et XPK, ou en tant que ressources mises en file d'attente.
Prérequis
- Ce tutoriel a été testé avec Python 3.10 ou version ultérieure.
- Vérifiez que votre projet dispose d'un quota
TPUS_PER_TPU_FAMILY
suffisant, qui spécifie le nombre maximal de chips auxquels vous pouvez accéder dans votre projetGoogle Cloud . - Vérifiez que votre projet dispose d'un quota TPU suffisant pour :
- Quota de VM TPU
- Quota d'adresses IP
- Quota Hyperdisk équilibré
- Autorisations de projet utilisateur
- Si vous utilisez GKE avec XPK, consultez la section Autorisations de la console Cloud sur le compte utilisateur ou de service pour connaître les autorisations requises pour exécuter XPK.
Provisionner un TPU v5litepod-16
Créez une VM TPU:
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT_NAME} \ --spot
Vérifiez que le TPU est à l'état
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Lorsque le TPU devient actif (ACTIVE
), un résultat semblable à celui-ci s'affiche:
createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
state: ACTIVE
tpu:
nodeSpec:
- node:
acceleratorType: v5litepod-16
networkConfig:
enableExternalIps: true
network: default
queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
runtimeVersion: v2-alpha-tpuv5-lite
schedulingConfig: {}
my-service-account@your-project-id.iam.gserviceaccount.com
email: 19672137854-compute@developer.iam.gserviceaccount.com
shieldedInstanceConfig: {}
nodeId: tpu-name
parent: projects/19672137403/locations/zone
Installation
Installez le fork pytorch-tpu/transformers
de Hugging Face Transformers et ses dépendances. Ce tutoriel a été testé avec les versions de dépendance suivantes:
torch
: compatible avec la version 2.6.0torch_xla[tpu]
: compatible avec la version 2.6.0jax
: 0.4.38jaxlib
: 0.4.38
Installer le logiciel du framework et ses dépendances
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git sudo apt install python3.10-venv python -m venv /home/$USER/venv/ source ~/venv/bin/activate cd transformers pip3 install --user -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Une fois l'installation terminée, un résultat semblable à celui-ci s'affiche:
Collecting jax==0.4.38
Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Configurer les configurations de modèle
La commande d'entraînement de la section suivante, Exécuter le modèle, utilise deux fichiers de configuration JSON pour définir les paramètres du modèle et la configuration FSDP (Fully Sharded Data Parallel). Le fractionnement FSDP permet aux poids du modèle de s'adapter à une taille de lot plus importante lors de l'entraînement. Lors de l'entraînement avec des modèles plus petits, il peut suffire d'utiliser le parallélisme des données et de répliquer les poids sur chaque appareil. Pour en savoir plus sur le fractionnement des tenseurs sur plusieurs appareils dans PyTorch/XLA, consultez le guide de l'utilisateur de SPMD PyTorch/XLA.
Cette commande crée le fichier de configuration des paramètres du modèle pour Llama3-8B. Pour les autres modèles, recherchez la configuration sur Hugging Face. Consultez par exemple la configuration Llama2-7B.
cat > llama-config.json <<EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Créez le fichier de configuration FSDP:
cat > fsdp-config.json <<EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
Pour en savoir plus sur le FSDP, consultez FSDPv2.
Importez les fichiers de configuration dans vos VM TPU à l'aide des commandes suivantes:
ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent. gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Cette commande génère un résultat semblable à celui-ci:
Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers. SCP: Attempting to connect to worker 0... SCP: Attempting to connect to worker 1... SCP: Attempting to connect to worker 2... SCP: Attempting to connect to worker 3... llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.0KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00
Exécuter le modèle
À l'aide des fichiers de configuration que vous avez créés dans la section précédente, exécutez le script run_clm.py
pour entraîner le modèle Llama 3 8B sur l'ensemble de données WikiText. L'exécution du script d'entraînement sur un TPU v5litepod-16 prend environ 10 minutes.
Générez un nouveau jeton Hugging Face si vous n'en possédez pas déjà un:
- Cliquez sur Votre profil > Paramètres > Jetons d'accès.
- Sélectionnez New Token (Nouveau jeton).
- Spécifiez le nom de votre choix et un rôle d'au moins "Read" (lecture).
- Sélectionnez Générer un jeton.
Utilisez votre jeton Hugging Face pour vous connecter à Hugging Face depuis votre VM TPU à l'aide de la commande suivante.
Remplacez la variable de jeton
huggingface-cli login
par celle générée à partir de Hugging Face à l'étape précédente:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' pip install -U "huggingface_hub[cli]" export PATH="/home/$USER/.local/bin/:$PATH" huggingface-cli login --token hf_abcxyzEFg'
Cette commande vous connectera à Hugging Face et affichera le jeton actif actuel.
Exécutez l'entraînement du modèle:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' source ~/venv/bin/activate export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=your-bucket/profile_path python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
L'étape d'entraînement prend environ 10 minutes. Vers la fin de l'entraînement, des messages semblables à ceux-ci s'affichent:
[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >> Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >> Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >> Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >> Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >> Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >> Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >> Number of trainable parameters = 8,030,261,248
0%| | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
5%|▌ | 1/20 [00:07<02:29, 7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
5%|▌ | 1/20 [00:07<02:29, 7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
10%|█ | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
10%|█ | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
15%|█▌ | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
15%|█▌ | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
20%|██ | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
85%|████████▌ | 17/20 [07:55<00:06, 2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
Effectuer un nettoyage
Une fois l'entraînement terminé, suivez la procédure ci-dessous pour supprimer la ressource mise en file d'attente et la VM TPU. La facturation de votre utilisation de la VM TPU sera alors interrompue.
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --force \ --async