Scale ML workloads using Ray
This document provides details on how to run machine learning (ML) workloads with Ray and JAX on TPUs.
These instructions assume that you already have a Ray and TPU environment set up, including a software environment that includes JAX and other related packages. To create a Ray TPU cluster, follow the instructions at Start Google Cloud GKE Cluster with TPUs for KubeRay. For more information about using TPUs with KubeRay, see Use TPUs with KubeRay.
Run a JAX workload on a single-host TPU
The following example script demonstrates how to run a JAX function on a Ray cluster with a single-host TPU, such as a v6e-4. If you have a multi-host TPU, this script stops responding due to JAX's multi-controller execution model. For more information about running Ray on a multi-host TPU, see Run a JAX workload on a multi-host TPU.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
h = my_function.remote()
print(ray.get(h)) # => 4
If you're used to running Ray with GPUs, there are some key differences when using TPUs:
- Rather than setting
num_gpus
, you specifyTPU
as a custom resource and set the number of TPU chips. - You specify the TPU using the number of chips per Ray worker node. For
example, if you're using a v6e-4, running a remote function with
TPU
set to 4 consumes the entire TPU host.- This is different from how GPUs typically run, with one process per host.
Setting
TPU
to a number that isn't 4 is not recommended. - Exception: If you have a single-host
v6e-8
orv5litepod-8
, you should set this value to 8.
- This is different from how GPUs typically run, with one process per host.
Setting
Run a JAX workload on a multi-host TPU
The following example script demonstrates how to run a JAX function on a Ray cluster with a multi-host TPU. The example script uses a v6e-16.
If you want to run your workload on a cluster with multiple TPU slices, see Control individual TPU slices.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]
If you're used to running Ray with GPUs, there are some key differences when using TPUs:
- Similar to PyTorch workloads on GPUs:
- JAX workloads on TPUs run in a multi-controller, single program, multiple data (SPMD) fashion.
- Collectives between devices are handled by the machine learning framework.
- Unlike PyTorch workloads on GPUs, JAX has a global view of the available devices in the cluster.
Run a Multislice JAX workload
Multislice lets you run workloads that span multiple TPU slices within a single TPU Pod or in multiple Pods over the data center network.
For convenience, you can use the experimental
ray-tpu
package to simplify
Ray's interactions with TPU slices. Install ray-tpu
using pip
:
pip install ray-tpu
The following example script shows how to use the ray-tpu
package to run
Multislice workloads using Ray actors or tasks:
from ray_tpu import RayTpuManager
import jax
import ray
ray.init()
# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
def get_devices(self):
return jax.device_count()
# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
return jax.device_count()
tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus)
"""
TPU resources:
{'v6e-16': [
RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""
# if using actors
actors = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=MyActor,
multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]
# if using tasks
h = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]
# note - you can also run this without Multislice
h = RayTpuManager.run_task(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]
TPU and Ray resources
Ray treats TPUs differently from GPUs to accommodate for the difference in usage. In the following example, there are nine Ray nodes total:
- The Ray head node is running on an
n1-standard-16
VM. - The Ray worker nodes are running on two
v6e-16
TPUs. Each TPU constitutes four workers.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
Resource usage field descriptions:
CPU
: The total number of CPUs available in the cluster.TPU
: The number of TPU chips in the cluster.TPU-v6e-16-head
: A special identifier for the resource that corresponds with worker 0 of a TPU slice. This is important for accessing individual TPU slices.memory
: Worker heap memory used by your application.object_store_memory
: Memory used when your application creates objects in the object store usingray.put
and when it returns values from remote functions.tpu-group-0
andtpu-group-1
: Unique identifiers for the individual TPU slices. This is important for running jobs on slices. These fields are set to 4 because there are 4 hosts per TPU slice in a v6e-16.
Control individual TPU slices
A common practice with Ray and TPUs is to run several workloads within the same TPU slice, for example, in hyperparameter tuning or serving.
TPU slices require special consideration when using Ray for both provisioning and job scheduling.
Run single-slice workloads
When the Ray process starts on TPU slices (running ray start
), the process
auto-detects information about the slice. For example, the topology, number
of workers in the slice, and whether the process is running on worker 0.
When you run ray status
on a TPU v6e-16 with the name "my-tpu", the output
looks similar to the following:
worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}
"TPU-v6e-16-head"
is the resource label for worker 0 of the slice.
"TPU": 4
indicates that each worker
has 4 chips. "my-tpu"
is the name of the TPU. You can use these values to run
a workload on TPUs within the same slice, as in the following example.
Assume that you want to run the following function on all workers in a slice:
@ray.remote()
def my_function():
return jax.device_count()
You need to target worker 0 of the slice, then tell worker 0 how to broadcast
my_function
to every worker in the slice:
@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> returns my-tpu
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4
The example performs the following steps:
@ray.remote(resources={"TPU-v6e-16-head": 1})
: Therun_on_pod
function runs on a worker that has the resource labelTPU-v6e-16-head
, which targets any arbitrary worker 0.tpu_name = ray.util.accelerators.tpu.get_current_pod_name()
: Get the TPU name.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
: Get the number of workers in the slice.remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4})
: Add the resource label containing the TPU name and the"TPU": 4
resource requirement to the functionmy_function
.- Because each worker in the TPU slice has a custom resource label for the slice that it's in, Ray will only schedule the workload on the workers within the same TPU slice.
- This also reserves 4 TPU workers for the remote function, so Ray won't schedule other TPU workloads on that Ray Pod.
- Because
run_on_pod
only uses theTPU-v6e-16-head
logical resource,my_function
will also run on worker 0 but in a different process.
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
: Invoke the modifiedmy_function
function a number of times equal to the number of workers and return the results.h = run_on_pod(my_function).remote()
:run_on_pod
will execute asynchronously and not block the main process.
TPU slice autoscaling
Ray on TPUs supports autoscaling on the granularity of a TPU slice. You can
enable this feature using the GKE node auto provisioning
(NAP) feature. You can
execute this feature using the Ray Autoscaler and KubeRay. The head resource
type is used to signal autoscaling to Ray, for example, TPU-v6e-32-head
.