Train Llama 3 using PyTorch on TPU v5e

This tutorial describes how to train a Llama-3-8B model using PyTorch/XLA on TPU v5e using the WikiText dataset. See Meta-Llama-3-8B for model details.

The Llama-3-8B model is hosted on the Hugging Face platform.

There are two versions of Meta-Llama-3-8B, one for use with Transformers and another with the original Llama 3 codebase. This tutorial uses the Transformers version because it:

  • Integrates seamlessly with the Hugging Face ecosystem: This makes it easier to fine-tune the model, use prebuilt pipelines, and access a vast collection of datasets and tools.

  • Enables flexibility and customization: The Transformers version offers significant flexibility and customization options for fine-tuning and deploying the model.

  • Provides community support: The Hugging Face community provides extensive documentation, tutorials, and support for using Transformers models.

For more information about Transformers, see the Hugging Face Transformers documentation.

To access and use the Meta-Llama-3-8B model, including downloading its weights and tokenizer, you need a Hugging Face user access token. The token provides:

  • Authentication and Authorization: The access token acts as a credential, allows the Hugging Face servers to authorize your access to the model's resources. This ensures that only authorized users can download and use the model.

  • Security: Hugging Face uses access tokens to protect its models and prevent unauthorized access or misuse.

For information about creating and using an access token for this tutorial, see Run the model. For more comprehensive information about creating and using access tokens, see the Hugging Face documentation on user access tokens.

You also need permission to access the Llama 3 8B model on Hugging Face. To get that permission, go to the Meta-Llama-3-8B model on Hugging Face and request access.

Prepare to provision a TPU v5litepod-16

This tutorial was tested using the following Cloud TPU environment variables. You can use other variables to provision your TPU, as long as the accelerator type, zone, and runtime version are compatible. For example, in this tutorial, europe-west4-b is used as the zone throughout. You can use any other zone that supports the TPU version (accelerator type) you are running (v5litepod-16 in this tutorial).

Set the following TPU VM environment variables.

   export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id
   export PROJECT_ID=your-project-id
   export ACCELERATOR_TYPE=v5litepod-16
   export ZONE=europe-west4-b
   export RUNTIME_VERSION=v2-alpha-tpuv5-lite
   export QUEUED_RESOURCE_ID=queued-resource-id
   export VALID_UNTIL_DURATION=1d

When you have access to the Meta-Llama-3-8B model on Hugging Face, prepare the TPU environment to run the tutorial.

  1. Follow Set up the Cloud TPU environment guide to ensure you have appropriate access to use Cloud TPUs.

  2. Create a service identity for the TPU VM.

    gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
  3. Create a TPU service account and grant access to Google Cloud services.

    Service accounts allow the Google Cloud TPU service to access other Google Cloud services. A user-managed service account is recommended. You can create a service account from the Google Cloud console or through the gcloud command.

    Create a service account using the gcloud command-line tool:

    gcloud iam service-accounts create your-service-account-name \
    --description="your-sa-description" \
    --display-name="your-sa-display-name"
    export SERVICE_ACCOUNT_NAME=your-service-account-name

    Create a service account from the Google Cloud console:

    1. Go to the Service Accounts page in the Google Cloud console.
    2. Click Create service account.
    3. Enter the service account name.
    4. (Optional) Enter a description for the service account.
    5. Click Create and continue.
    6. Choose the roles you want to grant to the service account.
    7. Click Continue.
    8. (Optional) Specify users or groups that can manage the service account.
    9. Click Done to finish creating the service account.

    After creating your service account, follow these steps to grant service account roles.

    The following roles are necessary:

    • TPU Admin: Needed to create a TPU
    • Storage Admin: Needed for accessing Cloud Storage
    • Logs Writer
    • Monitoring Metric Writer: Needed for writing metrics to Cloud Monitoring

    Your administrator must grant you the roles/resourcemanager.projectIamAdmin in order for you to assign IAM roles to users. A user with the Project IAM Admin roles/resourcemanager.projectIamAdmin role can also grant this role.

    Use the following gcloud commands to add service account roles:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/tpu.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/storage.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/logging.logWriter
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/monitoring.metricWriter

    You can also assign roles using the Google Cloud console.

    From the Google Cloud console, select the following roles:

    1. Select your service account and click Add Principal.
    2. In the New Principals field, enter the email address of your service account.
    3. In the Select a role drop-down, search for role (for example, Storage Admin) and select it.
    4. Click Save.
  4. Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Secure capacity

When you are ready to secure TPU capacity, review the quotas page to learn about the Cloud Quotas system. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.

Provision the Cloud TPU environment

You can provision TPU VMs with GKE, with GKE and XPK, or as queued resources.

Prerequisites

  • This tutorial has been tested with Python 3.10 or later.
  • Verify that your project has enough TPUS_PER_TPU_FAMILY quota, which specifies the maximum number of chips you can access within your Google Cloud project.
  • Verify that your project has enough TPU quota for:
    • TPU VM quota
    • IP Address quota
    • Hyperdisk balanced quota
  • User project permissions

Provision a TPU v5litepod-16

  1. Create a TPU VM:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID}   \
    --node-id=${TPU_NAME}  \
    --project=${PROJECT_ID}   \
    --zone=${ZONE}   \
    --accelerator-type=${ACCELERATOR_TYPE}   \
    --runtime-version=${RUNTIME_VERSION}   \
    --service-account=${SERVICE_ACCOUNT_NAME}   \
    --spot
  2. Verify the TPU is in the ACTIVE state:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project=${PROJECT_ID} \
    --zone=${ZONE}

When the TPU becomes active (ACTIVE), you will see output similar to:

createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
  state: ACTIVE
tpu:
  nodeSpec:
  - node:
      acceleratorType: v5litepod-16
      networkConfig:
        enableExternalIps: true
        network: default
      queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
      runtimeVersion: v2-alpha-tpuv5-lite
      schedulingConfig: {}
      my-service-account@your-project-id.iam.gserviceaccount.com
        email: 19672137854-compute@developer.iam.gserviceaccount.com
      shieldedInstanceConfig: {}
    nodeId: tpu-name
    parent: projects/19672137403/locations/zone

Installation

Install the pytorch-tpu/transformers fork of Hugging Face Transformers and dependencies. This tutorial was tested with the following dependency versions:

  • torch: compatible with 2.6.0
  • torch_xla[tpu]: compatible with 2.6.0
  • jax: 0.4.38
  • jaxlib: 0.4.38

Install framework software and dependencies

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    sudo apt install python3.10-venv
    python -m venv /home/$USER/venv/
    source ~/venv/bin/activate
    cd transformers
    pip3 install --user -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

When the installation completes, you will see output similar to:

Collecting jax==0.4.38
  Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
  Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
  Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
  Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0

Set up model configurations

The training command in the next section, Run the model, uses two JSON config files to define model parameters and FSDP (Fully Sharded Data Parallel) configuration. FSDP sharding is used for the model weights to fit a bigger batch size while training. When training with smaller models, it might be sufficient to use data parallelism and replicate the weights on each device. For more information about how to shard tensors across devices in PyTorch/XLA, see PyTorch/XLA SPMD User Guide.

  1. This command creates the model parameter config file for Llama3-8B. For other models, find the config on Hugging Face. For example, see the Llama2-7B config.

    cat > llama-config.json <<EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. Create the FSDP config file:

    cat > fsdp-config.json <<EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    For more information about FSDP, see FSDPv2.

  3. Upload the config files to your TPU VMs using the following commands:

     ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent.
     gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

    This command will generate output similar to:

    Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers.
    SCP: Attempting to connect to worker 0...
    SCP: Attempting to connect to worker 1...
    SCP: Attempting to connect to worker 2...
    SCP: Attempting to connect to worker 3...
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.0KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00

Run the model

Using the config files you created in the previous section, run the run_clm.py script to train the Llama 3 8B model on the WikiText dataset. The training script takes approximately 10 minutes to run on a TPU v5litepod-16.

  1. Generate a new Hugging Face token if you don't already have one:

    1. Click Your Profile > Settings > Access Tokens.
    2. Select New Token.
    3. Specify a Name of your choice and a Role of at least Read.
    4. Select Generate a token.
  2. Use your Hugging Face token to sign in to Hugging Face from your TPU VM using the following command.

    Replace the huggingface-cli login token variable with the one that was generated from Hugging Face in the previous step:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    pip install -U "huggingface_hub[cli]"
    export PATH="/home/$USER/.local/bin/:$PATH"
    huggingface-cli login --token hf_abcxyzEFg'

    This command will log you into Hugging Face and display the current active token.

  3. Run the model training:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --worker=all \
        --command='
        source ~/venv/bin/activate
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=your-bucket/profile_path
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

The training step takes about 10 minutes. Toward the end of the training you will see messages similar to:

[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >>   Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >>   Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >>   Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >>   Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >>   Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >>   Number of trainable parameters = 8,030,261,248
  0%|          | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
  5%|         | 1/20 [00:07<02:29,  7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  5%|         | 1/20 [00:07<02:29,  7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
 10%|         | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
 10%|█         | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
 15%|█▌        | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
 15%|█▌        | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
 20%|██        | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
 85%|████████▌ | 17/20 [07:55<00:06,  2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)

100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>

Training completed. Do not forget to share your model on huggingface.co/models =)

Clean up

After the training completes, use the following step to delete the queued resource and TPU VM. This will discontinue billing for your TPU VM use.

  gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --force \
       --async