This page shows you how to deploy workloads in Google Kubernetes Engine (GKE) by using the Cloud TPU Multislice configuration for cost-effective, large-scale training.
Before you configure Multislice in GKE, ensure that you're familiar with the following concepts:
What's TPU Multislice
TPU Multislice is the architectural organization of VMs in a TPU slice where two or more Cloud TPU slices communicate over the Data Center Network (DCN). Multislice enables full-stack, cost effective, large scale training with near-linear scaling up to tens of thousands of TPU chips. In a Multislice configuration, GKE deploys a Multislice workload on multiple TPU slices. The communication between TPU chips within a slice happens over inter chip interconnects (ICI). The communication between slices happens over the DCN.
We recommend that you use the Multislice if your Job is too big to fit on a single TPU slice.
Multislice availability in GKE
- Standard supports Multislice in version 1.27.4-gke.900 and later.
- Autopilot supports Multislice in version 1.29.2-gke.1521000 and later.
- Multislice supports JAX and PyTorch frameworks. The minimum supported JAX version is 2.1.
- Multislice only supports
multi-host TPU slice node pools. For example, you cannot use
Multislice with a
ct4p-hightpu-4t
with a2x2x1
topology or act5lp-hightpu-4t
with a2x2
topology, because these are single-host TPU slice node pools. - Multislice only supports synchronous multicontroller training.
- Multislice workloads can only run across TPU slices that share the same TPU type, size, and topology.
- Multislice don't support TPU v3.
Before you begin
Before you start, make sure you have performed the following tasks:
- Enable the Google Kubernetes Engine API. Enable Google Kubernetes Engine API
- If you want to use the Google Cloud CLI for this task,
install and then
initialize the
gcloud CLI. If you previously installed the gcloud CLI, get the latest
version by running
gcloud components update
.
- Create a Standard cluster or an Autopilot cluster that runs a version that supports Multislice. For supported versions, see Multislice availability in GKE.
- Ensure your project has sufficient quota for Cloud TPU in GKE.
- Install JobSet v0.2.3 or later.
Run a workload on a Multislice
This section shows you how to run a workload on a Multislice. If you use GKE Autopilot mode, skip to the Run a Multislice workload section. Autopilot clusters that run version 1.29.2-gke.1521000 or later enable TPUs by default.
Prepare a Standard mode node pool
This section covers the following steps:
- Create three multi-host TPU slice node pools
- Verify the node pool status
Create the TPU slice node pool
You can create more than one multi-host TPU slice node pool. For the purpose of this guide, create three multi-host TPU slice node pools to run a Multislice workload. You can create a multi-host TPU slice node pool using the Google Cloud CLI, Terraform, or the Google Cloud console.
gcloud
gcloud container node-pools create POOL_NAME \
--location=LOCATION \
--cluster=CLUSTER_NAME \
--node-locations=NODE_ZONE \
--machine-type=MACHINE_TYPE \
--tpu-topology=TPU_TOPOLOGY \
--num-nodes=NUM_NODES \
[--spot \]
[--enable-autoscaling \
--max-nodes MAX_NODES]
[--reservation-affinity=specific \
--reservation=RESERVATION_NAME] \
Replace the following:
POOL_NAME
: The name of the new node pool.LOCATION
: The name of the zone based on the TPU version you want to use. To identify an available location, see TPU availability in GKE.CLUSTER_NAME
: The name of the cluster.NODE_ZONE
: The comma-separated list of one or more zones where GKE creates the node pool.MACHINE_TYPE
: The type of machine to use for nodes. To learn more about the available machine types, see Choose the TPU version.TPU_TOPOLOGY
: The physical topology for the TPU slice. The format of the topology depends on the TPU version. To learn more about TPU topologies, use the table in Choose a topology.To learn more, see Topology.
NUM_NODES
: The number of nodes in the node pool. It must be zero or the product of the values defined inTPU_TOPOLOGY
({A}x{B}x{C}
) divided by the number of chips in each VM. For multi-host TPU v4 and TPU v5e, the number of chips in each VM is four. Therefore, if yourTPU_TOPOLOGY
is2x4x4
(TPU v4 with four chips in each VM), then theNUM_NODES
is 32/4 which equals to 8.
Optionally, you can also use the following flags:
RESERVATION_NAME
: The name of the reservation GKE uses when creating the node pool. If you omit this flag, GKE uses available TPU slice node pools. To learn more about TPU reservations, see TPU reservation.--spot
: Sets the node pool to use Spot VMs for the TPU slice nodes. This cannot be changed after node pool creation. For more information, see Spot VMs.--enable-autoscaling
: Create a node pool with autoscaling enabled. When GKE scales a multi-host TPU slice node pool, it atomically scales up the node pool from zero to the maximum size.MAX_NODES
: The maximum size of the node pool. The--max-nodes
flag is required if--enable-autoscaling
is supplied and must be equal to the product of the values defined inTPU_TOPOLOGY
({A}x{B}x{C}
) divided by the number of chips in each VM.
Terraform
- Ensure that you use the version 4.84.0 or later of the
google
provider. Add the following block to your Terraform configuration:
resource "google_container_node_pool" "NODE_POOL_RESOURCE_NAME" { provider = google project = PROJECT_ID cluster = CLUSTER_NAME name = POOL_NAME location = CLUSTER_LOCATION node_locations = [NODE_ZONES] initial_node_count = NUM_NODES autoscaling { max_node_count = MAX_NODES location_policy = "ANY" } node_config { machine_type = MACHINE_TYPE reservation_affinity { consume_reservation_type = "SPECIFIC_RESERVATION" key = "compute.googleapis.com/reservation-name" values = [RESERVATION_LABEL_VALUES] } spot = true } placement_policy { type = "COMPACT" tpu_topology = TPU_TOPOLOGY } }
Replace the following:
NODE_POOL_RESOURCE_NAME
: The name of the node pool resource in the Terraform template.PROJECT_ID
: Your project ID.CLUSTER_NAME
: The name of the existing cluster to add the node pool to.POOL_NAME
: The name of the node pool to create.CLUSTER_LOCATION
: Compute location for the cluster. We recommend having a regional cluster for higher reliability of the Kubernetes control plane. You can also use a zonal cluster. To learn more, see Select a TPU version and topology.NODE_ZONES
: The comma-separated list of one or more zones where GKE creates the node pool.NUM_NODES
: The number of nodes in the node pool. It must be zero or the product of the number of the TPU chips divided by four, because in multi-host TPU slices each TPU slice node has 4 chips. For example, ifTPU_TOPOLOGY
is4x8
, then there are 32 chips which meansNUM_NODES
must be 8. To learn more about TPU topologies, use the table in Choose the TPU version.TPU_TOPOLOGY
: This indicates the desired physical topology for the TPU slice. The format of the topology depends on the TPU version you are using. To learn more about TPU topologies, use the table in Choose a topology.
Optionally, you can also use the following variables:
RESERVATION_NAME
: If you use TPU reservation, this is the list of labels of the reservation resources to use when creating the node pool. To learn more about how to populate theRESERVATION_LABEL_VALUES
in thereservation_affinity
field, see Terraform Provider.autoscaling
: Create a node pool with autoscaling enabled. When GKE scales a multi-host TPU slice node pool, it atomically scales up the node pool from zero to the maximum size.MAX_NODES
: It is the maximum size of the node pool. It must be equal to the product of the values defined inTPU_TOPOLOGY
({A}x{B}x{C}
) divided by the number of chips in each VM.
spot
: Lets the node pool to use Spot VMs for the TPU slice nodes. This cannot be changed after node pool creation. For more information, see Spot VMs.
Console
To create a node pool with TPUs:
Go to the Google Kubernetes Engine page in the Google Cloud console.
In the cluster list, click the name of the cluster you want to modify.
Click add_box Add node pool.
In the Node pool details section, check the Specify node locations box.
Select the name of the zone based on the TPU version you want to use. To identify an available location, see TPU availability in GKE.
From the navigation pane, click Nodes.
In the Machine Configuration section, select TPUs.
In the Series drop-down menu, select one of the following:
- CT3P: For TPU v3.
- CT4P: For TPU v4.
- CT5LP: For TPU v5e.
In the Machine type drop-down menu, select the name of the machine to use for nodes. Use the Choose the TPU version table to learn how to define the machine type and TPU topology that create a multi-host TPU slice node pool.
In the TPU Topology drop-down menu, select the physical topology for the TPU slice.
In the Changes needed dialog, click Make changes.
Ensure that Boot disk type is either Standard persistent disk or SSD persistent disk.
Optionally, select the Enable nodes on spot VMs checkbox to use Spot VMs for the nodes in the node pool.
Click Create.
Verify the node pool status
Get credentials, so that you can use
kubectl
to access the cluster:gcloud container clusters get-credentials CLUSTER_NAME \ --project=PROJECT_ID
Replace the following:
CLUSTER_NAME
: The name of the cluster.PROJECT_ID
: Your project ID.
Use
kubectl
, in Cloud Shell, to see your TPU slice nodes:kubectl get nodes -l cloud.google.com/gke-tpu-accelerator=TPU_ACCELERATOR \ -l cloud.google.com/gke-tpu-topology=TPU_TOPOLOGY
Replace the following:
TPU_ACCELERATOR
: The type of TPU accelerator you used when you created the node pools. For example,tpu-v4-podslice
,tpu-v5-lite-device
, ortpu-v5-lite-podslice
.TPU_TOPOLOGY
: The physical topology for the TPU slice.
The output is similar to the following:
NAME STATUS ROLES AGE VERSION gke-tpu-20ee2cce-5tv6 Ready <none> 34h v1.28.1-gke.1066000
Run a Multislice workload
In this section you run a JAX workload which shows the global number of TPU chips in the TPU slice and then exits.
To run a JAX workload do the following:
Create the following
tpu-multislice.yaml
manifest:Autopilot
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-job annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: NUM_SLICES template: spec: parallelism: NUM_NODES completions: NUM_NODES backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 - containerPort: 8431 command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' sleep 60 resources: limits: google.com/tpu: NUM_CHIPS
Standard
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-job annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: NUM_SLICES template: spec: parallelism: NUM_NODES completions: NUM_NODES backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 - containerPort: 8431 securityContext: privileged: true command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' sleep 60 resources: limits: google.com/tpu: NUM_CHIPS
Replace the following:
NUM_SLICES
: The number of TPU slice node pools. In this case, theNUM_SLICES
equals3
.ACCELERATOR_TYPE
: The type of TPU accelerator you used when you created the node pools. For example,tpu-v4-podslice
,tpu-v5-lite-device
, ortpu-v5-lite-podslice
.TPU_TOPOLOGY
: The physical topology for the TPU slice. For example4x4x4
or2x2
depending on the TPU version.NUM_NODES
: The number of nodes in the node pool. It must be zero or the product of the values defined inTPU_TOPOLOGY
({A}x{B}x{C}
) divided by the number of TPU chips in each VM. For multi-host TPU v4, the number of TPU chips in each VM is four. For multi-host TPU v5e, the number of TPU chips in each VM is one, four, or eight. Therefore, if yourTPU_TOPOLOGY
is2x4x4
(TPU v4 with four TPU chips in each VM), then theNUM_NODES
is 32/4 which equals to 8.NUM_CHIPS
: For multi-host TPU v4, the number of TPU chips in each VM is four. For multi-host TPU v5e, the number of TPU chips in each VM is one, four, or eight. To learn more, see TPU chips on the VM in a TPU slice.
In this manifest:
- The JobSet is a Headless Service with the same name as the JobSet name, in
this case it is
multislice-job
. - The
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
annotation configures pod affinity to ensure all pods are scheduled on the same slice. - The
maxRestarts: 4
indicates the maximum number of times that GKE restarts the JobSet when a child Job fails. If the JobSet restarts reaches the maximum defined, then the JobSet is marked as failed. - The
parallelism
andcompletions
fields equal the number of nodes in each node pool. - The
backoff
is 0 because Multislice only supports synchronous multi-controller training. Must be set to 0. Fail the job when any pod fails. - The values in the affinity section ensure that there is only one TPU Multislice workload running in a group of Multislices.
- The
containerPort: 8080
is the port for MXLA coordinator - The
containerPort: 8431
is the port to export the TPU usage metrics - The
securityContext: privileged: true
indicates that nodes have privileged mode enabled to access TPUs. Nodes in GKE version 1.28 or later don't need to have privileged mode enabled to access TPUs. To learn more, see Run containers without privileged mode.
Apply the manifest:
kubectl apply -f tpu-multislice.yaml
Confirm that the workload is admitted:
kubectl get jobsets
The output is similar to the following:
NAME RESTARTS COMPLETED AGE multislice-job 3s
Monitor the status of the provisioned Pods:
kubectl get pods
The output is similar to the following:
NAME READY STATUS RESTARTS AGE multislice-job-slice-0-0-wzq9t 0/1 Completed 0 2m31s multislice-job-slice-0-1-zf4dp 0/1 Completed 0 2m30s multislice-job-slice-1-0-hbfn5 0/1 Completed 0 2m31s multislice-job-slice-1-1-45fgl 0/1 Completed 0 2m30s multislice-job-slice-2-0-wjbp4 0/1 Completed 0 2m30s multislice-job-slice-2-1-lwnvs 0/1 Completed 0 2m30s
The multislice-job
JobSet schedules, creates, then runs the Pods to completion. The Pod names are in the format
<jobsetName>-<jobName>-<jobReplicaIndex>-<randomSuffix>
. The jobsetName
prefix determines the JobSet the Pod belongs to.
Additional configurations
The following sections describe the additional configurations you can apply to your Multislice.
Improve network performance with hostNetwork
To improve network performance between TPU slices, we recommend you turn on
hostNetworking
. Use hostNetwork: true
in your Pod spec to skip all the
Kubernetes networking stack and let your Kubernetes Pods use the host network
directly for VM-to-VM communication.
To turn on hostNetworking
, add the following two lines to your Pod spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
To keep using the podHostnames
for worker node discovery with hostNetwork
, set
dnsPolicy: ClusterFirstWithHostNet
. This is important when you are running auto-resuming
training Jobs and you need to have the same names for reloading the same
checkpoints.
If you use TPU Trillium (v6e) and your Pods use hostNetworking
, install the following DaemonSet for tuning /proc/sys/net/ipv4/tcp_rmem
on the node.
kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/ai-on-gke/51bf3dcab6ff658cf62cc32867f96860bf58dfdc/scripts/network-setup/v6e-increase-rmem.yaml
Improve network performance without hostNetwork on Trillium (v6e)
If you use TPU Trillium (v6e) and your Pods can't use hostNetworking
, enable multi-networking with netdevice
mode for best network performance. The netdevice
mode NIC support with multi-network passes VM NIC directly to Pod, bypassing Kubernetes and GKE Dataplane V2.
The ct6e-standard-4t
machine type is backed by two physical NICs. Kubernetes requires one vNIC that can't be passed to Pods. Therefore, each node must have three vNICs to allow Pods to have direct access to two vNICs to achieve best performance of both physical NICs.
To enable netdevice
mode for ct6e-standard-4t
, complete the following steps:
- Create two additional VPCs that supports
netdevice
mode - Create a GKE cluster with multi-network capabilities
Configure two
netdevice
networks. For example, you can use the followingGKENetworkParamSet
andNetwork
objects (SECOND_VPC
andTHIRD_VPC
are the VPCs created in the preceding step):apiVersion: networking.gke.io/v1 kind: GKENetworkParamSet metadata: name: tpu-second spec: vpc: SECOND_VPC vpcSubnet: SECOND_VPC_SUBNET deviceMode: NetDevice --- apiVersion: networking.gke.io/v1 kind: GKENetworkParamSet metadata: name: tpu-third spec: vpc: THIRD_VPC vpcSubnet: SECOND_VPC_SUBNET deviceMode: NetDevice --- apiVersion: networking.gke.io/v1 kind: Network metadata: name: tpu-second spec: provider: "GKE" type: "Device" parametersRef: group: networking.gke.io kind: GKENetworkParamSet name: tpu-second --- apiVersion: networking.gke.io/v1 kind: Network metadata: name: tpu-third spec: provider: "GKE" type: "Device" parametersRef: group: networking.gke.io kind: GKENetworkParamSet name: tpu-third
Connect your Pods with three networks. For example, you can use the following annotations in your Pod specification:
metadata: annotations: networking.gke.io/default-interface: 'eth0' networking.gke.io/interfaces: | [ {"interfaceName":"eth0","network":"default"}, {"interfaceName":"eth1","network":"tpu-second"}, {"interfaceName":"eth2","network":"tpu-third"}, ]
Apply network sysctls inside Pod, in the init container or application container. For example, you can add following init container to the Pod spec:
initContainers: - name: "network-optimization-sysctls" image: "busybox" securityContext: privileged: true command: - bash - -c - | echo 5000 > /proc/sys/net/ipv4/tcp_rto_min_us echo 1 > /proc/sys/net/ipv4/tcp_no_metrics_save echo 0 > /proc/sys/net/ipv4/tcp_slow_start_after_idle echo 131072 > /proc/sys/net/core/optmem_max echo "4096 41943040 314572800" > /proc/sys/net/ipv4/tcp_rmem
Use eth1
and eth2
interfaces for better networking
performance, instead of the eth0
interface. This can be done by
adding export LIBTPU_INIT_ARGS="$LIBTPU_INIT_ARGS --megascale_grpc_interface_prefixes=eth1,eth2,lo"
to the workload specification.
Logging
Logs emitted by containers running on GKE nodes, including TPU slice nodes, are visible in the Logs Explorer, if you have GKE system logging enabled in your cluster.
You can view your logs from GKE using the Logs Explorer with the following filter to view the container logs for your workload:
resource.type="k8s_container"
resource.labels.cluster_name=CLUSTER_NAME
labels."k8s-pod/jobset_sigs_k8s_io/jobset-name"=JOBSET_NAME
Use the following filter for TPU slice and workers:
resource.type="k8s_container"
resource.labels.cluster_name=CLUSTER_NAME
labels."k8s-pod/jobset_sigs_k8s_io/jobset-name"=JOBSET_NAME
resource.labels.pod_name:<jobSetName>-<replicateJobName>-<job-index>-<worker-index>
To learn more, see View GKE TPU logs.
Observability and metrics
In addition to the general TPU metrics, there are 4 additional multislice specific TPU runtime metrics. These metrics are available in GKE version 1.29.1-gke.1016000 or later. TPU workload must use JAX version 0.4.24
The following are the available multislice metrics:
- DCN (Data Center Network) transfer latencies: Distribution of network transfer latencies for multislice traffic.
- Collective latencies: Distribution of end to end collective latency for multislice traffic.
- Host-to-Device transfer latencies: Distribution of host to device transfer latency for each chunk of data for multislice traffic.
- Device-to-Host transfer latencies: Distribution of device to host transfer latency for each chunk of data for multislice traffic.
These metrics are located in the Kubernetes container (k8s_container
) schema:
kubernetes.io/container/multislice/network/dcn_transfer_latencies
kubernetes.io/container/multislice/network/collective_end_to_end_latencies
kubernetes.io/container/multislice/accelerator/host_to_device_transfer_latencies
kubernetes.io/container/multislice/accelerator/device_to_host_transfer_latencies
TPU slice versus Multislice
The following table differentiates the architectural organization of a TPU slice and a Multislice:
TPU slice | Multislice | |
---|---|---|
Interconnectivity | The workload runs on a single TPU slice. All TPU chips in a slice are connected with ICI. | The workload runs on multiple TPU slices. Communication within a slice happens over ICI. Communication between slices occurs over DCN. |
Supported node pools | Single-host TPU slice and multi-host TPU slice | Groups of multi-host TPU slices |
Recommended workload type | IndexedJob or JobSet | JobSet |
What's next
- Learn how to Orchestrate Multislice workloads with TPU slices