Train Llama 3 using PyTorch on TPU v5e
This tutorial describes how to train a Llama-3-8B model using PyTorch/XLA on TPU v5e using the WikiText dataset. See Meta-Llama-3-8B for model details.
The Llama-3-8B model is hosted on the Hugging Face platform.
There are two versions of Meta-Llama-3-8B, one for use with Transformers and another with the original Llama 3 codebase. This tutorial uses the Transformers version because it:
Integrates seamlessly with the Hugging Face ecosystem: This makes it easier to fine-tune the model, use prebuilt pipelines, and access a vast collection of datasets and tools.
Enables flexibility and customization: The Transformers version offers significant flexibility and customization options for fine-tuning and deploying the model.
Provides community support: The Hugging Face community provides extensive documentation, tutorials, and support for using Transformers models.
For more information about Transformers, see the Hugging Face Transformers documentation.
To access and use the Meta-Llama-3-8B model, including downloading its weights and tokenizer, you need a Hugging Face user access token. The token provides:
Authentication and Authorization: The access token acts as a credential, allows the Hugging Face servers to authorize your access to the model's resources. This ensures that only authorized users can download and use the model.
Security: Hugging Face uses access tokens to protect its models and prevent unauthorized access or misuse.
For information about creating and using an access token for this tutorial, see Run the model. For more comprehensive information about creating and using access tokens, see the Hugging Face documentation on user access tokens.
You also need permission to access the Llama 3 8B model on Hugging Face. To get that permission, go to the Meta-Llama-3-8B model on Hugging Face and request access.
Prepare to provision a TPU v5litepod-16
This tutorial was tested using the following Cloud TPU
environment variables. You can use other variables to provision your TPU,
as long as the accelerator type, zone, and runtime version are compatible.
For example, in this tutorial, europe-west4-b
is used as the zone throughout. You can use any other zone that supports the
TPU version (accelerator type) you are running (v5litepod-16 in this tutorial).
Set the following TPU VM environment variables.
export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v5litepod-16 export ZONE=europe-west4-b export RUNTIME_VERSION=v2-alpha-tpuv5-lite export QUEUED_RESOURCE_ID=queued-resource-id export VALID_UNTIL_DURATION=1d
When you have access to the Meta-Llama-3-8B model on Hugging Face, prepare the TPU environment to run the tutorial.
Follow Set up the Cloud TPU environment guide to ensure you have appropriate access to use Cloud TPUs.
Create a service identity for the TPU VM.
gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
Create a TPU service account and grant access to Google Cloud services.
Service accounts allow the Google Cloud TPU service to access other Google Cloud services. A user-managed service account is recommended. You can create a service account from the Google Cloud console or through the
gcloud
command.Create a service account using the
gcloud
command-line tool:gcloud iam service-accounts create your-service-account-name \ --description="your-sa-description" \ --display-name="your-sa-display-name" export SERVICE_ACCOUNT_NAME=your-service-account-name
Create a service account from the Google Cloud console:
- Go to the Service Accounts page in the Google Cloud console.
- Click Create service account.
- Enter the service account name.
- (Optional) Enter a description for the service account.
- Click Create and continue.
- Choose the roles you want to grant to the service account.
- Click Continue.
- (Optional) Specify users or groups that can manage the service account.
- Click Done to finish creating the service account.
After creating your service account, follow these steps to grant service account roles.
The following roles are necessary:
- TPU Admin: Needed to create a TPU
- Storage Admin: Needed for accessing Cloud Storage
- Logs Writer
- Monitoring Metric Writer: Needed for writing metrics to Cloud Monitoring
Your administrator must grant you the
roles/resourcemanager.projectIamAdmin
in order for you to assign IAM roles to users. A user with the Project IAM Adminroles/resourcemanager.projectIamAdmin
role can also grant this role.Use the following
gcloud
commands to add service account roles:gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/tpu.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/storage.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/logging.logWriter gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/monitoring.metricWriter
You can also assign roles using the Google Cloud console.
From the Google Cloud console, select the following roles:
- Select your service account and click Add Principal.
- In the New Principals field, enter the email address of your service account.
- In the Select a role drop-down, search for role (for example, Storage Admin) and select it.
- Click Save.
Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Secure capacity
When you are ready to secure TPU capacity, review the quotas page to learn about the Cloud Quotas system. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.
Provision the Cloud TPU environment
You can provision TPU VMs with GKE, with GKE and XPK, or as queued resources.
Prerequisites
- This tutorial has been tested with Python 3.10 or later.
- Verify that your project has enough
TPUS_PER_TPU_FAMILY
quota, which specifies the maximum number of chips you can access within your Google Cloud project. - Verify that your project has enough TPU quota for:
- TPU VM quota
- IP Address quota
- Hyperdisk balanced quota
- User project permissions
- If you are using GKE with XPK, see Cloud Console Permissions on the user or service account for the permissions needed to run XPK.
Provision a TPU v5litepod-16
Create a TPU VM:
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT_NAME} \ --spot
Verify the TPU is in the
ACTIVE
state:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
When the TPU becomes active (ACTIVE
), you will see output similar to:
createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
state: ACTIVE
tpu:
nodeSpec:
- node:
acceleratorType: v5litepod-16
networkConfig:
enableExternalIps: true
network: default
queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
runtimeVersion: v2-alpha-tpuv5-lite
schedulingConfig: {}
my-service-account@your-project-id.iam.gserviceaccount.com
email: 19672137854-compute@developer.iam.gserviceaccount.com
shieldedInstanceConfig: {}
nodeId: tpu-name
parent: projects/19672137403/locations/zone
Installation
Install the pytorch-tpu/transformers
fork of
Hugging Face Transformers and dependencies. This tutorial was tested with the
following dependency versions:
torch
: compatible with 2.6.0torch_xla[tpu]
: compatible with 2.6.0jax
: 0.4.38jaxlib
: 0.4.38
Install framework software and dependencies
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git sudo apt install python3.10-venv python -m venv /home/$USER/venv/ source ~/venv/bin/activate cd transformers pip3 install --user -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
When the installation completes, you will see output similar to:
Collecting jax==0.4.38
Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Set up model configurations
The training command in the next section, Run the model, uses two JSON config files to define model parameters and FSDP (Fully Sharded Data Parallel) configuration. FSDP sharding is used for the model weights to fit a bigger batch size while training. When training with smaller models, it might be sufficient to use data parallelism and replicate the weights on each device. For more information about how to shard tensors across devices in PyTorch/XLA, see PyTorch/XLA SPMD User Guide.
This command creates the model parameter config file for Llama3-8B. For other models, find the config on Hugging Face. For example, see the Llama2-7B config.
cat > llama-config.json <<EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Create the FSDP config file:
cat > fsdp-config.json <<EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
For more information about FSDP, see FSDPv2.
Upload the config files to your TPU VMs using the following commands:
ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent. gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
This command will generate output similar to:
Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers. SCP: Attempting to connect to worker 0... SCP: Attempting to connect to worker 1... SCP: Attempting to connect to worker 2... SCP: Attempting to connect to worker 3... llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.0KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00
Run the model
Using the config files you created in the previous section, run the run_clm.py
script to train the Llama 3 8B model on the WikiText dataset. The training
script takes approximately 10 minutes to run on a TPU v5litepod-16.
Generate a new Hugging Face token if you don't already have one:
- Click Your Profile > Settings > Access Tokens.
- Select New Token.
- Specify a Name of your choice and a Role of at least Read.
- Select Generate a token.
Use your Hugging Face token to sign in to Hugging Face from your TPU VM using the following command.
Replace the
huggingface-cli login
token variable with the one that was generated from Hugging Face in the previous step:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' pip install -U "huggingface_hub[cli]" export PATH="/home/$USER/.local/bin/:$PATH" huggingface-cli login --token hf_abcxyzEFg'
This command will log you into Hugging Face and display the current active token.
Run the model training:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' source ~/venv/bin/activate export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=your-bucket/profile_path python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
The training step takes about 10 minutes. Toward the end of the training you will see messages similar to:
[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >> Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >> Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >> Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >> Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >> Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >> Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >> Number of trainable parameters = 8,030,261,248
0%| | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
5%|▌ | 1/20 [00:07<02:29, 7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
5%|▌ | 1/20 [00:07<02:29, 7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
10%|█ | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
10%|█ | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
15%|█▌ | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
15%|█▌ | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
20%|██ | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
85%|████████▌ | 17/20 [07:55<00:06, 2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
Clean up
After the training completes, use the following step to delete the queued resource and TPU VM. This will discontinue billing for your TPU VM use.
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --force \ --async