Deploy TPU Multislices in GKE


This page shows you how to deploy workloads in Google Kubernetes Engine (GKE) by using the Cloud TPU Multislice configuration for cost-effective, large-scale training.

This tutorial is for Machine learning (ML) engineers and Platform admins and operators who want to use Kubernetes container orchestration to manage large-scale model training, tuning, and inference workloads using TPUs. To learn more about common roles and example tasks referenced in Google Cloud content, see Common GKE Enterprise user roles and tasks.

Before you configure Multislice in GKE, ensure that you're familiar with the following concepts:

  1. Introduction to Cloud TPU
  2. Cloud TPU system architecture
  3. About TPUs in GKE

What's TPU Multislice

TPU Multislice is the architectural organization of VMs in a TPU slice where two or more Cloud TPU slices communicate over the Data Center Network (DCN). Multislice enables full-stack, cost effective, large scale training with near-linear scaling up to tens of thousands of TPU chips. In a Multislice configuration, GKE deploys a Multislice workload on multiple TPU slices. The communication between TPU chips within a slice happens over inter chip interconnects (ICI). The communication between slices happens over the DCN.

We recommend that you use the Multislice if your Job is too big to fit on a single TPU slice.

Multislice availability in GKE

  • Standard supports Multislice in version 1.27.4-gke.900 and later.
  • Autopilot supports Multislice in version 1.29.2-gke.1521000 and later.
  • Multislice supports JAX and PyTorch frameworks. The minimum supported JAX version is 2.1.
  • Multislice only supports multi-host TPU slice node pools. For example, you cannot use Multislice with a ct4p-hightpu-4t with a 2x2x1 topology or a ct5lp-hightpu-4t with a 2x2 topology, because these are single-host TPU slice node pools.
  • Multislice only supports synchronous multicontroller training.
  • Multislice workloads can only run across TPU slices that share the same TPU type, size, and topology.
  • Multislice don't support TPU v3.

Before you begin

Before you start, make sure you have performed the following tasks:

  • Enable the Google Kubernetes Engine API.
  • Enable Google Kubernetes Engine API
  • If you want to use the Google Cloud CLI for this task, install and then initialize the gcloud CLI. If you previously installed the gcloud CLI, get the latest version by running gcloud components update.

Run a workload on a Multislice

This section shows you how to run a workload on a Multislice. If you use GKE Autopilot mode, skip to the Run a Multislice workload section. Autopilot clusters that run version 1.29.2-gke.1521000 or later enable TPUs by default.

Prepare a Standard mode node pool

This section covers the following steps:

  1. Create three multi-host TPU slice node pools
  2. Verify the node pool status

Create the TPU slice node pool

You can create more than one multi-host TPU slice node pool. For the purpose of this guide, create three multi-host TPU slice node pools to run a Multislice workload. You can create a multi-host TPU slice node pool using the Google Cloud CLI, Terraform, or the Google Cloud console.

gcloud

gcloud container node-pools create POOL_NAME \
    --location=LOCATION \
    --cluster=CLUSTER_NAME \
    --node-locations=NODE_ZONE \
    --machine-type=MACHINE_TYPE \
    --tpu-topology=TPU_TOPOLOGY \
    --num-nodes=NUM_NODES \
    [--spot \]
    [--enable-autoscaling \
      --max-nodes MAX_NODES]
    [--reservation-affinity=specific \
    --reservation=RESERVATION_NAME] \
    

Replace the following:

  • POOL_NAME: The name of the new node pool.
  • LOCATION: The name of the zone based on the TPU version you want to use. To identify an available location, see TPU availability in GKE.
  • CLUSTER_NAME: The name of the cluster.
  • NODE_ZONE: The comma-separated list of one or more zones where GKE creates the node pool.
  • MACHINE_TYPE: The type of machine to use for nodes. To learn more about the available machine types, see Choose the TPU version.
  • TPU_TOPOLOGY: The physical topology for the TPU slice. The format of the topology depends on the TPU version. To learn more about TPU topologies, use the table in Choose a topology.

    To learn more, see Topology.

  • NUM_NODES: The number of nodes in the node pool. It must be zero or the product of the values defined in TPU_TOPOLOGY ({A}x{B}x{C}) divided by the number of chips in each VM. For multi-host TPU v4 and TPU v5e, the number of chips in each VM is four. Therefore, if your TPU_TOPOLOGY is 2x4x4 (TPU v4 with four chips in each VM), then the NUM_NODES is 32/4 which equals to 8.

Optionally, you can also use the following flags:

  • RESERVATION_NAME: The name of the reservation GKE uses when creating the node pool. If you omit this flag, GKE uses available TPU slice node pools. To learn more about TPU reservations, see TPU reservation.
  • --spot: Sets the node pool to use Spot VMs for the TPU slice nodes. This cannot be changed after node pool creation. For more information, see Spot VMs.
  • --enable-autoscaling: Create a node pool with autoscaling enabled. When GKE scales a multi-host TPU slice node pool, it atomically scales up the node pool from zero to the maximum size.
    • MAX_NODES: The maximum size of the node pool. The --max-nodes flag is required if --enable-autoscaling is supplied and must be equal to the product of the values defined in TPU_TOPOLOGY ({A}x{B}x{C}) divided by the number of chips in each VM.

Terraform

  1. Ensure that you use the version 4.84.0 or later of the google provider.
  2. Add the following block to your Terraform configuration:

    resource "google_container_node_pool" "NODE_POOL_RESOURCE_NAME" {
      provider           = google
      project            = PROJECT_ID
      cluster            = CLUSTER_NAME
      name               = POOL_NAME
      location           = CLUSTER_LOCATION
      node_locations     = [NODE_ZONES]
      initial_node_count = NUM_NODES
    
      autoscaling {
        max_node_count = MAX_NODES
        location_policy      = "ANY"
      }
      node_config {
        machine_type = MACHINE_TYPE
        reservation_affinity {
          consume_reservation_type = "SPECIFIC_RESERVATION"
          key = "compute.googleapis.com/reservation-name"
          values = [RESERVATION_LABEL_VALUES]
        }
        spot = true
      }
    
      placement_policy {
        type = "COMPACT"
        tpu_topology = TPU_TOPOLOGY
      }
    }
    

    Replace the following:

    • NODE_POOL_RESOURCE_NAME: The name of the node pool resource in the Terraform template.
    • PROJECT_ID: Your project ID.
    • CLUSTER_NAME: The name of the existing cluster to add the node pool to.
    • POOL_NAME: The name of the node pool to create.
    • CLUSTER_LOCATION: Compute location for the cluster. We recommend having a regional cluster for higher reliability of the Kubernetes control plane. You can also use a zonal cluster. To learn more, see Select a TPU version and topology.
    • NODE_ZONES: The comma-separated list of one or more zones where GKE creates the node pool.
    • NUM_NODES: The number of nodes in the node pool. It must be zero or the product of the number of the TPU chips divided by four, because in multi-host TPU slices each TPU slice node has 4 chips. For example, if TPU_TOPOLOGY is 4x8, then there are 32 chips which means NUM_NODES must be 8. To learn more about TPU topologies, use the table in Choose the TPU version.
    • TPU_TOPOLOGY: This indicates the desired physical topology for the TPU slice. The format of the topology depends on the TPU version you are using. To learn more about TPU topologies, use the table in Choose a topology.

    Optionally, you can also use the following variables:

    • RESERVATION_NAME: If you use TPU reservation, this is the list of labels of the reservation resources to use when creating the node pool. To learn more about how to populate theRESERVATION_LABEL_VALUES in the reservation_affinity field, see Terraform Provider.
    • autoscaling: Create a node pool with autoscaling enabled. When GKE scales a multi-host TPU slice node pool, it atomically scales up the node pool from zero to the maximum size.
      • MAX_NODES: It is the maximum size of the node pool. It must be equal to the product of the values defined in TPU_TOPOLOGY ({A}x{B}x{C}) divided by the number of chips in each VM.
    • spot: Lets the node pool to use Spot VMs for the TPU slice nodes. This cannot be changed after node pool creation. For more information, see Spot VMs.

Console

To create a node pool with TPUs:

  1. Go to the Google Kubernetes Engine page in the Google Cloud console.

    Go to Google Kubernetes Engine

  2. In the cluster list, click the name of the cluster you want to modify.

  3. Click Add node pool.

  4. In the Node pool details section, check the Specify node locations box.

  5. Select the name of the zone based on the TPU version you want to use. To identify an available location, see TPU availability in GKE.

  6. From the navigation pane, click Nodes.

  7. In the Machine Configuration section, select TPUs.

  8. In the Series drop-down menu, select one of the following:

    • CT3P: For TPU v3.
    • CT4P: For TPU v4.
    • CT5LP: For TPU v5e.
  9. In the Machine type drop-down menu, select the name of the machine to use for nodes. Use the Choose the TPU version table to learn how to define the machine type and TPU topology that create a multi-host TPU slice node pool.

  10. In the TPU Topology drop-down menu, select the physical topology for the TPU slice.

  11. In the Changes needed dialog, click Make changes.

  12. Ensure that Boot disk type is either Standard persistent disk or SSD persistent disk.

  13. Optionally, select the Enable nodes on spot VMs checkbox to use Spot VMs for the nodes in the node pool.

  14. Click Create.

Verify the node pool status

  1. Get credentials, so that you can use kubectl to access the cluster:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --project=PROJECT_ID
    

    Replace the following:

    • CLUSTER_NAME: The name of the cluster.
    • PROJECT_ID: Your project ID.
  2. Use kubectl, in Cloud Shell, to see your TPU slice nodes:

    kubectl get nodes -l cloud.google.com/gke-tpu-accelerator=TPU_ACCELERATOR \
       -l cloud.google.com/gke-tpu-topology=TPU_TOPOLOGY
    

    Replace the following:

    • TPU_ACCELERATOR: The type of TPU accelerator you used when you created the node pools. For example, tpu-v4-podslice, tpu-v5-lite-device, or tpu-v5-lite-podslice.
    • TPU_TOPOLOGY: The physical topology for the TPU slice.

    The output is similar to the following:

     NAME                                    STATUS   ROLES    AGE    VERSION
     gke-tpu-20ee2cce-5tv6                   Ready    <none>   34h     v1.28.1-gke.1066000
    

Run a Multislice workload

In this section you run a JAX workload which shows the global number of TPU chips in the TPU slice and then exits.

To run a JAX workload do the following:

  1. Create the following tpu-multislice.yaml manifest:

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-job
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: NUM_SLICES
          template:
            spec:
              parallelism: NUM_NODES
              completions: NUM_NODES
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE
                    cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    - containerPort: 8431
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                     limits:
                        google.com/tpu: NUM_CHIPS
    

    Standard

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-job
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: NUM_SLICES
          template:
            spec:
              parallelism: NUM_NODES
              completions: NUM_NODES
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE
                    cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    - containerPort: 8431
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                       google.com/tpu: NUM_CHIPS
    

    Replace the following:

    • NUM_SLICES: The number of TPU slice node pools. In this case, the NUM_SLICES equals 3.
    • ACCELERATOR_TYPE: The type of TPU accelerator you used when you created the node pools. For example, tpu-v4-podslice, tpu-v5-lite-device, or tpu-v5-lite-podslice.
    • TPU_TOPOLOGY: The physical topology for the TPU slice. For example 4x4x4 or 2x2 depending on the TPU version.
    • NUM_NODES: The number of nodes in the node pool. It must be zero or the product of the values defined in TPU_TOPOLOGY ({A}x{B}x{C}) divided by the number of TPU chips in each VM. For multi-host TPU v4, the number of TPU chips in each VM is four. For multi-host TPU v5e, the number of TPU chips in each VM is one, four, or eight. Therefore, if your TPU_TOPOLOGY is 2x4x4 (TPU v4 with four TPU chips in each VM), then the NUM_NODES is 32/4 which equals to 8.
    • NUM_CHIPS: For multi-host TPU v4, the number of TPU chips in each VM is four. For multi-host TPU v5e, the number of TPU chips in each VM is one, four, or eight. To learn more, see TPU chips on the VM in a TPU slice.

    In this manifest:

    • The JobSet is a Headless Service with the same name as the JobSet name, in this case it is multislice-job.
    • The alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool annotation configures pod affinity to ensure all pods are scheduled on the same slice.
    • The maxRestarts: 4 indicates the maximum number of times that GKE restarts the JobSet when a child Job fails. If the JobSet restarts reaches the maximum defined, then the JobSet is marked as failed.
    • The parallelism and completions fields equal the number of nodes in each node pool.
    • The backoff is 0 because Multislice only supports synchronous multi-controller training. Must be set to 0. Fail the job when any pod fails.
    • The values in the affinity section ensure that there is only one TPU Multislice workload running in a group of Multislices.
    • The containerPort: 8080 is the port for MXLA coordinator
    • The containerPort: 8431 is the port to export the TPU usage metrics
    • The securityContext: privileged: true indicates that nodes have privileged mode enabled to access TPUs. Nodes in GKE version 1.28 or later don't need to have privileged mode enabled to access TPUs. To learn more, see Run containers without privileged mode.
  2. Apply the manifest:

    kubectl apply -f tpu-multislice.yaml
    
  3. Confirm that the workload is admitted:

    kubectl get jobsets
    

    The output is similar to the following:

    NAME            RESTARTS   COMPLETED   AGE
    multislice-job                         3s
    
  4. Monitor the status of the provisioned Pods:

    kubectl get pods
    

    The output is similar to the following:

     NAME                                READY   STATUS      RESTARTS   AGE
     multislice-job-slice-0-0-wzq9t      0/1     Completed   0          2m31s
     multislice-job-slice-0-1-zf4dp      0/1     Completed   0          2m30s
     multislice-job-slice-1-0-hbfn5      0/1     Completed   0          2m31s
     multislice-job-slice-1-1-45fgl      0/1     Completed   0          2m30s
     multislice-job-slice-2-0-wjbp4      0/1     Completed   0          2m30s
     multislice-job-slice-2-1-lwnvs      0/1     Completed   0          2m30s
    

The multislice-job JobSet schedules, creates, then runs the Pods to completion. The Pod names are in the format <jobsetName>-<jobName>-<jobReplicaIndex>-<randomSuffix>. The jobsetName prefix determines the JobSet the Pod belongs to.

Additional configurations

The following sections describe the additional configurations you can apply to your Multislice.

Improve network performance with hostNetwork

To improve network performance between TPU slices, we recommend you turn on hostNetworking. Use hostNetwork: true in your Pod spec to skip all the Kubernetes networking stack and let your Kubernetes Pods use the host network directly for VM-to-VM communication.

To turn on hostNetworking, add the following two lines to your Pod spec:

hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet

To keep using the podHostnames for worker node discovery with hostNetwork, set dnsPolicy: ClusterFirstWithHostNet. This is important when you are running auto-resuming training Jobs and you need to have the same names for reloading the same checkpoints.

If you use TPU Trillium (v6e) and your Pods use hostNetworking, install the following DaemonSet for tuning /proc/sys/net/ipv4/tcp_rmem on the node.

kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/ai-on-gke/51bf3dcab6ff658cf62cc32867f96860bf58dfdc/scripts/network-setup/v6e-increase-rmem.yaml

Improve network performance without hostNetwork on TPU Trillium

If you use TPU Trillium and your Pods can't use hostNetworking, enable multi-networking with netdevice mode for best network performance. The netdevice mode NIC support with multi-network passes VM NIC directly to Pod, bypassing Kubernetes and GKE Dataplane V2.

The ct6e-standard-4t machine type is backed by two physical NICs. Kubernetes requires one vNIC that can't be passed to Pods. Therefore, each node must have three vNICs to allow Pods to have direct access to two vNICs to achieve best performance of both physical NICs.

To enable netdevice mode for ct6e-standard-4t, complete the following steps:

  1. Create two additional VPCs that supports netdevice mode
  2. Create a GKE cluster with multi-network capabilities
  3. Configure two netdevice networks. For example, you can use the following GKENetworkParamSet and Network objects (SECOND_VPC and THIRD_VPC are the VPCs created in the preceding step):

    apiVersion: networking.gke.io/v1
    kind: GKENetworkParamSet
    metadata:
      name: tpu-second
    spec:
      vpc: SECOND_VPC
      vpcSubnet: SECOND_VPC_SUBNET
      deviceMode: NetDevice
    ---
    apiVersion: networking.gke.io/v1
    kind: GKENetworkParamSet
    metadata:
      name: tpu-third
    spec:
      vpc: THIRD_VPC
      vpcSubnet: SECOND_VPC_SUBNET
      deviceMode: NetDevice
    ---
    apiVersion: networking.gke.io/v1
    kind: Network
    metadata:
      name: tpu-second
    spec:
      provider: "GKE"
      type: "Device"
      parametersRef:
        group: networking.gke.io
        kind: GKENetworkParamSet
        name: tpu-second
    ---
    apiVersion: networking.gke.io/v1
    kind: Network
    metadata:
      name: tpu-third
    spec:
      provider: "GKE"
      type: "Device"
      parametersRef:
        group: networking.gke.io
        kind: GKENetworkParamSet
        name: tpu-third
    
  4. Connect your Pods with three networks. For example, you can use the following annotations in your Pod specification:

    metadata:
      annotations:
        networking.gke.io/default-interface: 'eth0'
        networking.gke.io/interfaces: |
          [
            {"interfaceName":"eth0","network":"default"},
            {"interfaceName":"eth1","network":"tpu-second"},
            {"interfaceName":"eth2","network":"tpu-third"},
          ]
    
  5. Apply network sysctls inside Pod, in the init container or application container. For example, you can add following init container to the Pod spec:

    initContainers:
    - name: "network-optimization-sysctls"
      image: "busybox"
      securityContext:
        privileged: true
      command:
      - bash
      - -c
      - |
        echo 5000 > /proc/sys/net/ipv4/tcp_rto_min_us
        echo 1 > /proc/sys/net/ipv4/tcp_no_metrics_save
        echo 0 > /proc/sys/net/ipv4/tcp_slow_start_after_idle
        echo 131072 > /proc/sys/net/core/optmem_max
        echo "4096 41943040 314572800" > /proc/sys/net/ipv4/tcp_rmem
    
Best practice:

Use eth1 and eth2 interfaces for better networking performance, instead of the eth0 interface. This can be done by adding export LIBTPU_INIT_ARGS="$LIBTPU_INIT_ARGS --megascale_grpc_interface_prefixes=eth1,eth2,lo" to the workload specification.

Logging

Logs emitted by containers running on GKE nodes, including TPU slice nodes, are visible in the Logs Explorer, if you have GKE system logging enabled in your cluster.

You can view your logs from GKE using the Logs Explorer with the following filter to view the container logs for your workload:

resource.type="k8s_container"
resource.labels.cluster_name=CLUSTER_NAME
labels."k8s-pod/jobset_sigs_k8s_io/jobset-name"=JOBSET_NAME

Use the following filter for TPU slice and workers:

resource.type="k8s_container"
resource.labels.cluster_name=CLUSTER_NAME
labels."k8s-pod/jobset_sigs_k8s_io/jobset-name"=JOBSET_NAME
resource.labels.pod_name:<jobSetName>-<replicateJobName>-<job-index>-<worker-index>

To learn more, see View GKE TPU logs.

Observability and metrics

In addition to the general TPU metrics, there are 4 additional multislice specific TPU runtime metrics. These metrics are available in GKE version 1.29.1-gke.1016000 or later. TPU workload must use JAX version 0.4.24

The following are the available multislice metrics:

  • DCN (Data Center Network) transfer latencies: Distribution of network transfer latencies for multislice traffic.
  • Collective latencies: Distribution of end to end collective latency for multislice traffic.
  • Host-to-Device transfer latencies: Distribution of host to device transfer latency for each chunk of data for multislice traffic.
  • Device-to-Host transfer latencies: Distribution of device to host transfer latency for each chunk of data for multislice traffic.

These metrics are located in the Kubernetes container (k8s_container) schema:

  • kubernetes.io/container/multislice/network/dcn_transfer_latencies
  • kubernetes.io/container/multislice/network/collective_end_to_end_latencies
  • kubernetes.io/container/multislice/accelerator/host_to_device_transfer_latencies
  • kubernetes.io/container/multislice/accelerator/device_to_host_transfer_latencies

TPU slice versus Multislice

The following table differentiates the architectural organization of a TPU slice and a Multislice:

TPU slice Multislice
Interconnectivity The workload runs on a single TPU slice. All TPU chips in a slice are connected with ICI. The workload runs on multiple TPU slices. Communication within a slice happens over ICI. Communication between slices occurs over DCN.
Supported node pools Single-host TPU slice and multi-host TPU slice Groups of multi-host TPU slices
Recommended workload type IndexedJob or JobSet JobSet

What's next