Cloud TPU Multislice Overview
Cloud TPU Multislice is a full stack performance-scaling technology that enables a training job to use multiple TPU slices within a single Pod or on slices in multiple Pods with simple data parallelism. With TPU v4 chips this means training jobs can use more than 4096 chips in a single run. For training jobs that require less than 4096 chips, a single slice can offer the best performance. However, multiple smaller slices are more readily available, allowing for a faster startup time when Multislice is used with smaller slices.
When deployed in Multislice configurations, TPU chips in each slice communicate through inter-chip-interconnect (ICI). TPU chips in different slices communicate by transferring data to CPUs (hosts) which in turn transmit the data over the data-center network (DCN).
Developers don't have to write code to implement inter-slice DCN communication. The XLA compiler generates that code for you and overlaps communication with computation for maximum performance.
Concepts
- Accelerator type
- The shape of each TPU slice that comprises a Multislice. Each
slice in a multislice request is of the same accelerator type. An accelerator
type consists of a TPU type (v4 or v5e) followed by the number of
TensorCores. For example,
v4-128
specifies a TPU v4 with 128 TensorCores. - Auto-repair
- When a slice encounters a maintenance event, preemption or hardware failure, Cloud TPU will create a new slice. In the rare case when there is insufficient resources to create a new slice, the creation won't complete until hardware becomes available. After the new slice is created all other slices in the Multislice environment will be restarted so training can continue.With a properly configured startup script, the training script can automatically relaunch without user intervention, loading and resuming from the latest checkpoint.
- Dataset
- The data that is used by a model for training or inference.
- Data Center Networking (DCN)
- A higher latency, lower-throughput network (when compared with ICI) that connects TPU slices in a Multislice configuration.
- Gang scheduling
- When all TPU slices are provisioned together, at the same time, guaranteeing either all or none of the slices are provisioned successfully.
- Host
- A host is a physical computer that runs VMs. A host can run at most four VMs at one time. Each VM has a dedicated TPU.
- Inference
- Load a pre-trained machine learning model onto a host and make predictions on data.
- Interchip interconnect (ICI)
- High speed, low latency internal links that connect TPUs within a TPU Pod.
- Multislice
- Two or more TPU chip slices that can communicate over DCN.
- Node
- In the Multislice context, node refers to a single TPU slice. Each TPU slice in a Multislice is given a node ID.
- Pod
- A collection of TPU chips connected by dedicated ICI network interfaces. A Pod lets you distribute the processing load across multiple TPUs.
- Queued resource (QR)
- A representation of TPU resources, used to enqueue and manage a request for a single-slice or Multislice TPU environment.
- Startup script
- A standard Compute Engine startup script that is run every time a VM is booted or rebooted. For Multislice, it is specified in the QR creation request. For more information about Cloud TPU startup scripts, see Manage TPU resources.
- TPU slice
- A logical subsection of a TPU Pod consisting of TPU chips. All chips in a slice communicate with each other using the ICI network.
- TPU VM
- A virtual machine running Linux that has access to the underlying TPUs. For v4 TPUs, each TPU VM has direct access to four chips. Sometimes we call a TPU VM a worker.
- Tensor
- A data structure that is used to represent multidimensional data in a machine learning model.
- Tensor processing unit (TPU)
- Google's internally developed ML acceleration chip. They are designed to offer fast and power-efficient compute for key machine learning tasks like matrix multiplication.
- Types of Cloud TPU capacity
TPUs can be created from different types of capacity (see Usage Options in How TPU pricing works) :
- Reservation: Targets reserved quota. To use reserved quota you must have a
reservation agreement with Google. Use the
--reserved
flag when creating your resources. - Spot: Targets preemptible quota using Spot VMs. Your
resources may be preempted to make room for requests for a higher
priority job. Use the
--spot
flag when creating your resources. - On-demand: Targets on-demand quota, which doesn't need a reservation and won't be preempted. The TPU request will be enqueued to an on-demand quota queue offered by Cloud TPU, the availability of resources is not guaranteed. Selected by default, no flags needed.
- Reservation: Targets reserved quota. To use reserved quota you must have a
reservation agreement with Google. Use the
Get started
If you have not used TPUs before, start by installing the Google Cloud CLI, and set up your Cloud TPU environment. To use Multislice, your TPU resources must be managed as Queued Resources.
If you are an existing TPU v4 user and have a reservation you may need to migrate your reservation to a new reservation system. For more information, contact your Google Cloud account representative.
Introductory example
This tutorial uses code from the MaxText GitHub repo. MaxText is a high performance, arbitrarily scalable, open source, and well-tested basic LLM written in Python and Jax. MaxText was designed to train efficiently on Cloud TPU.
The code in shardings.py
is designed to help you get started experimenting with different parallelism
options. For example, data parallelism, fully sharded data parallelism (FSDP),
and tensor parallelism. The code scales from single slice to Multislice
environments.
ICI parallelism
ICI refers to the high speed interconnect that connects the TPUs in a single
slice. ICI sharding corresponds to sharding within a slice. shardings.py
provides three ICI parallelism parameters:
ici_data_parallelism
ici_fsdp_parallelism
ici_tensor_parallelism
The values you specify for these parameters determine the number of shards for each parallelism method.
These inputs must be constrained so that
ici_data_parallelism * ici_fsdp_parallelism * ici_tensor_parallelism
is equal
to the number of chips in the slice.
The following table shows example user inputs for ICI parallelism for the four chips available in v4-8:
ici_data_parallelism | ici_fsdp_parallelism | ici_tensor_parallelism | |
4-way FSDP | 1 | 4 | 1 |
4-way Tensor parallelism | 1 | 1 | 4 |
2-way FSDP + 2-way Tensor parallelism | 1 | 2 | 2 |
Note that ici_data_parallelism
should be left as 1 in most cases because the
ICI network is fast enough to almost always prefer FSDP to data parallelism.
This example assumes you are familiar with running code on a single TPU slice
such as in Run a calculation on a Cloud TPU VM using JAX.
This example show how to run shardings.py
on a single slice.
Set up the environment:
$ gcloud auth login $ gcloud config set project your-project-id $ gcloud config set compute/zone your-zone
Create SSH keys for
gcloud
. We recommend leaving a blank password (press enter twice after running the following command). If you are prompted that thegoogle_compute_engine
file already exists, replace the existing version.$ ssh-keygen -f ~/.ssh/google_compute_engine
Provision your TPUs with the following command:
$ gcloud compute tpus queued-resources \ create your-qr-id \ --accelerator-type your-accelerator-type \ --runtime-version tpu-ubuntu2204-base \ --node-id qr-id \ [--reserved |--spot]
Command flag descriptions
your-qr-id
- A user-defined string that identifies the QR request.
accelerator-type
- The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
runtime-version
- The [Cloud TPU software version](/tpu/docs/supported-tpu-configurations#tpu_software_versions).
node-id
- The ID of the TPU resources that will be created in response to the QR request.
reserved
- Use reserved quota when creating the slices.
spot
- Use Spot VMs quota when creating the slices.
The Google Cloud CLI does not support all create QR options, such as tags. For more information, see Create QRs.
Wait until the QR is in the
ACTIVE
state which means the worker nodes are in theREADY
state. Once the QR provisioning starts, it may take one to five minutes to complete depending on the size of the QR. You can check the status of a QR request using the following command:$ gcloud compute tpus queued-resources \ list --filter=your-qr-id
A v4-8 slice has a single TPU VM. Connect to the TPU VM using SSH:
$ gcloud compute tpus tpu-vm ssh your-qr-id
Clone MaxText (which includes
shardings.py
) to your TPU VM.Within the MaxText repository directory, run the setup script to install JAX and other dependencies on your TPU slice. Setup script takes a few minutes to run.
$ bash setup.sh
Run the following command to run
shardings.py
on your TPU slice.$ python3 pedagogical_examples/shardings.py \ --ici_fsdp_parallelism 4 \ --batch_size 131072 \ --embedding_dimension 2048
You can see the results in the logs. Your TPUs should achieve about 260 TFLOP per second or an impressive 90%+ FLOP utilization! In this case, we've selected approximately the maximum batch that fits into the TPU's High Bandwidth Memory (HBM).
Feel free to explore other sharding strategies over ICI, for example you could try the following combination:
$ python3 pedagogical_examples/shardings.py \ --ici_tensor_parallelism 4 \ --batch_size 131072 \ --embedding_dimension 2048
Delete the QR and TPU slice when finished. You should run these cleanup steps from the environment where you set up the slice (first run
exit
to exit the SSH session). The deletion will take two to five minutes to complete, and can be run in the background with the optional--async
flag.$ gcloud compute tpus queued-resources delete your-qr-id --force (--async)
Multislice sharding using DCN parallelism
The shardings.py
script takes three parameters that specify DCN parallelism,
corresponding to the number of shards of each type of data parallelism:
- dcn_data_parallelism
- dcn_fsdp_parallelism
- dcn_tensor_parallelism
The values of these parameters must be constrained so that
dcn_data_parallelism * dcn_fsdp_parallelism * dcn_tensor_parallelism
equals
the number of slices.
As an example for two slices, use --dcn_data_parallelism = 2
.
dcn_data_parallelism | dcn_fsdp_parallelism | dcn_tensor_parallelism | # of slices | |
2-way data parallelism | 2 | 1 | 1 | 2 |
dcn_tensor_parallelism
should always be set to 1
because the DCN is a poor
fit for such sharding. For typical LLM workloads on v4 chips,
dcn_fsdp_parallelism
should also be set to 1
and therefore
dcn_data_parallelism
should be set to the number of slices, but this is
application dependent.
As you increase the number of slices (assuming you keep the slice size and batch per slice constant), you increase the amount of data parallelism.
Running shardings.py
in a Multislice environment
You can run shardings.py
in a Multislice environment using
multihost_runner.py
or by running shardings.py
on each TPU VM. Here we use
multihost_runner.py
. The following steps are very similar to those
Getting Started: Quick Experiments on Multiple slices
from the MaxText repository, except here we run shardings.py
instead of the
more complex LLM in train.py
.
The multihost_runner.py
tool is optimized for quick experiments, repeatedly
re-using the same TPUs. Because the multihost_runner.py
script depends on
long-lived SSH connections, we don't recommend it for any long-running jobs.
If you want to run a longer job (for example, hours or days), we recommend you
use multihost_job.py
.
In this tutorial, we use the term runner to indicate the machine on which you
run the multihost_runner.py
script. We use the term workers to indicate the
TPU VMs that make up your slices. You can run multihost_runner.py
on a local
machine or any Compute Engine VM in the same project as your slices. Running
multihost_runner.py
on a worker is not supported.
multihost_runner.py
automatically connects to TPU workers using SSH.
In this example, we run shardings.py
over two v4-16 slices, a total of four
VMs and 16 TPU chips. You can modify the example to run on more TPUs.
Set up your environment
Clone MaxText on your runner machine.
Go to the repository directory.
Create SSH keys for
gcloud
, we recommend leaving a blank password (press enter twice after running the following command). If you are prompted that thegoogle_compute_engine
file already exists, select not to keep your existing version.$ ssh-keygen -f ~/.ssh/google_compute_engine
Add an environment variable to set the TPU slice count to
2
.$ export SLICE_COUNT=2
Create a Multislice environment using
queued-resources create
.The following command shows how to create a v4 Multislice TPU. To use v5e, specify a v5e
accelerator-type
(for examplev5litepod-16
) and the v5eruntime-version
(v2-alpha-tpuv5-lite
).$ gcloud compute tpus queued-resources
create your-qr-id
--accelerator-type=your-accelerator-type
--runtime-version=tpu-vm-runtime-version
--node-count=node-count
--node-prefix=your-qr-id
[--reserved|--spot]Command flag descriptions
your-qr-id
- A user-defined string that identifies the QR request.
accelerator-type
- The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
runtime-version
- The Cloud TPU software version.
node-count
- The number of slices to create.
node-prefix
- The prefix used to generate names for each slice. A number is appended
to the prefix for each slice. For example if you set
node-prefix
tomySlice
, the slices are named:mySlice-0
,mySlice-1
, continuing numerically for each slice. reserved
- Use reserved quota when creating the slices.
spot
- Use Spot VMs quota when creating the slices.
When the QR provisioning starts it may take up to five minutes to complete depending on the size of the QR. Wait until the Queued Resource (QR) is in the
ACTIVE
state. You can check the status of a QR request using the following command:$ gcloud compute tpus queued-resources list \ --filter=your-qr-id
This should generate output that looks like:
NAME ZONE NODE_COUNT ACCELERATOR_TYPE STATE ... que-res-id us-central2-b 4 v4-16 ACTIVE ...
Contact your Google Cloud account representative if the QR status is in the
WAITING_FOR_RESOURCES
orPROVISIONING
state for more than 15 minutes.Install dependencies.
$ python3 multihost_runner.py \ --TPU_PREFIX=your-qr-id \ --COMMAND="bash setup.sh"
Run
shardings.py
on each worker usingmultihost_runner.py
.$ python3 multihost_runner.py \ --TPU_PREFIX=your-qr-id \ --COMMAND="python3 pedagogical_examples/shardings.py \ --dcn_data_parallelism $SLICE_COUNT \ --ici_fsdp_parallelism 8 \ --batch_size 131072 \ --embedding_dimension 2048"
You'll see approximately 230 TFLOPs per second of performance in the log files.
Clean up the TPUs and QR when finished. The deletion will take two to five minutes to complete, and can be run in the background with the optional
--async
flag.
Scaling a workload to Multislice
Before running your model in a Multislice environment, make the following code changes:
- Use jax.experimental.mesh_utils.create_hybrid_device_mesh instead of jax.experimental.mesh_utils.create_device_mesh when creating your mesh.
These should be the only necessary code changes when moving to Multislice. To achieve high performance, DCN needs to be mapped onto data parallel, fully sharded data parallel or pipeline parallel axes. Performance considerations and sharding strategies are discussed in more detail in Sharding With Multislice for Maximum Performance.
To validate that your code can access all the devices, you can assert that
len(jax.devices())
is equal to the number of chips in your Multislice
environment. For example, if you are using four slices of v4-16
, you have
eight chips per slice * 4 slices, so len(jax.devices())
should return 32.
Choosing slice sizes for Multislice environments
To get a linear speed up, add new slices of the same size as your existing
slice. For example, if you use a v4-512
slice, Multislice will
achieve approximately twice the performance by adding a second v4-512
slice
and doubling your global batch size. For more information, see
Sharding With Multislice for Maximum Performance.
Running your Job on multiple slices
There are three different approaches to running your custom workload in a Multislice environment:
- Using the experimentation runner script,
multihost_runner.py
- Using the production runner script,
multihost_job.py
- Using a manual approach
Experimentation runner script
The multihost_runner.py
script distributes code to an existing Multislice environment, and runs
your command on each host, copies your logs back, and tracks each command's error
status. The multihost_runner.py
script is documented in
MaxText README.
Because multihost_runner.py
maintains persistent SSH connections, it is only
suitable for modestly sized, relatively short-running experimentation. You can
adapt the steps in the multihost_runner.py tutorial
to your workload and hardware configuration.
Production runner script
For production jobs that need resiliency against hardware failures and other
preemptions, it is best to integrate directly with the Create Queued Resource
API. As a working example, we provide multihost_job.py
,
which triggers the Created Queued Resource API call with the appropriate startup
script to run your training and resume on preemption. The multihost_job.py
script is documented in the
MaxText README.
Because multihost_job.py
must provision resources for each run, it doesn't
provide as fast an iteration cycle as multihost_runner.py
.
Manual approach
We recommend you use or adapt multihost_runner.py or multihost_job.py to run your custom workload in your Multislice configuration. However, if you prefer to provision and manage your environment using QR commands directly, see Manage a Multislice Environment.
Manage a Multislice environment
To manually provision and manage QRs without using the tools provided in the MaxText repo, read the following sections.
Create QRs
Set the following environment variables before provisioning capacity:
$ export your-qr-id=your-queued-resource-id $ export PROJECT=your-project-name $ export ZONE=us-central2-b $ export NETWORK_NAME=your-network-name $ export SUBNETWORK_NAME=your-subnetwork-name $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-16 $ export SLICE_COUNT=4 $ export STARTUP_SCRIPT="#!/bin/bash\n ..." $ gcloud config set project project-name $ gcloud config set compute/zone zone
Input | Description |
your-qr-id | The user-assigned ID of the QR. |
PROJECT | Google Cloud Project Name |
ZONE | us-central2-b |
NETWORK_NAME | Name of the VPC networks. |
SUBNETWORK_NAME | Name of the subnet in VPC networks |
RUNTIME_VERSION | tpu-ubuntu2204-base |
ACCELERATOR_TYPE | v4-16 |
EXAMPLE_TAG_1, EXAMPLE_TAG_2 … | Tags used to identify valid sources or targets for network firewalls |
SLICE_COUNT | Number of slices. Limited to a maximum of 256 slices. |
STARTUP_SCRIPT | If added to the create request, a startup script can run whenever a TPU slice is provisioned or restarted and if the TPU slice is repaired or reset. |
Create a QR request using gcloud
$ gcloud compute tpus queued-resources \
create ${your-qr-id} \
--project your-project-id \
--zone your-zone \
--node-count ${SLICE_COUNT} \
--accelerator-type ${ACCELERATOR_TYPE} \
--runtime-version ${RUNTIME_VERSION} \
--network ${NETWORK_NAME} \
--subnetwork ${SUBNETWORK_NAME} \
--tags ${EXAMPLE_TAG_1},${EXAMPLE_TAG_2} \ --metadata=startup-script='${STARTUP_SCRIPT}'
[--reserved|--spot]
Command flag descriptions
your-qr-id
- A user-defined string that identifies the QR request.
project
- A user-defined string that identifies the QR request.
zone
- The Google Cloud zone in which to create the QR.
node-count
- The number of slices to create.
accelerator-type
- The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
runtime-version
- The Cloud TPU software version.
network
- The name of a VPC network to which to attach the TPU resource.
subnetwork
- The name of a VPC subnetwork to which to attach the TPU resource.
reserved
- Use reserved quota when creating the slices.
spot
- Use Spot VMs quota when creating the slices.
Ensure you have the respective quota before selecting --reserved
,
--spot
, or the default on-demand quota. For information on quota types,
see Quota Policy.
Create a QR request using curl
Create a file named queued-resource-req.json
and copy the following JSON into it.
{ "guaranteed": { "reserved": true }, "tpu": { "node_spec": [ { "parent": "projects/your-project-number/locations/your-zone", "node": { "accelerator_type": "accelerator-type", "runtime_version": "tpu-vm-runtime-version", "network_config": { "network": "your-network-name", "subnetwork": "your-subnetwork-name", "enable_external_ips": true }, "tags" : ["example-tag-1"] "metadata": { "startup-script": "your-startup-script" } }, "multi_node_params": { "node_count": slice-count, "node_id_prefix": "your-queued-resource-id" } } ] } }
- your-project-number - Your Google Cloud project number
- your-zone - The zone in which you want to create your QR
- accelerator-type - The version and size of a single slice
- tpu-vm-runtime-version - The TPU VM runtime versions
- your-network-name - Optional, a network to which the QR will be attached
- your-subnetwork-name - Optional, a subnetwork to which the QR will be attached
- example-tag-1 - Optional, an arbitrary tag string
- your-startup-script - A startup script that will be run when the QR is allocated
- slice-count - The number of TPU slices in your Multislice environment
- your-qr-id - The user supplied ID for the QR
For more information, see the REST Queued Resource API documentation for all available options.
To use Spot capacity, replace:
"guaranteed": { "reserved": true }
with "spot": {}
Remove the line to use the default on-demand capacity.
Submit the QR create request with the JSON payload:
$ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" -d @queuedresourcereq.json https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources\?queued_resource_id\=your-qr-id
- your-project-id - Your Google Cloud project ID
- your-zone - The zone in which you want to create your QR
- your-qr-id - The user supplied ID for the QR
The response should look like the following:
{ "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qr-guid>", "metadata": { "@type": "type.googleapis.com/google.cloud.common.OperationMetadata", "createTime": "2023-11-01T00:17:05.742546311Z", "target": "projects/<your-project-id>/locations/<your-zone>/queuedResources/<your-qa-id>", "verb": "create", "cancelRequested": false, "apiVersion": "v2alpha1" }, "done": false }
Use GUID value at the end of the string value for the name
attribute to get
information about the QR request.
Retrieve status of a QR
To get the status of the QR request, use the following command:
$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/operations/operation-your-qr-guid
- your-project-id - Your Google Cloud project ID
- your-zone - The zone in which to create the QR
- your-qr-guid - The GUID following
name
in the output from the QR creation request.
The response of this command contains the status of the operation:
{ "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qa-guid>, "metadata": {...}, "done": true, "response": { "@type": "type.googleapis.com/google.cloud.tpu.v2.QueuedResource", ... "state": { "state": "WAITING_FOR_RESOURCES" } } }
If the QR is created successfully ("done = true")
, the state within the
response
field will be either WAITING_FOR_RESOURCES
or FAILED
.
If the QR is in the WAITING_FOR_RESOURCES
state, the QR has been
enqueued and will start provisioning when there are enough resources. If the QR
is in the FAILED
state, the failure reason will be in the output. For more
information about other possible states, see the
Queued resources user guide.
Once the operation is done, use the describe QRs to monitor the stages of the QR.
In a rare scenario, you might find your QR in the FAILED
state while some
slices are ACTIVE
. If this happens, delete the resources that were created,
and try again in a few minutes or reach out
to the Cloud TPU team to resolve the issue.
SSH and install dependencies
Run JAX code on TPU Pod slices
describes how to connect to your TPU VMs using SSH in a single slice. To
connect to all TPU VMs in your Multislice environment over SSH and
install dependencies, use the following gcloud
command:
$ gcloud compute tpus queued-resources ssh ${your-qr-id} \
--zone your-zone \
--node=all \
--worker=all \
--command="command-to-run"
--batch-size=4
This gcloud
command sends the specified command to all workers and nodes in
QR using SSH. The command is batched into groups of four and is sent
simultaneously. The next batch of commands are sent when the current batch
completes execution. If there is a failure with one of the commands, processing
stops, and no further batches are sent. For more information, see the
Queued resource API reference.
If the number of slices you are using exceeds your local computer's threading
limit (also called batching limit) you will run into a deadlock. As an example,
assume the batching limit on your local machine is 64. If you try to run a
training script on more than 64 slices, say 100, the SSH command will break up
the slices into batches. It will run the training script on the first batch of 64
slices and wait for the scripts to complete before running the script on the
remaining batch of 36 slices. However, the first batch of 64 slices cannot
complete until the remaining 36 slices start running the script, causing a
deadlock.
To prevent this scenario, you can run the training script in the background on
each VM by appending an ampersand (&
) to the script command you specify
with the --command
flag. When you do this, after starting the training script
on the first batch of slices, control will immediately return to
the SSH command. The SSH command can then start running the training script on
the remaining batch of 36 slices. You'll need to pipe your stdout
and stderr
streams appropriately when running the commands in the background. To increase
parallelism within the same QR, you can select specific slices using the --node
parameter.
Network setup
Ensure TPU slices can communicate with each other by running the following steps.
Install JAX on each of the slices. For more information, see
Run JAX code on TPU Pod slices. Assert that
len(jax.devices())
is equal to the number of chips in your Multislice
environment. To do this, on each slice, run:
$ python3 -c 'import jax; print(jax.devices())'
If you run this code on four slices of v4-16's, there are eight chips per
slice and four slices, a total of 32 chips (devices) should be returned
by jax.devices()
.
List QRs
You can view the state of your QRs using the queued-resources list
command:
$ gcloud compute tpus queued-resources list
NAME ZONE NODE_COUNT ACCELERATOR_TYPE STATE
...
que-res-id us-central2-b 4 v4-16 ACTIVE
...
Describe QRs
To view the detailed configuration and state of a QR, use the
describe QR API. You can call this API using gcloud
or curl
.
Using gcloud
:
$ gcloud compute tpus queued-resources describe ${your-qr-id}
...state:
state: ACTIVE
...
Using curl
:
$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/queuedResources/${your-qr-id}
{
"name": your-queued-res,
"tpu": {
"nodeSpec": [
{
... // node 1
},
{
... // node 2
},
...
]
},
...
"state": "ACTIVE"
}
state
represents the status of a QR. For more information on the possible
states of QRs, see Queued resources.
Start your job on a provisioned environment
You can manually run workloads by connecting to all hosts in each slice over SSH and running the following command on all hosts.
$ gcloud compute tpus tpu-vm ssh your-qr-id \
--zone=your-zone \
--worker=all \
--node=all \
--command="command-to-run"
Resetting QRs
The ResetQueuedResource
API can be used to reset
all the VMs in an ACTIVE
QR. Resetting the VMs forcibly erases the memory of
the machine and resets the VM to its initial state. Any data stored locally will
remain intact and the startup script will be invoked after a reset. The
ResetQueuedResource
API can be useful when you want to restart all TPUs. For
example, when training is stuck and resetting all VMs is easier than debugging.
The resets of all VMs are performed in parallel and a ResetQueuedResource
operation takes one to two minutes to complete. To invoke the API, use the following
command:
$ gcloud compute tpus queued-resources reset your-qr-id
Deleting QRs
To release resources at the end of your training session, delete the queued
resource with the --force
flag. The deletion will take two to five minutes to
complete, and can be run in the background with the optional --async
flag.
$ gcloud compute tpus queued-resources \
delete your-qr-id --force (--async)
Automatic failure recovery
In the event of a disruption, Multislice offers intervention-free repair of the impacted slice and reset of all slices afterward. The impacted slice is replaced with a new slice and the remaining otherwise healthy slices are reset. If no capacity is available to allocate a replacement slice, training stops.
To resume training automatically after an interruption, you must specify a startup script that checks for and loads the last saved checkpoints. Your startup script is automatically run every time a slice is reallocated or a VM is reset. You specify a startup script in the JSON payload you send to the create QR request API.
The following startup script (used in Create QRs) lets you automatically recover from failures and resume training from checkpoints stored in a Cloud Storage bucket during MaxText training:
{ "tpu": { "node_spec": [ { ... "metadata": { "startup-script": "#! /bin/bash \n pwd \n runuser -l user1 -c 'cd /home/user1/MaxText && python3 MaxText/train.py MaxText/configs/base.yml run_name=run_test_failure_recovery dcn_data_parallelism=4 ici_fsdp_parallelism=8 steps=10000 save_period=10 base_output_directory='gs://user1-us-central2'' EOF" } ... } ] } }
Clone the MaxText repo before trying this out.
Profiling and debugging
Profiling is the same in single slice and Multislice environments. For more information, see Profiling JAX programs.
Optimizing training
Sharding with Multislice for maximum performance
Achieving maximum performance in Multislice environments requires considering how to shard across the multiple slices. There are typically three choices (data parallelism, fully-sharded data parallelism and pipeline parallelism). We don't recommend sharding activations across the model dimensions (sometimes called tensor parallelism) because it requires too much inter-slice bandwidth. For all these strategies, you can keep the same sharding strategy within a slice that has worked for you in the past.
We recommend starting with pure data parallelism. Using fully-sharded data parallelism is useful for freeing up memory usage. The drawback is that communication between slices uses the DCN network and will slow down your workload. Use pipeline parallelism only when necessary based on batch size (as analyzed below).
When to use data parallelism
Pure data parallelism will work well in cases where you have a workload that is running well, but you'd like to improve its performance by scaling across multiple slices.
To achieve strong scaling across multiple slices, the amount of time required to perform all-reduce over DCN needs to be less than the amount of time required to perform a backwards pass. DCN is used for communication between slices and is a limiting factor in workload throughput.
Each v4 TPU chip performs at a peak of 275 * 1012 FLOPS per second.
There are four chips per TPU host and each host has a maximum network bandwidth of 50 Gbps.
That means the arithmetic intensity is 4 * 275 * 1012 FLOPS / 50 Gbps = 22000 FLOPS / bit.
Your model will use 32 to 64 bits of DCN bandwidth for each parameter per step. If you use two slices, your model will use 32 bits of DCN bandwidth. If you use more than two slices the compiler will perform a full shuffle all-reduce operation and you'll use up to 64 bits of DCN bandwidth for each parameter per step. The amount of FLOPS needed for each parameter will vary depending on your model. Specifically, for Transformer based language models, the number of FLOPS required for a forward and a backward pass are approximately 6 * B * P where:
- B is the batch size in tokens
- P is the number of parameters
The number of FLOPS per parameter is 6 * B
and the number of FLOPS per parameter
during the backwards pass is 4 * B
.
To ensure strong scaling across multiple slices, ensure that the operational
intensity exceeds the arithmetic intensity of the TPU hardware. To calculate the
operational intensity, divide the number of FLOPS per parameter during the
backwards pass by the network bandwidth (in bits) per parameter per step:
Operational Intensity = FLOPSbackwards_pass / DCN bandwidth
Therefore, for a Transformer based language model, if you are using two slices:
Operational intensity = 4 * B / 32
If you are using more than two slices: Operational intensity = 4 * B/64
This suggests a minimum batch size of between 176k and 352k for Transformer based language models. Because the DCN network can briefly drop packets, it's best to maintain a significant margin for error, deploying data parallelism only if the batch size per Pod is at least 350k (two Pods) to 700k (many Pods).
For other model architectures, you'll need to estimate the runtime of your backwards pass per slice (either by timing it using a profiler or by counting FLOPS). Then you can compare that to the expected run time to all reduce over DCN and get a good estimate of if data parallelism will make sense for you.
When to use fully sharded data parallelism (FSDP)
Fully sharded data parallelism (FSDP) combines data parallelism (sharding the data across nodes) with sharding the weights across nodes. For each operation in the forward and backward passes, the weights are all-gathered so that each slice has the weights it needs. Instead of synchronizing the gradients using all-reduce, the gradients are reduce-scattered as they are produced. In this way, each slice only gets the gradients for the weights it's responsible for.
Similar to data parallelism, FSDP will require scaling the global batch size linearly with the number of slices. FSDP will decrease the memory pressure as you increase the number of slices. This is because the number of weights and optimizer state per slice decreases but it does so at the price of increased network traffic and the greater possibility for blocking due to a delayed collective.
In practice, FSDP across slices is best if you are increasing the batch per slice, storing more activations to minimize re-materialization during the backwards pass or increasing the number of parameters in your neural network.
The all-gather and all-reduce operations in FSDP work similarly to those in DP, so you can determine if your FSDP workload is limited by DCN performance in the same way as described in the previous section.
When to use pipeline parallelism
Pipeline parallelism becomes relevant when achieving high performance with other parallelism strategies that require a global batch size greater than your preferred maximum batch size. Pipeline parallelism allows the slices comprising a pipeline to "share" a batch. However, pipeline parallelism has two significant downsides:
- It incurs the "pipeline bubble" where chips are idle because they are waiting for data.
- It requires micro-batching which decreases the effective batch size, the arithmetic intensity and ultimately model FLOP utilization.
Pipeline parallelism should be used only if the other parallelism strategies require too large a global batch size. Before trying pipeline parallelism, it is worth experimenting to see empirically if convergence per sample slows down at the batch size necessary to achieve high performance FSDP. FSDP tends to achieve higher model FLOP utilization but if the convergence per sample slows as the batch size grows, pipeline parallelism may still be the better choice. Most workloads can tolerate sufficiently large batch sizes to not benefit from pipeline parallelism, but your workload may be different.
If pipeline parallelism is necessary, we recommend combining it with data parallelism or FSDP. This will enable you to minimize the pipeline depth while increasing the per pipeline batch size until DCN latency becomes less of a factor in throughput. Concretely, if you have N slices, consider pipelines of depth 2 and N/2 replicas of data parallelism, then pipelines of depth 4 and N/4 replicas of data parallelism and so on, until the batch per pipeline gets large enough that the DCN collectives can be hidden behind the arithmetic in the backwards pass. This will minimize the slowdown introduced by pipeline parallelism while allowing you to scale past the global batch size limit.
Multislice best practices
Data loading
During training we repeatedly load batches from a dataset to feed into the model. Having an efficient, async data loader which shards the batch across hosts is important to avoid starving the TPUs of work. The current data loader in MaxText has each host load an equal subset of the examples. This solution is adequate for text but requires a reshard within the model. Additionally, MaxText doesn't yet offer deterministic snapshotting which would allow the data iterator to load the same data before and after preemption.
Checkpointing
The Orbax checkpointing library provides
primitives for checkpointing JAX PyTrees to local storage or Google Cloud storage.
We provide a reference integration with synchronous checkpointing into MaxText
in checkpointing.py
.
Supported configurations
Shapes
All slices must be of the same shape (for example, the same AcceleratorType
).
Heterogeneous slice shapes are not supported.
Orchestration
Orchestration is supported with GKE. For more information, see TPUs in GKE.
Frameworks
Multislice only supports JAX and PyTorch workloads.
Parallelism
We recommend users test Multislice with data parallelism. To learn more about implementing pipeline parallelism with Multislice, contact your Google Cloud account representative.
Support and Feedback
We welcome all feedback! To share feedback or request support, reach out to us using the Cloud TPU Support or Feedback form.