使用 JobSet 和 Kueue 编排多切片工作负载


本教程演示了如何在 Google Kubernetes Engine (GKE) 上编排多个多切片工作负载,以提高资源利用率。您可以部署一个 Jax 工作负载作为示例,在 TPU 多切片上运行该工作负载,并使用 JobSet 和 Kueue 实现 Job 队列。Kueue 会根据可用资源、配额和层次结构来确定 Job 应何时运行,以在团队之间实现公平共享。

本教程适用于希望使用 Kubernetes 的容器编排功能训练 LLM 的机器学习 (ML) 工程师、平台管理员和运维人员。如需详细了解我们在 Google Cloud 内容中提及的常见角色和示例任务,请参阅常见的 GKE Enterprise 用户角色和任务

在阅读本页面之前,请确保您熟悉以下内容:

目标

  1. 使用具有三个 v5e TPU 切片的 GKE 集群准备环境。每个 TPU 切片都有一个包含 8 个芯片的 2x4 拓扑。因此,总共 24 个 TPU v5e TPU 芯片。
  2. 创建 Kueue 资源,以确保在工作负载之间公平共享配额。
  3. 运行多切片工作负载。

准备工作

在开始之前,请确保您已执行以下任务:

  • 启用 Google Kubernetes Engine API。
  • 启用 Google Kubernetes Engine API
  • 如果您要使用 Google Cloud CLI 执行此任务,请安装初始化 gcloud CLI。 如果您之前安装了 gcloud CLI,请运行 gcloud components update 以获取最新版本。

准备环境

  1. 在 Google Cloud 控制台中,启动 Cloud Shell 实例:
    打开 Cloud Shell

  2. 设置默认环境变量:

    gcloud config set project PROJECT_ID
    gcloud config set compute/region COMPUTE_REGION
    

    替换以下值:

默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。Autopilot 集群上的 TPU 在工作负载规范中进行配置。如需了解详情,请参阅使用 JobSet 定义多切片工作负载部分。

创建 GKE 集群

在 Cloud Shell 中创建一个 GKE 集群:

Autopilot

gcloud container clusters create-auto multislice-cluster \
    --location=LOCATION \
    --cluster-version 1.29.2-gke.1521000 \
    --release-channel rapid

标准

gcloud container clusters create multislice-cluster \
    --location=LOCATION

LOCATION 替换为您要在其中创建集群的位置。确保集群有足够的容量用于 ct5lp-hightpu-4t 机器类型。集群创建过程可能需要几分钟时间。

如果您使用 GKE Autopilot 模式,请跳至创建 Kueue 资源部分。默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。

创建三个 Standard 模式 TPU 切片节点池

  1. 创建第一个节点池,名为 nodepool1

    gcloud beta container node-pools create nodepool1 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --project=PROJECT_ID
    

    NODE_LOCATION 替换为您要在其中创建节点的集群区域中的一个或多个可用区。

  2. 创建第二个节点池,名为 nodepool2

    gcloud beta container node-pools create nodepool2 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --project=PROJECT_ID
    
  3. 创建第三个节点池,名为 nodepool3

    gcloud beta container node-pools create nodepool3 \
        --location=LOCATION \
        --cluster=multislice-cluster \
        --node-locations=NODE_LOCATION \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=2x4 \
        --project=PROJECT_ID
    

GKE 会创建三个节点池。每个节点池都是一个单独的 TPU 切片。

创建 Kueue 资源

  1. 创建以下 kueue.yaml 清单:

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      queueingStrategy: BestEffortFIFO
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
    
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. 应用 kueue.yaml 清单:

    kubectl apply -f kueue.yaml
    

    GKE 会创建以下 Kueue 资源:

  • ResourceFlavor:集群中资源的抽象。在此示例中,GKE 会创建三个具有 2x4 拓扑的 TPU 切片。每个 TPU 切片都有一个包含 8 个芯片的 2x4 拓扑(总共 24 个 TPU 芯片)。
  • ClusterQueue:管理工作负载和集群资源的全局队列。
  • LocalQueue:对密切相关的工作负载分组,这些工作负载通常由单个租户(用户)运行。每个 LocalQueue 都指向一个 ClusterQueue,系统会从 ClusterQueue 分配资源以运行其工作负载。Kueue 工作负载是一种表示批量工作负载的抽象,在这种情况下,每个工作负载都是一个 JobSet。

使用 JobSet 定义多切片工作负载

在本部分中,您将创建三个 JobSet。Jobset 是一种工作负载 API,可让您将一组 Kubernetes Job 作为一个单元进行管理。JobSet 最常见的应用场景是分布式训练,但您也可以使用它来运行批量工作负载。

以下 JobSet 运行 Jax 工作负载,该工作负载会输出切片中全局数量的 TPU 芯片,然后休眠 60 秒以模拟一些模型训练时间,然后退出。

  1. 创建以下 jobsets-multislice.yaml 清单:

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    

    标准

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-1slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                    resources:
                      limits:
                        google.com/tpu: 4
    
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-2slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    ---
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-3slice
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 3
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4
    
  2. 应用 jobsets-multislice.yaml 清单:

    kubectl apply -f jobsets-multislice.yaml
    

GKE 会使用以下资源请求创建 Job:

  • multislice-1slice JobSet 创建一个 Job,总共需要一个 TPU 切片。
  • multislice-2slice JobSet 创建两个 Job,总共需要两个 TPU 切片。
  • multislice-3slice JobSet 创建三个 Job,总共需要三个 TPU 切片。

由于集群只有三个 TPU 切片,因此并非所有 JobSet 可以同时运行。当 Kueue 将所有三个 multislice-3slice JobSet 加入队列时,其 Job 会单独运行以完成。multislice-1slicemultislice-2slice 会等待,然后一起运行。

验证 Kueue 是否已允许工作负载

  1. 检查 Kueue 中已加入队列的工作负载:

    kubectl get workloads
    

    输出类似于以下内容:

    NAME                             QUEUE              ADMITTED BY     AGE
    jobset-multislice-1slice-2530a   multislice-queue                   3s
    jobset-multislice-2slice-ffb02   multislice-queue                   4s
    jobset-multislice-3slice-8c695   multislice-queue   cluster-queue   10s
    

Kueue 会将一个或多个工作负载加入队列,具体取决于它们所需的 TPU 资源。

监控工作负载

  1. 监控哪些 pod 正在运行:

    kubectl get pods
    

    输出类似于以下内容:

    NAME                                READY   STATUS      RESTARTS   AGE
    multislice-1slice-slice-0-0-pf2ll   1/1     Running     0          1s
    multislice-1slice-slice-0-1-55g62   1/1     Running     0          1s
    multislice-2slice-slice-0-0-f4hf7   1/1     Running     0          3s
    multislice-2slice-slice-0-1-c8kv7   1/1     Running     0          3s
    multislice-2slice-slice-1-0-7h46t   1/1     Running     0          3s
    multislice-2slice-slice-1-1-lj9hb   1/1     Running     0          3s
    multislice-3slice-slice-0-0-wzq9t   0/1     Completed   0          2m31s
    multislice-3slice-slice-0-1-zf4dp   0/1     Completed   0          2m30s
    multislice-3slice-slice-1-0-hbfn5   0/1     Completed   0          2m31s
    multislice-3slice-slice-1-1-45fgl   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-0-wjbp4   0/1     Completed   0          2m30s
    multislice-3slice-slice-2-1-lwnvs   0/1     Completed   0          2m30s
    

    看到 GKE 先为 multislice-3slice 安排、创建和运行了 pod。然后,GKE 运行了来自 multislice-1slicemultislice-2slice JobSet 的 Pod。

启用 Kueue 工作负载优先级和抢占

(可选)您可以为 Kueue 工作负载分配优先级,以确定 Kueue 允许已加入队列的工作负载的顺序。

  1. 更新 ClusterQueue,使其具有抢占政策:

    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ResourceFlavor
    metadata:
      name: "vlp-24"
    spec:
      nodeLabels:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: ClusterQueue
    metadata:
      name: "cluster-queue"
    spec:
      namespaceSelector: {}
      resourceGroups:
      - coveredResources: ["google.com/tpu"]
        flavors:
        - name: "vlp-24"
          resources:
          - name: "google.com/tpu"
            nominalQuota: 24
     preemption:
        reclaimWithinCohort: Any
        withinClusterQueue: LowerPriority
    ---
    apiVersion: kueue.x-k8s.io/v1beta1
    kind: LocalQueue
    metadata:
      namespace: default
      name: multislice-queue
    spec:
      clusterQueue: cluster-queue
    
  2. 为要分配给工作负载的每个不同的优先级创建一个 PriorityClass

    apiVersion: scheduling.k8s.io/v1
    kind: PriorityClass
    metadata:
      name: low-priority
    value: 100
    globalDefault: false
    description: "This low priority class should be used for some Pods only."
    
  3. 为 JobSet 分配 priorityClassName

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
    

    标准

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: low-priority
      labels:
        kueue.x-k8s.io/queue-name: multislice-queue
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 2x4
                  priorityClassName: low-priority
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      sleep 60
                    resources:
                      limits:
                        google.com/tpu: 4 # Number of TPU chips per worker
      ```
    

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除项目

  1. In the Google Cloud console, go to the Manage resources page.

    Go to Manage resources

  2. In the project list, select the project that you want to delete, and then click Delete.
  3. In the dialog, type the project ID, and then click Shut down to delete the project.

逐个删除资源

  1. 删除 Kueue 配额系统:

    kubectl delete -n team-a localqueue
    kubectl delete -n team-b localqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete clusterqueue
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    kubectl delete resourceflavor
    
  2. 删除 Kueue 清单:

    VERSION=kueue.x-k8s.io/v1beta1
    kubectl delete -f \
        https://github.com/kubernetes-sigs/kueue/releases/download/$VERSION/manifests.yaml
    
  3. 删除集群:

    gcloud container clusters delete kueue-cohort --region=COMPUTE_REGION
    

后续步骤