Trillium (v6e) introduction
v6e is used to refer to Trillium in this documentation, TPU API, and logs. v6e represents Google's 6th generation of TPU.
With 256 chips per Pod, v6e architecture shares many similarities with v5e. This system is optimized for transformer, text-to-image, and convolutional neural network (CNN) training, fine-tuning, and serving.
For more information about v6e system architecture and configurations, see TPU v6e.
This introduction document focuses on the processes for model training and serving using JAX, PyTorch, or TensorFlow frameworks. With each framework, you can provision TPUs using queued resources or GKE. GKE setup can be done using XPK or GKE commands.
General procedure to train or serve a model using v6e
- Prepare a Google Cloud project
- Secure capacity
- Provision the Cloud TPU environment
- Run a model training or inference workload
Prepare a Google Cloud project
Before you can use Cloud TPU, you need to:
- Create a Google Cloud account and project with billing enabled
- Install Google Cloud CLI alpha components
- Enable the Cloud TPU API
- Create a Cloud TPU service agent
- Create a Cloud TPU service account and grant permissions
For more information, see Set up the Cloud TPU environment.
Secure capacity
Contact Google Cloud support to request Cloud TPU v6e quota and to answer any questions about capacity.
Provision the Cloud TPU environment
v6e Cloud TPU can be provisioned and managed with GKE, with GKE and XPK (a wrapper CLI tool over GKE), or as queued resources.
Prerequisites
- Verify that your project has enough
TPUS_PER_TPU_FAMILY
quota, which specifies the maximum number of chips you can access within your Google Cloud project. - v6e has been tested with the following configuration:
- Python
3.10
or later - Nightly software versions:
- Nightly JAX
0.4.32.dev20240912
- Nightly LibTPU
0.1.dev20240912+nightly
- Nightly JAX
- Stable software versions:
- JAX + JAX Lib of v0.4.37
- Python
Verify that your project has enough quota for:
- Cloud TPU VM quota
- IP address quota
Hyperdisk Balanced quota
If you are using GKE with XPK, see Cloud Console Permissions on the user or service account for the permissions needed to run XPK.
Create environment variables
In a Cloud Shell, create the following environment variables:
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration # Additional environment variable needed for Multislice: export NUM_SLICES=number-of-slices # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Variable | Description |
NODE_ID | The user-assigned ID of the Cloud TPU which is created when the queued resource request is allocated. |
PROJECT_ID | Google Cloud project name. Use an existing project or create a new one. For more information, see Set up your Google Cloud project. |
ZONE | See the Cloud TPU regions and zones document for the supported zones. |
ACCELERATOR_TYPE | See Accelerator Types. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | This is the email address for your service account that you can find in
Google Cloud Console -> IAM -> Service Accounts
For example: |
NUM_SLICES | The number of slices to create (needed for Multislice only). |
QUEUED_RESOURCE_ID | The user-assigned text ID of the queued resource request. |
VALID_DURATION | The duration for which the queued resource request is valid. |
NETWORK_NAME | The name of a secondary network to use. |
NETWORK_FW_NAME | The name of a secondary network firewall to use. |
Optimize network performance
For the best performance use a network with 8,896 MTU (maximum transmission unit).
By default, a Virtual Private Cloud (VPC) only provides an MTU of 1,460 bytes which will provide suboptimal network performance. You can set a VPC network's MTU to any value between 1,300 bytes and 8,896 bytes (inclusive). Common custom MTU sizes are 1,500 bytes (standard Ethernet) or 8,896 bytes (the maximum possible). For more information, see Valid VPC network MTU sizes.
For more information about changing the MTU setting for an existing or default network, see Change the MTU setting of a VPC network.
The following example creates a network with 8,896 MTU.
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
Using multi-NIC (option for Multislice)
The following environment variables are needed for a secondary subnet when you are using a Multislice environment.
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
Use the following commands to create custom IP routing for the network and subnet.
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
--network=${NETWORK_NAME_2} \
--range=10.10.0.0/18 --region=${REGION} \
--project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
--network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
--project=${PROJECT_ID} \
--network=${NETWORK_NAME_2} \
--region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
--router=${ROUTER_NAME} \
--region=${REGION} \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project=${PROJECT_ID} \
--enable-logging
Once a multi-network slice has been created, you can validate that
both NICs are being used by running --command ifconfig
as part
of the XPK workload. Use the following xpk workload
command to display the output of the ifconfig
command in Cloud console logs and check that both eth0 and eth1 have mtu=8896.
After you create a multi-network slice, you can validate that both network
interface cards (NICs) are being used by setting up an XPK cluster
and adding the --command ifconfig
flag to the XPK workload creation command.
Use the following xpk workload
command to display the output of the ifconfig
command in Google Cloud console logs and check that both eth0 and eth1 have mtu=8896.
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --cluster your-cluster-name \ (--base-docker-image maxtext_base_image|--docker-image your-cloud-image-name \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
If you want to enable debug logs or use Vertex AI TensorBoard, add the following optional arguments to the command:
--enable-debug-logs --use-vertex-tensorboard
Verify that both eth0 and eth1 have mtu=8,896.
a way to verify you have multi-nic running is by running the command --command "ifconfig" as part of the XPK workload. Then look at the printed output of that
XPK workload in cloud Google Cloud console logs and check that both eth0 and eth1 have mtu=8896.
Verify that both eth0 and eth1 have mtu=8,896. You can verify that the multi-NIC
is running is by adding the --command ifconfig
flag to the XPK workload
creation command. Check the output of that xpk workload in Google Cloud console
logs and verify that both eth0 and eth1 have mtu=8896.
Improve TCP settings
If you created your Cloud TPUs using the queued resources interface, you can run the following command to improve network performance by increasing TCP receive buffer limits.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT}" \ --zone "${ZONE}" \ --node=all \ --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \ --worker=all
Provision with queued resources
You can create a Cloud TPU v6e using queued resources. Queued resources allow you to receive capacity once it becomes available. You can specify an optional start and end time for when the request should be filled. For more information, see Manage queued resources.
Provision v6e Cloud TPUs with GKE or XPK
If you are using GKE commands with v6e, you can use Kubernetes commands or XPK to provision Cloud TPUs and train or serve models. See Plan for Cloud TPUs in GKE to learn how to plan your Cloud TPU configurations in GKE clusters. The following sections provide commands to create an XPK cluster with single NIC support and multi-NIC support.
Create an XPK cluster with single NIC support
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=n1-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Variable | Description |
CLUSTER_NAME | The user-assigned name for the XPK cluster. |
PROJECT_ID | Google Cloud project name. Use an existing project or create a new one. For more information, see Set up your Google Cloud project. |
ZONE | See the Cloud TPU regions and zones document for the supported zones. |
TPU_TYPE | See Accelerator Types. |
NUM_SLICES | The number of slices you want to create |
CLUSTER_ARGUMENTS | The network and subnetwork to use.
For example: |
NUM_SLICES | The number of slices to create. |
NETWORK_NAME | The name of a secondary network to use. |
NETWORK_FW_NAME | The name of a secondary network firewall to use. |
Create an XPK cluster with multi-NIC support
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster=${CLUSTER_NAME} \
--num-slices=${NUM_SLICES} \
--tpu-type=${TPU_TYPE} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--on-demand \
--custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
--custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
--create-vertex-tensorboard
Variable | Description |
CLUSTER_NAME | The user-assigned name for the XPK cluster. |
PROJECT_ID | Google Cloud project name. Use an existing project or create a new one. For more information, see Set up your Google Cloud project. |
ZONE | See the Cloud TPU regions and zones document for the supported zones. |
TPU_TYPE | See Accelerator Types. |
NUM_SLICES | The number of slices you want to create |
CLUSTER_ARGUMENTS | The network and subnetwork to use.
For example: |
NODE_POOL_ARGUMENTS | Additional node network to use.
For example: |
NUM_SLICES | The number of slices to create (needed for Multislice only). |
NETWORK_NAME | The name of a secondary network to use. |
NETWORK_FW_NAME | The name of a secondary network firewall to use. |
Framework setup
This section describes the general setup process for ML model training using JAX, PyTorch, or TensorFlow frameworks. If you're using GKE, you can use XPK or Kubernetes commands for framework setup.
Setup for JAX
This section provides setup instructions for running JAX workloads on GKE, with or without XPK, as well as using queued resources.
Set up JAX using GKE
Single slice on single host
The following example sets up a 2x2 single-host node pool using a Kubernetes YAML file.
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Upon successful completion, you should see the following message in the GKE log:
Total TPU chips: 4
Single slice on multi-host
The following example sets up a 4x4 multi-host node pool using a Kubernetes YAML file.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Upon successful completion, you should see the following message in the GKE log:
Total TPU chips: 16
Multislice on multi-host
The following example sets up two 4x4 multi-host node pools using a Kubernetes YAML file.
As a prerequisite, you need to install JobSet v0.2.3 or later.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Upon successful completion, you should see the following message in the GKE log:
Total TPU chips: 32
For more information, see Run a Multislice workload in the GKE documentation.
For better performance, Enable hostNetwork.
Multi-NIC
To take advantage of multi-NIC in GKE, the Kubernetes Pod manifest needs to have additional annotations. The following is a non-TPU multi-NIC workload example manifest.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
If you use the exec
command to connect to the Kubernetes Pod, you should see
the additional NIC using the following code.
$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
Set up JAX using GKE with XPK
To set up JAX using GKE and XPK, see the xpk README.
To set up and run XPK with MaxText, see How to run MaxText.
Set up JAX using queued resources
Install JAX on all Cloud TPU VMs in your slice or slices simultaneously using the
gcloud alpha compute tpus tpu-vm ssh
command. For Multislice, add the --node=all
flag.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
You can run the following command to check how many Cloud TPU cores are available in your slice and to test that everything is installed correctly:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
The output is similar to the following when running on a v6e-16 slice:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count()
shows the total number of chips in the given slice.
jax.local_device_count()
indicates the count of chips accessible by a single
VM in this slice.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
&& pip install -r requirements.txt && pip install . '
Troubleshoot JAX setup
A general tip is to enable verbose logging in your GKE workload manifest. Then, provide the logs to GKE support.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Error messages
no endpoints available for service 'jobset-webhook-service'
This error means the jobset wasn't installed properly. Check to see if jobset-controller-manager deployment Kubernetes Pods are running. For more information, see the JobSet troubleshooting documentation for details.
TPU initialization failed: Failed to connect
Make sure your GKE node version is 1.30.4-gke.1348000 or later (GKE 1.31 is not supported).
Setup for PyTorch
This section describes how to start using PJRT on v6e with PyTorch/XLA. Python 3.10 is the recommended Python version.
Set up PyTorch using GKE with XPK
You can use the following Docker container with XPK which has PyTorch dependencies already installed:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
To create a XPK workload, use the following command:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
Using --base-docker-image
creates a new Docker image with the current working
directory built into the new Docker.
Set up PyTorch using queued resources
Follow these steps to install PyTorch using queued resources and run a small script on v6e.
Install dependencies using SSH to access the VMs
Use the following command to install dependencies on all Cloud TPU VMs. For
Multislice, add the --worker=all
flag:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Improve performance of models with sizable, frequent allocations
For models which have sizable, frequent allocations, using the tcmalloc
function improves performance significantly compared to the default malloc
function implementation, so the default malloc
function used on Cloud TPU VM is
tcmalloc
. However, depending on your workload (for example, with DLRM which
has very large allocations for its embedding tables) the tcmalloc
function may
cause a slowdown in which case you can try unsetting the following variable
using the default malloc
function instead:
unset LD_PRELOAD
Use a Python script to do a calculation on v6e VM
Use the following command to run a script that creates two tensors, adds them together, and prints the result.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
This generates output similar to the following:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Setup for TensorFlow
You can reset the Cloud TPU runtime with the v6e-compatible TensorFlow version by running the following commands:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Use SSH to access worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Install TensorFlow on worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Export the TPU_NAME
environment variable:
export TPU_NAME=v6e-16
You can run the following Python script to check how many Cloud TPU cores are available in your slice and to test that everything is installed correctly:
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
The output is similar to the following when running on a v6e-16 slice:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e with SkyPilot
You can use Cloud TPU v6e with SkyPilot. Use the following steps to add v6e-related location and pricing information to SkyPilot.
Add the following to the end of the
~/.sky/catalogs/v5/gcp/vms.csv
file:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Specify the following resources in a YAML file:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Launch a cluster with Cloud TPU v6e:
sky launch tpu_v6.yaml -c tpu_v6
Connect to the Cloud TPU v6e using SSH:
ssh tpu_v6
Inference tutorials
The following tutorials show how to run inference on Cloud TPU v6e:
Training examples
The following sections provide examples for training MaxText, MaxDiffusion, and PyTorch models on Cloud TPU v6e.
MaxText and MaxDiffusion training on v6e Cloud TPU VM
The following sections cover the training lifecycle of the MaxText and MaxDiffusion models.
In general, the high-level steps are:
- Build the workload base image.
- Run your workload using XPK.
- Build the training command for the workload.
- Deploy the workload.
- Follow the workload and view metrics.
- Delete the XPK workload if it isn't needed.
- Delete the XPK cluster when it's no longer needed.
Build base image
Install MaxText or MaxDiffusion and build the Docker image:
Clone the repository you want to use and change to the directory for the repository:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Configure Docker to use the Google Cloud CLI:
gcloud auth configure-docker
Build the Docker image using the following command or using JAX Stable Stack. For more information about JAX Stable Stack, see Build Docker image with JAX Stable Stack.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
If you're launching the workload from a machine that doesn't have the image built locally, upload the image:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Build a Docker image with JAX Stable Stack
You can build the MaxText and MaxDiffusion Docker images using the JAX Stable Stack base image.
JAX Stable Stack provides a consistent environment for MaxText and MaxDiffusion
by bundling JAX with core packages like orbax
, flax
, and optax
, along with
a well-qualified libtpu.so that drives Cloud TPU program utilities and other
essential tools. These libraries are tested to ensure compatibility and provide
a stable foundation to build and run MaxText and MaxDiffusion. This eliminates
potential conflicts due to incompatible package versions.
JAX Stable Stack includes a fully released and qualified libtpu.so, the core library that drives Cloud TPU program compilation, execution, and ICI network configuration. The libtpu release replaces the nightly build previously used by JAX, and ensures consistent functionality of XLA computations on Cloud TPU with PJRT-level qualification tests in HLO/StableHLO IRs.
To build the MaxText and MaxDiffusion Docker image with JAX Stable Stack, when
you run the docker_build_dependency_image.sh
script, set the MODE
variable to
stable_stack
and set the BASEIMAGE
variable to the base image you want to
use.
The following example specifies
us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
as the base image:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
For a list of available JAX Stable Stack base images, see JAX Stable Stack images in Artifact Registry.
Run your workload using XPK
Set the following environment variables if you're not using the default values set by MaxText or MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Build your model script. This script will be copied as a training command in a later step.
Don't execute the model script yet.
MaxText
MaxText is a high performance, highly scalable, open-source LLM written in pure Python and JAX and targeting Google Cloud TPUs and GPUs for training and inference.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma is a family of open-weights LLMs developed by Google DeepMind, based on Gemini research and technology.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral is a state-of-the-art AI model developed by Mistral AI, utilizing a sparse mixture-of-experts (MoE) architecture.
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama is a family of open-weights LLMs developed by Meta.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash
MaxDiffusion
MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python and JAX that run on XLA devices including Cloud TPUs and GPUs. Stable Diffusion is a latent text-to-image model that generates photo-realistic images from any text input.
You need to install a specific Git branch to run MaxDiffusion as shown in the following
git checkout
command.git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Training script:
cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \ python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash run_name=sdxl-ddp-v6e
Run the model using the script you created in the previous step. You must either specify the
--base-docker-image
flag to use the MaxText base image or specify the--docker-image
flag and the image you want to use.Optional: You can enable debug logging by including the
--enable-debug-logs
flag. For more information, see Debug JAX on MaxText.Optional: you can create a Vertex AI Experiment to upload data to Vertex AI TensorBoard by including the
--use-vertex-tensorboard
flag. For more information, see Monitor JAX on MaxText using Vertex AI.python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command=$YOUR-MODEL-SCRIPT
If you want to enable debug logs or use Vertex AI TensorBoard, add the following optional arguments to the command:
export CLUSTER_NAME=CLUSTER_NAME: The name of your XPK cluster. export ACCELERATOR_TYPEACCELERATOR_TYPE: The version and size of your TPU. For example, `v6e-256`. export NUM_SLICES=NUM_SLICES: The number of Cloud TPU slices. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: The model script to execute as a training command.
Export the following variables: \ export CLUSTER_NAME=CLUSTER_NAME \ export ACCELERATOR_TYPE=ACCELERATOR_TYPE \ export NUM_SLICES=NUM_SLICES \ export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
Command flag descriptions
Variable Description CLUSTER_NAME The name of your XPK cluster. ACCELERATOR_TYPE See Accelerator Types. NUM_SLICES The number of TPU slices. YOUR_MODEL_SCRIPT The model script to execute as a training command. The output includes a link to follow your workload. Open the link and click the Logs tab to track your workload in real time.
Debug JAX on MaxText
Use supplemental XPK commands to diagnose why the cluster or workload isn't running.
- XPK workload list
- XPK inspector
- Enable verbose logging in your workload logs using the
--enable-debug-logs
flag when you create the XPK workload.
Monitor JAX on MaxText using Vertex AI
View scalar and profile data through Vertex AI's managed TensorBoard.
- Increase resource management (CRUD) requests for the zone you're using from 600 to 5000. This might not be an issue for small workloads using less than 16 VMs.
Install dependencies such as
cloud-accelerator-diagnostics
for Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Create your XPK cluster using the
--create-vertex-tensorboard
flag, as documented in Create Vertex AI TensorBoard. You can also run this command on existing clusters.Create your Vertex AI experiment when running your XPK workload using the
--use-vertex-tensorboard
flag and the optional--experiment-name
flag. For the full list of steps, see Create Vertex AI Experiment to upload data to Vertex AI TensorBoard.
The logs include a link to a Vertex AI TensorBoard, similar to the following:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
You can also find the Vertex AI TensorBoard link in the Google Cloud console. Go to Vertex AI Experiments in the Google Cloud console. Select the appropriate region from the drop-down.
The TensorBoard directory is also written to the Cloud Storage bucket that
you specified with ${BASE_OUTPUT_DIR}
.
Delete XPK workloads
Use the xpk workload delete
command
to delete one or more workloads based on the job prefix or job status. This
command might be useful if you sent XPK workloads that no longer need to be run,
or if you have jobs that are stuck in the queue.
Delete XPK cluster
Use the xpk cluster delete
command to delete a cluster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
Llama and PyTorch/XLA training on v6e Cloud TPU VM
This tutorial describes how to train Llama models using PyTorch/XLA on Cloud TPU v6e using the WikiText dataset.
Get access to Hugging Face and the Llama 3 model
You need a Hugging Face user access token to run this tutorial. For information about creating and user access tokens, see the Hugging Face documentation on user access tokens.
You also need permission to access the Llama 3 8B model on Hugging Face. To get access, go to the Meta-Llama-3-8B model on HuggingFace and request access.
Create a Cloud TPU VM
Create a Cloud TPU v6e with 8 chips to run the tutorial.
Set up environment variables:
export ACCELERATOR_TYPE=v6e-8 export VERSION=v2-alpha-tpuv6e export TPU_NAME=$USER-$ACCELERATOR_TYPE export PROJECT_ID=your-project-id export ZONE=us-east1-d
Create a Cloud TPU VM:
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Installation
Install the pytorch-tpu/transformers
fork
of Hugging Face transformers and dependencies. This tutorial was tested with the
following dependency versions used in this example:
torch
: compatible with 2.5.0torch_xla[tpu]
: compatible with 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} --zone ${ZONE} \ --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Set up model configs
The training command in the next section, Run the model, uses two JSON config files to define model parameters and FSDP (Fully Sharded Data Parallel) configuration. FSDP sharding is used for the model weights to fit a bigger batch size while training. When training with smaller models, it might be sufficient to use data parallelism and replicate the weights on each device. For more information about how to shard tensors across devices in PyTorch/XLA, see PyTorch/XLA SPMD User Guide.
Create the model parameter config file. The following is the model parameter config for Llama3-8B. For other models, find the config on Hugging Face. For example, see the Llama2-7B config.
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Create the FSDP config file:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
For more information about FSDP, see FSDPv2.
Upload the config files to your Cloud TPU VMs using the following command:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Run the model
Using the config files you created in the previous section, run the run_clm.py
script to train the Llama 3 8B model on the WikiText dataset. The training
script takes approximately 10 minutes to run on a Cloud TPU v6e-8.
Sign in to Hugging Face on your Cloud TPU using the following command:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Run the model training:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Troubleshooting PyTorch/XLA
If you set the optional variables for debugging in the previous section,
the profile for the model will be stored at the location specified by the
variable PROFILE_LOGDIR
. You can extract the xplane.pb
file stored
at this location and use tensorboard
to view the profiles in your
browser using the TensorBoard instructions
If PyTorch/XLA isn't performing as expected, see the troubleshooting guide,
which has suggestions for debugging, profiling, and optimizing your model.
DLRM DCN v2 training on v6e
This tutorial shows you how to train the DLRM DCN v2 model on Cloud TPU v6e. You need to provision a TPU v6e with 64, 128, or 256 chips.
If you are running on a multi-host TPU, reset tpu-runtime
with the appropriate
TensorFlow version by running the following commands. If you are
running on s single-host TPU, you don't need to run the following two commands.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Connect to worker-0 using SSH
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project ${PROJECT_ID}
Set the Cloud TPU name
export TPU_NAME=your-tpu-name
Run DLRM v2
Copy the following code snippet in a file named script.sh
:
pip install --user setuptools==65.5.0
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
If you're running TensorFlow on GKE, install the TensorFlow Cloud TPU wheel and libtpu using the following command:
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Set the following flags, which are necessary to run recommendation workloads (such as DLRM DCN):
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Run script.sh
:
chmod +x script.sh
./script.sh
Benchmarking results
The following section contains benchmarking results for DLRM DCN v2 and MaxDiffusion on v6e.
DLRM DCN v2
The DLRM DCN v2 training script was run at different scales. See the throughputs in the following table.
v6e-64 | v6e-128 | v6e-256 | |
---|---|---|---|
Training steps | 7000 | 7000 | 7000 |
Global batch size | 131072 | 262144 | 524288 |
Throughput (examples/sec) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
We ran the training script for MaxDiffusion on a v6e-4, a v6e-16, and two v6e-16. See the throughputs in the following table.
v6e-4 | v6e-16 | Two v6e-16 | |
---|---|---|---|
Training steps | 0.069 | 0.073 | 0.13 |
Global batch size | 8 | 32 | 64 |
Throughput (examples/sec) | 115.9 | 438.4 | 492.3 |