Run a small batch workload with TPUs and flex-start provisioning mode


This guide shows you how to optimize TPU provisioning for medium- and small-scale training workloads by using flex-start provisioning mode. In this guide, you use flex-start to deploy a workload that consists of a TPU slice node pool.

This guide is intended for Machine learning (ML) engineers, Platform admins and operators, and for Data and AI specialists who are interested in using Kubernetes container orchestration capabilities for running batch workloads. For more information about common roles and example tasks that we reference in Google Cloud content, see Common GKE Enterprise user roles and tasks.

Before you begin

Before you start, make sure you have performed the following tasks:

  • Enable the Google Kubernetes Engine API.
  • Enable Google Kubernetes Engine API
  • If you want to use the Google Cloud CLI for this task, install and then initialize the gcloud CLI. If you previously installed the gcloud CLI, get the latest version by running gcloud components update.
  • Verify that you have an Autopilot cluster or a Standard cluster that's running version 1.33.0-gke.1712000 or later.
  • Verify that you're familiar with limitations of flex-start.
  • When using a Standard cluster, verify that you maintain at least one node pool without flex-start enabled for the cluster to function correctly.
  • Verify that you have quota for preemptible TPUs in your node locations.

Create a node pool with flex-start

To create a node pool with flex-start enabled on an existing Standard cluster, use the gcloud CLI.

If you use a cluster in Autopilot mode, skip this section and go to the Run a batch workload section.

You can create a single or multi-host TPU slice node pool with flex-start:

  1. Create a node pool with flex-start:

    Single-host

    gcloud container node-pools create NODE_POOL_NAME \
        --cluster=CLUSTER_NAME \
        --location=LOCATION_NAME \
        --node-locations=NODE_ZONES \
        --machine-type=MACHINE_TYPE \
        --reservation-affinity=none \
        --enable-autoscaling \
        --flex-start \
        --num-nodes 0 \
        --min-nodes=0 \
        --max-nodes=1
    

    Replace the following:

    • NODE_POOL_NAME: the name you choose for your node pool.
    • CLUSTER_NAME: the name of the cluster.
    • LOCATION_NAME: the compute region for the cluster control plane.
    • NODE_ZONES: the comma-separated list of one or more zones where GKE creates the node pool.
    • MACHINE_TYPE: the type of machine to use for nodes. For more information about TPU compatible machine types, use the table in Choose the TPU version.

    Multi-host

    gcloud container node-pools create NODE_POOL_NAME \
        --cluster=CLUSTER_NAME \
        --location=LOCATION_NAME \
        --node-locations=NODE_ZONES \
        --machine-type=MACHINE_TYPE \
        --tpu-topology=TPU_TOPOLOGY \
        --flex-start \
        --enable-autoscaling \
        --num-nodes=0 \
        --max-nodes=2 \
        --reservation-affinity=none \
        --no-enable-autorepair
    

    Replace the following:

    • NODE_POOL_NAME: the name you choose for your node pool.
    • CLUSTER_NAME: the name of the cluster.
    • LOCATION_NAME: the compute region for the cluster control plane.
    • NODE_ZONES: the comma-separated list of one or more zones where GKE creates the node pool.
    • MACHINE_TYPE: the type of machine to use for nodes. For example, you can use ct6e-standard-4t for TPU Trillium. To learn more about the available machine types, see Choose the TPU version.
    • TPU_TOPOLOGY: the physical topology for the TPU slice. The format of the topology depends on the TPU version. To learn more about TPU topologies, use the table in Choose a topology.

    The preceding command uses the following required flags when you create a node pool with flex-start:

    • --enable-autoscaling: flex-start provisions only the necessary compute resources when your workload runs. You must set following parameters:

      • --num-nodes=0
      • --min-nodes=0
      • --max-nodes set to the number of virtual machines that your TPU slice requires.

        For example, your node pool creation command can include the following parameters:

        ...
        --machine-type=ct6e-standard-4t \
        --tpu-topology=4x4 \
        --enable-autoscaling \
        --num-nodes=0 \
        --max-nodes=4 \
        

        This command sets the --max-nodes field to 4 because a 4x4 topology consists of 16 chips and each ct6e-standard-4t VM has 4 chips.

      Cluster autoscaler scales up to the number of nodes that your workload requires. After your workload completes, cluster autoscaler scales down to zero nodes.

    • --reservation-affinity=none: flex-start doesn't use or require reservations.

  2. Verify the status of flex-start in the node pool:

    gcloud container node-pools describe NODE_POOL_NAME \
        --cluster CLUSTER_NAME \
        --location LOCATION_NAME \
        --format="get(config.flexStart)"
    

    If flex-start is enabled in the node pool, the field flexStart is set to True.

Run a batch workload

In this section, you create a Job that schedules a TPU node with flex-start. A Job controller in Kubernetes creates one or more Pods and ensures that they successfully execute a specific task.

  1. In the Google Cloud console, launch a Cloud Shell session by clicking Cloud Shell activation icon Activate Cloud Shell. A session opens in the bottom pane of the Google Cloud console.

  2. Create a file named dws-flex-start.yaml:

    Single-host

    Use the following manifest for the dws-flex-start.yaml file:

    apiVersion: batch/v1
    kind: Job
    metadata:
    name: job-1
    spec:
    template:
      spec:
        nodeSelector:
          cloud.google.com/gke-flex-start: "true"
          cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE
          cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY
        containers:
        - name: container-1
          image: gcr.io/k8s-staging-perf-tests/sleep:latest
          args: ["3600s"] # Sleep for 1hour
          resources:
            requests:
                google.com/tpu: NUM_CHIPS
            limits:
                google.com/tpu: NUM_CHIPS
        restartPolicy: OnFailure
    

    Multi-host

    Use the following manifest for the dws-flex-start.yaml file:

    apiVersion: v1
    kind: Service
    metadata:
    name: headless-svc
    spec:
    clusterIP: None
    selector:
      job-name: job-1
    ---
    apiVersion: batch/v1
    kind: Job
    metadata:
    name: job-1
    spec:
    backoffLimit: 0
    completions: 2
    parallelism: 2
    completionMode: Indexed
    template:
      spec:
        subdomain: headless-svc
        restartPolicy: Never
        nodeSelector:
            cloud.google.com/gke-flex-start: "true"
            cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE
            cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY
        containers:
        - name: tpu-job
          image: python:3.10
          ports:
          - containerPort: 8471 # Default port using which TPU VMs communicate
          - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
          securityContext:
            privileged: true
          command:
          - bash
          - -c
          - |
            pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
            python -c 'import jax; print("TPU cores:", jax.device_count())'
          resources:
            requests:
              google.com/tpu: NUM_CHIPS
            limits:
              google.com/tpu: NUM_CHIPS
    

    Replace the following:

    • ACCELERATOR_TYPE: the type of TPU accelerator you used when you created the node pools. For example, tpu-v4-podslice or tpu-v5-lite-podslice.
    • TPU_TOPOLOGY: the physical topology for the TPU slice. For example, the value might be 4x4x4 or 2x2, depending on the TPU version.
    • NUM_CHIPS: the number of TPU chips in each VM is one, four, or eight. To learn more, see TPU versions.
  3. Apply the dws-flex-start.yaml manifest:

    kubectl apply -f dws-flex-start.yaml
    
  4. Verify that the Jobs are running on the same node:

    kubectl get pods
    

    The output is similar to the following:

    NAME    READY   STATUS      RESTARTS   AGE   IP       NODE               NOMINATED NODE   READINESS GATES
    job-1   0/1     Completed   0          19m   10.(...) gke-flex-zonal-a2  <none>           <none>
    

Clean up

To avoid incurring charges to your Google Cloud account for the resources that you used on this page, either delete the project that contains the resources, or keep the project and delete the individual resources.

Delete the project

  1. In the Google Cloud console, go to the Manage resources page.

    Go to Manage resources

  2. In the project list, select the project that you want to delete, and then click Delete.
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

Delete the individual resource

  1. Delete the Jobs:

    kubectl delete job -l "job-name in (job-1,job-2)"
    
  2. Delete the node pool:

    gcloud container node-pools delete NODE_POOL_NAME \
          --location LOCATION_NAME
    
  3. Delete the cluster:

    gcloud container clusters delete CLUSTER_NAME
    

What's next