Scale ML workloads using Ray

This document provides details on how to run machine learning (ML) workloads with Ray and JAX on TPUs.

These instructions assume that you already have a Ray and TPU environment set up, including a software environment that includes JAX and other related packages. To create a Ray TPU cluster, follow the instructions at Start Google Cloud GKE Cluster with TPUs for KubeRay. For more information about using TPUs with KubeRay, see Use TPUs with KubeRay.

Run a JAX workload on a single-host TPU

The following example script demonstrates how to run a JAX function on a Ray cluster with a single-host TPU, such as a v6e-4. If you have a multi-host TPU, this script stops responding due to JAX's multi-controller execution model. For more information about running Ray on a multi-host TPU, see Run a JAX workload on a multi-host TPU.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

If you're used to running Ray with GPUs, there are some key differences when using TPUs:

  • Rather than setting num_gpus, you specify TPU as a custom resource and set the number of TPU chips.
  • You specify the TPU using the number of chips per Ray worker node. For example, if you're using a v6e-4, running a remote function with TPU set to 4 consumes the entire TPU host.
    • This is different from how GPUs typically run, with one process per host. Setting TPU to a number that isn't 4 is not recommended.
    • Exception: If you have a single-host v6e-8 or v5litepod-8, you should set this value to 8.

Run a JAX workload on a multi-host TPU

The following example script demonstrates how to run a JAX function on a Ray cluster with a multi-host TPU. The example script uses a v6e-16.

If you want to run your workload on a cluster with multiple TPU slices, see Control individual TPU slices.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

If you're used to running Ray with GPUs, there are some key differences when using TPUs:

Run a Multislice JAX workload

Multislice lets you run workloads that span multiple TPU slices within a single TPU Pod or in multiple Pods over the data center network.

For convenience, you can use the experimental ray-tpu package to simplify Ray's interactions with TPU slices. Install ray-tpu using pip:

pip install ray-tpu

The following example script shows how to use the ray-tpu package to run Multislice workloads using Ray actors or tasks:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

TPU and Ray resources

Ray treats TPUs differently from GPUs to accommodate for the difference in usage. In the following example, there are nine Ray nodes total:

  • The Ray head node is running on an n1-standard-16 VM.
  • The Ray worker nodes are running on two v6e-16 TPUs. Each TPU constitutes four workers.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Resource usage field descriptions:

  • CPU: The total number of CPUs available in the cluster.
  • TPU: The number of TPU chips in the cluster.
  • TPU-v6e-16-head: A special identifier for the resource that corresponds with worker 0 of a TPU slice. This is important for accessing individual TPU slices.
  • memory: Worker heap memory used by your application.
  • object_store_memory: Memory used when your application creates objects in the object store using ray.put and when it returns values from remote functions.
  • tpu-group-0 and tpu-group-1: Unique identifiers for the individual TPU slices. This is important for running jobs on slices. These fields are set to 4 because there are 4 hosts per TPU slice in a v6e-16.

Control individual TPU slices

A common practice with Ray and TPUs is to run several workloads within the same TPU slice, for example, in hyperparameter tuning or serving.

TPU slices require special consideration when using Ray for both provisioning and job scheduling.

Run single-slice workloads

When the Ray process starts on TPU slices (running ray start), the process auto-detects information about the slice. For example, the topology, number of workers in the slice, and whether the process is running on worker 0.

When you run ray status on a TPU v6e-16 with the name "my-tpu", the output looks similar to the following:

worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}

"TPU-v6e-16-head" is the resource label for worker 0 of the slice. "TPU": 4 indicates that each worker has 4 chips. "my-tpu" is the name of the TPU. You can use these values to run a workload on TPUs within the same slice, as in the following example.

Assume that you want to run the following function on all workers in a slice:

@ray.remote()
def my_function():
    return jax.device_count()

You need to target worker 0 of the slice, then tell worker 0 how to broadcast my_function to every worker in the slice:

@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
    tpu_name = ray.util.accelerators.tpu.get_current_pod_name()  # -> returns my-tpu
    num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
    remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
    return ray.get([remote_fn.remote() for _ in range(num_hosts)])

h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4

The example performs the following steps:

  • @ray.remote(resources={"TPU-v6e-16-head": 1}): The run_on_pod function runs on a worker that has the resource label TPU-v6e-16-head, which targets any arbitrary worker 0.
  • tpu_name = ray.util.accelerators.tpu.get_current_pod_name(): Get the TPU name.
  • num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count(): Get the number of workers in the slice.
  • remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}): Add the resource label containing the TPU name and the "TPU": 4 resource requirement to the function my_function.
    • Because each worker in the TPU slice has a custom resource label for the slice that it's in, Ray will only schedule the workload on the workers within the same TPU slice.
    • This also reserves 4 TPU workers for the remote function, so Ray won't schedule other TPU workloads on that Ray Pod.
    • Because run_on_pod only uses the TPU-v6e-16-head logical resource, my_function will also run on worker 0 but in a different process.
  • return ray.get([remote_fn.remote() for _ in range(num_hosts)]): Invoke the modified my_function function a number of times equal to the number of workers and return the results.
  • h = run_on_pod(my_function).remote(): run_on_pod will execute asynchronously and not block the main process.

TPU slice autoscaling

Ray on TPUs supports autoscaling on the granularity of a TPU slice. You can enable this feature using the GKE node auto provisioning (NAP) feature. You can execute this feature using the Ray Autoscaler and KubeRay. The head resource type is used to signal autoscaling to Ray, for example, TPU-v6e-32-head.