Cloud TPU v5p training
Cloud TPU v5p is Google Cloud's fifth generation Cloud TPU and the successor to the v4 TPU. v5p is optimized for large scale training and to be a leading platform for the development of foundational LLMs, diffusion models, and generative AI. At a high level, v5p provides up to 2x the performance of v4, while also packing 2x more TPUs into a Pod (6k largest slice versus 3k in v4), resulting in up to 4x performance at a Pod-level. It also runs at a higher clock frequency (1.75Ghz vs. 1.05Ghz), adds SparseCore for large scale embeddings, and triples High Bandwidth Memory (HBM) capacity.
Cloud TPU v5p concepts
If you are new to Cloud TPUs, check out the TPU documentation home.
Cloud TPU concepts (for example, slices, hosts, and TensorCores) and Cloud TPU system architecture for all Cloud TPU versions are described in the Cloud TPU System Architecture page.
Each Cloud TPU version requires specific accelerator types for training or inference. These accelerator types are described in v5p configurations.
Manage TPU resources
All of the commands in this document assume you are creating a TPU v5p VM. For more information for the commands to create TPU VMs, see Managing TPUs or Queued resources user guide for managing queued resources. To make the commands easier to run, the code samples in this document use the following environment variables:
export PROJECT_ID=your-project export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=your-tpu-name
Environment variable descriptions
- The Google Cloud project in which you are creating your TPU.
- The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
- The zone where you plan to create your Cloud TPU.
- The TPU software version.
- The user-defined name of the TPU you are using.
Framework Setup
This section describes the general setup process for model training using JAX or PyTorch with TPU v5p.
Setup for JAX
If you have slice shapes greater than 4 chips, you will have multiple VMs
in one slice. In this case, you need to use the --worker=all
to run the installation on all TPU VMs using a single command:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f'
You can run the following command to check number of devices (the outputs shown here were produced with a v5p-32 slice). This code tests that everything is installed correctly by checking that JAX sees the Cloud TPU TensorCores and can run basic operations:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
The output will be similar to the following:
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... 16 4 16 4 16 4 16 4
shows the total number of chips in
the given slice. jax.local_device_count()
indicates the
count of chips accessible by a single VM in this slice.
# Check the number of chips in the given slice by summing the count of chips # from all VMs through the # jax.local_device_count() API call. gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
The output will be similar to the following:
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... [16. 16. 16. 16.] [16. 16. 16. 16.] [16. 16. 16. 16.] [16. 16. 16. 16.]
Use --node=all
to run the command on all Multislice workers.
gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \ --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
Try the JAX tutorials in this document to get started with v5p training using JAX.
Setup for PyTorch
The PJRT runtime is the only supported runtime for v5p, and PyTorch 2.1+ uses PJRT as the default runtime for all TPU versions. This section describes how to start using PJRT on v5p Pods with PyTorch/XLA 2.2.0 for all workers.
Install dependencies
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' sudo apt-get update sudo apt-get install libopenblas-dev -y pip install numpy pip install torch torch_xla[tpu] torchvision -f -f '
Use a Python script with PJRT to validate your installation. The script shows the available TPU devices (the outputs shown here were produced with a v5p-32 slice).
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} --zone ${ZONE} --worker=all \ --command=' PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))" '
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... ['xla:0', 'xla:1', 'xla:2', 'xla:3'] ['xla:0', 'xla:1', 'xla:2', 'xla:3'] ['xla:0', 'xla:1', 'xla:2', 'xla:3'] ['xla:0', 'xla:1', 'xla:2', 'xla:3']
Use --node=all
to run the command on all Multislice workers.
gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \ --command=' PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))" '
Try the PyTorch tutorials in this document to get started with v5p training using PyTorch.
Monitor and profile
Cloud TPU v5p supports monitoring and profiling using the same methods as previous generations of Cloud TPU. You can read Profile your model with Cloud TPU tools to learn more about profiling and Monitoring Cloud TPU VMs to learn more about monitoring.
Training tutorials
This section focuses on training tutorials for a single slice.
Adapting these tutorials to Multislice training can be
achieved by adding the --node=all
flag to SSH commands.
For details and best practices, refer to the
Multislice introduction.
JAX tutorials
Train Diffusion 2.1
This tutorial shows you how to train the Stable Diffusion model from HuggingFace using the Pokémon dataset on Cloud TPU v5p.
The Stable Diffusion model is a latent text-to-image model that generates photo-realistic images from any text input. For more information, see the following resources:
Set up
Create environment variables:
export GCS_BUCKET_NAME=your-bucket export PROJECT_ID=your-project-ID export ACCELERATOR_TYPE=v5p-32 export ZONE=europe-west4-b export LOCATION=europe-west4 export RUNTIME_VERSION=v2-alpha-tpuv5 export SERVICE_ACCOUNT=your-service-account export TPU_NAME=your-tpu-name export QUEUED_RESOURCE_ID=your-qr-name export QUOTA_TYPE=spot export VALID_UNTIL_DURATION=1d
Command flag descriptions
Variable Description PROJECT_ID Google Cloud Project Name ACCELERATOR_TYPE See the TPU versions page for your TPU version. ZONE See the TPU regions and zones document for the supported zones. LOCATION The Google Cloud region in which to create the Cloud Storage storage bucket. RUNTIME_VERSION For v5p use v2-alpha-tpuv5 for the RUNTIME_VERSION. SERVICE_ACCOUNT This is the address of your service account that you can find in Google Cloud console -> IAM -> Service Accounts. For example: TPU_NAME The user-assigned text ID of the TPU which is created when the queued resource request is allocated. QUEUED_RESOURCE_ID The user-assigned text ID of the queued resource request. See the Queued Resources document for information about queued resources. QUOTA_TYPE Can be reserved
. If neither of these are specified, the QUOTA_TYPE defaults toon-demand
. See quotas for information on the different types of quotas supported by Cloud TPU.VALID_UNTIL_DURATION The duration for which the request is valid. See Queued resources for information about the different valid durations. Set up a storage bucket for your model output.
gcloud storage buckets create gs://$GCS_BUCKET_NAME \ --project=$PROJECT_ID \ --location=$LOCATION
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_UNTIL_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
You will be able to SSH to your TPU VM once your queued resource is in the
state. Check the state of your queued resource by running the following command:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
When the queued resource is in the
state, the output will be similar to the following:state: ACTIVE
Train the model
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --project $PROJECT_ID --worker=all --command=" git clone cd maxdiffusion git reset --hard 57629bcf4fa32fe5a57096b60b09f41f2fa5c35d # This identifies the GitHub commit to use. pip3 install jax[tpu] -f # Install the latest version of JAX pip3 install -r requirements.txt pip3 install . export LIBTPU_INIT_ARGS="" python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run base_output_directory=gs://$GCS_BUCKET_NAME enable_profiler=False"
Clean up
Delete your TPU and queued resource request at the end of your session or to remove queued resource requests that are in the "FAILED" state. To delete a queued resource, delete the slice(s) and then the queued resource request in 2 steps:
gcloud compute tpus tpu-vm delete ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
Or, use --force
to delete the slice(s) and the queued resource request
in a single step:
# With --force gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
Benchmark results
The Stable Diffusion training script ran on v5p-8, v5p-32, and v5p-128. The following table shows the throughput.
v5p-8 |
v5p-32 |
v5p-128 |
Train Step |
150 |
150 |
150 |
Global batch size |
32 |
64 |
64 |
Throughput (examples/sec) |
12.10 |
18.08 |
19.10 |
This tutorial shows you how to train the MaxText model using a synthetic dataset on Cloud TPU.
MaxText is a high performance, arbitrarily scalable, open source, well-tested LLM written in pure Python/JAX targeting Cloud TPUs. MaxText empowers researchers and developers with an accessible and adaptable tool for advancing the frontiers of Natural Language Processing (NLP) research and development.
Before running this tutorial, you need to set up your Cloud TPU environment.
Set up environment variables
export PROJECT_ID=your_project_ID export TPU_NAME=your_tpu_name # user defined TPU name export ACCELERATOR_TYPE=v5p-256 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export RUN_NAME=your_experiment_run_name # user defined name for this run export GCS_BUCKET_NAME=your_bucket_name # Output cloud folder. Should start with gs:// export MAXTEXT_OUTPUT_PATH=${GCS_BUCKET_NAME}/your_experiment_output_path export NUM_SLICES=1 # Update the value to a number >1 for Multislice.
Command flag descriptions
Variable Description PROJECT_ID Google Cloud Project Name TPU_NAME A user defined name for your TPU. ACCELERATOR_TYPE See the TPU versions page for your TPU version. ZONE See the TPU regions and zones document for the supported zones. RUNTIME_VERSION For v5p use v2-alpha-tpuv5 for the runtime version. RUN_NAME User supplied experiment run name. Optional setup recommended for Multislice:
export NETWORK_NAME=your_network_name export FIREWALL_RULE_NAME=your_firewall_rule_name
If you're running Multislice workloads and want optimal network performance, consider creating a dedicated network with a Maximum Transmission Unit (MTU) of 8896 bytes and configuring appropriate firewall rules. While optional, this step can significantly improve performance, especially when scaling up the number of slices over the data-center network (DCN). Note creating a network requires
permission in the project. The following examples show how to create a dedicated network and firewall rule.Create a dedicated network:
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
Create a firewall rule:
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT_ID}
Clone the MaxText repository
git clone
Train the model
The following sections describe two options for training MaxText.
Option 1
If you want a script to manage the entire workflow, from provisioning Cloud TPUs and installing dependencies to running your model and tearing down resources, you can use
.cd maxtext && python3 --PROJECT=${PROJECT_ID} --ZONE=${ZONE} \ --NUM_SLICES=${NUM_SLICES} --TPU_TYPE=${ACCELERATOR_TYPE} \ --VERSION=${RUNTIME_VERSION} --RUN_NAME=${RUN_NAME} #user defined run name \ --BUCKET_NAME=${GCS_BUCKET_NAME} \ #used to store logs and configs --COMMAND="bash && bash MaxText/configs/experimental/ RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
After initiating the script, you should see a message similar to the following in the log. The log location is referenced in the output message. Click the first link to access the logs of all workers once TPU provisioning is complete.
------------------------------------ multihost_job finished running, TPUs are starting up to run your job remotely. Logs for your job are displayed here:;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22
_log%22%2529;?project=PROJECT_ID To see the output of a single host, you may edit the slice and worker number in the `log_file_path` property here:;;?project=PROJECT_ID When your job is finished, the main command log is in your Cloud Storage bucket: View the status of the created TPUs using: gcloud compute tpus queued-resources list --filter=RUN_NAME --zone=ZONE --project=PROJECT_ID
Option 2
To run the training script multiple times on a provisioned
Cloud TPU, use
script to use the resource.
Set up variables to create a TPU.
export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export VALID_DURATION=1d export QUOTA_TYPE=quota_type
--node-count ${NODE_COUNT} \ --node-prefix ${NODE_PREFIX} # optional, the default is QUEUED_RESOURCE_ID
Create a TPU resource.
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
You will be able to connect to your TPU VMs using SSH once your
is in stateACTIVE
:Use the
command to query the status of your queued resource.gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} --project ${PROJECT_ID} --zone ${ZONE}
When your queued resource is in the ACTIVE state, the output will be similar to the following:
state: ACTIVE
Connect to your TPU using SSH
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
Install dependencies
export TPU_NAME=your_tpu_name export MAXTEXT_OUTPUT_PATH=output-path
cd maxtext && python3 --TPU_PREFIX=${TPU_NAME} \ --COMMAND='bash'
Run the model with various configuration scripts, such as, If you are running the script from a TPU VM, you need to add the flag
.python3 --TPU_PREFIX=${TPU_NAME} \ --COMMAND="bash MaxText/configs/experimental/ RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
Clean up
Delete your TPU and queued resources.
Benchmark results
The MaxText training script was run from 32B to 1160B with bf16 precision. The results of these runs are shown in the following table.
No. of params |
Accelerator Type |
TFLOP/chip/sec |
Model flops utilization (MFU) |
32B |
v5p-128 |
3.28E+02 |
71.47% |
64B |
v5p-128 |
3.23E+02 |
70.31% |
128B |
v5p-256 |
3.15E+02 |
68.68% |
128B |
v5p-512 |
3.15E+02 |
68.53% |
256B |
v5p-1024 |
3.16E+02 |
68.82% |
512B |
v5p-1024 |
2.94E+02 |
63.99% |
1024B |
v5p-2048 |
2.49E+02 |
64.05% |
1024B |
v5p-4096 |
2.97E+02 |
64.80% |
1160B |
v5p-7680 |
2.95E+02 |
64.27% |
1160B |
v5p-12288 |
3.04E+02 |
66.23% |
The 256B parameter model has been tested on v5p-512 and v5p-1024 using both bf16 and int8 weights. The following table displays the results of these tests.
v5p-512 |
v5p-512 |
v5p-1024 |
v5p-1024 |
Global batch size (tokens) |
5.24E+05 |
5.24E+05 |
1.05E+06 |
1.05E+06 |
Precision |
bf16 |
int8 |
bf16 |
int8 |
TFLOP/chip/sec |
307 |
408 |
308 |
414 |
Model flops utilization (MFU) |
66.98% |
88.85% |
67.09% |
90.23% |
TensorFlow tutorials
Train ResNet on a single host v5p
This tutorial describes how to train ImageNet on a v5p-8
using a fake dataset. If you want to use a different dataset, refer to
Preparing the dataset.
Set up
Create environment variables:
export PROJECT_ID=your-project-ID export ACCELERATOR_TYPE=v5p-32 export ZONE=us-east1-c export RUNTIME_VERSION=tpu-vm-tf-2.18.0-pjrt export TPU_NAME=your-tpu-name export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
For this tutorial, use
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --${QUOTA_TYPE}
You will be able to connect to your TPU VM using SSH once your queued resource is in the
state. To check the state of your queued resource, use the following command:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
Connect to your TPU using SSH
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
Set some environment variables
export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export NEXT_PLUGGABLE_DEVICE_USE_C_API=true export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/
Change to the models repository directory and install requirements.
cd ${MODELS_REPO} && git checkout r2.15.0 pip install -r official/requirements.txt
Train the model
Run the training script.
python3 official/vision/ \ --tpu=local \ --experiment=resnet_imagenet \ --mode=train_and_eval \ --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \ --model_dir=${MODEL_DIR} \ --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
Clean up
Delete your TPU and queued resources.
Train ResNet on a multi-host v5p
This tutorial describes how to train ImageNet on v5p-16
or larger using
a fake dataset. If you want to use a different dataset, see Preparing the dataset.
Create environment variables:
export PROJECT_ID=your_project_ID export TPU_NAME=your_tpu_name export ZONE=us-east1-c export ACCELERATOR_TYPE=v5p-16 export RUNTIME_VERSION=tpu-vm-tf-2.18.0-pod-pjrt export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
can be eitherv5p-16
or larger.-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --${QUOTA_TYPE}
You will be able to connect to your TPU VM using SSH once your queued resource is in the
state.Use the
command to query the status of your queued resource:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
Connect to your TPU (worker zero) using SSH
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
Set some environment variables
export TPU_NAME=your_tpu_name export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export TPU_LOAD_LIBRARY=0
Change to the models repository directory and install requirements.
cd $MODELS_REPO && git checkout r2.15.0 pip install -r official/requirements.txt
Train the model
Run the training script.
python3 official/vision/ \ --tpu=${TPU_NAME} \ --experiment=resnet_imagenet \ --mode=train_and_eval \ --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \ --model_dir=${MODEL_DIR} \ --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
Clean up
Delete your TPU and queued resources.
Llama 2
This tutorial will cover how to train the Llama 2 7B model on v5p using a fork of the HuggingFace repository on PyTorch/XLA with General and Scalable Parallelization for ML Computation Graphs (GSPMD).
Create environment variables.
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export QUOTA_TYPE=quota_type export VALID_DURATION=1d
Create a TPU resource
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
You will be able to connect to your TPU VM using SSH once your
is in theACTIVE
state:Use the
command to query the status of your queued resource.gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
When your queued resource is in the ACTIVE state, the output will be similar to the following:
state: ACTIVE
Install Pytorch/XLA and required dependencies.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' sudo apt-get update sudo apt-get install libopenblas-dev -y pip install numpy pip install typing-extensions pip install torch torch_xla[tpu] -f -f '
Download the HuggingFace repository and install requirements.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' git clone -b llama2-google-next-training cd transformers pip3 install git+file://$PWD pip3 install datasets accelerate evaluate scikit-learn'
Download the 7B model config.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="curl --output ~/config.json"
Train the model
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_BF16=1 export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export LIBTPU_INIT_ARGS="--xla_enable_async_collective_permute=true \ --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ --xla_tpu_enable_async_collective_fusion=true \ --xla_tpu_overlap_compute_collective_tc=true \ --xla_enable_async_all_gather=true \ --xla_jf_spmd_threshold_for_windowed_einsum_mib=0" export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=20000 export PROFILE_LOGDIR=/tmp/home/ cd transformers python examples/pytorch/language-modeling/ \ --tokenizer_name hf-internal-testing/llama-tokenizer \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 96 \ --per_device_eval_batch_size 8 \ --num_train_epochs 1 \ --do_train \ --output_dir /tmp/output \ --overwrite_output_dir \ --config_name ~/config.json \ --save_strategy no \ --logging_strategy no \ --remove_unused_columns no \ --optim adafactor \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --block_size 2048 \ --spmd_2d_sharding 1 \ --spmd_grad_chkpt '
If you're running in a multislice environment, you need to
set the flag --spmd_dcn_parallelism
to the number of slices.
The SPMD_USER_GUIDE provides a more in-depth user guide that explains all the different environment variables and toggles of the HF script. To be noted, the LIBTPU_INIT_ARGS will be incorporated into PyTorch/XLA and on by default in future releases.
Clean up
Delete your TPU and queued resources.
Benchmark results
Throughput for all three Llama 2 model sizes are included in the following table.
v5p-8 |
v5p-128 |
v5p-128 |
Model size |
7B |
13B |
70B |
Global batch size |
96 |
1024 |
128 |
Sharding mesh shape |
(4, 1) |
(64, 1) |
(16, 4) |
Model flops utilization (MFU) |
56.67% |
55.80% |
51.85% |
Support and Feedback
We welcome all feedback! To share feedback or request support, fill out the Cloud TPU Support or Feedback form.