本教程演示了如何在 Google Kubernetes Engine (GKE) 上编排多个多切片工作负载,以提高资源利用率。您可以部署一个 Jax 工作负载作为示例,在 TPU 多切片上运行该工作负载,并使用 JobSet 和 Kueue 实现 Job 队列。Kueue 会根据可用资源、配额和层次结构来确定 Job 应何时运行,以在团队之间实现公平共享。
本教程适用于希望使用 Kubernetes 的容器编排功能训练 LLM 的机器学习 (ML) 工程师、平台管理员和运维人员。如需详细了解我们在 Google Cloud 内容中提及的常见角色和示例任务,请参阅常见的 GKE Enterprise 用户角色和任务。
在阅读本页面之前,请确保您熟悉以下内容:
- Cloud TPU 系统架构中的当前 TPU 版本可用性
- GKE 中的 TPU 多切片
目标
- 使用具有三个 v5e TPU 切片的 GKE 集群准备环境。每个 TPU 切片都有一个包含 8 个芯片的
2x4
拓扑。因此,总共 24 个 TPU v5e TPU 芯片。 - 创建 Kueue 资源,以确保在工作负载之间公平共享配额。
- 运行多切片工作负载。
准备工作
在开始之前,请确保您已执行以下任务:
- 启用 Google Kubernetes Engine API。 启用 Google Kubernetes Engine API
- 如果您要使用 Google Cloud CLI 执行此任务,请安装并初始化 gcloud CLI。 如果您之前安装了 gcloud CLI,请运行
gcloud components update
以获取最新版本。
准备环境
在 Google Cloud 控制台中,启动 Cloud Shell 实例:
打开 Cloud Shell设置默认环境变量:
gcloud config set project PROJECT_ID gcloud config set compute/region COMPUTE_REGION
替换以下值:
- PROJECT_ID:您的 Google Cloud 项目 ID。
- COMPUTE_REGION:Compute Engine 区域。
默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。Autopilot 集群上的 TPU 在工作负载规范中进行配置。如需了解详情,请参阅使用 JobSet 定义多切片工作负载部分。
创建 GKE 集群
在 Cloud Shell 中创建一个 GKE 集群:
Autopilot
gcloud container clusters create-auto multislice-cluster \
--location=LOCATION \
--cluster-version 1.29.2-gke.1521000 \
--release-channel rapid
标准
gcloud container clusters create multislice-cluster \
--location=LOCATION
将 LOCATION 替换为您要在其中创建集群的位置。确保集群有足够的容量用于 ct5lp-hightpu-4t
机器类型。集群创建过程可能需要几分钟时间。
如果您使用 GKE Autopilot 模式,请跳至创建 Kueue 资源部分。默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。
创建三个 Standard 模式 TPU 切片节点池
创建第一个节点池,名为
nodepool1
:gcloud beta container node-pools create nodepool1 \ --location=LOCATION \ --cluster=multislice-cluster \ --node-locations=NODE_LOCATION \ --machine-type=ct5lp-hightpu-4t \ --tpu-topology=2x4 \ --project=PROJECT_ID
将 NODE_LOCATION 替换为您要在其中创建节点的集群区域中的一个或多个可用区。
创建第二个节点池,名为
nodepool2
:gcloud beta container node-pools create nodepool2 \ --location=LOCATION \ --cluster=multislice-cluster \ --node-locations=NODE_LOCATION \ --machine-type=ct5lp-hightpu-4t \ --tpu-topology=2x4 \ --project=PROJECT_ID
创建第三个节点池,名为
nodepool3
:gcloud beta container node-pools create nodepool3 \ --location=LOCATION \ --cluster=multislice-cluster \ --node-locations=NODE_LOCATION \ --machine-type=ct5lp-hightpu-4t \ --tpu-topology=2x4 \ --project=PROJECT_ID
GKE 会创建三个节点池。每个节点池都是一个单独的 TPU 切片。
创建 Kueue 资源
创建以下
kueue.yaml
清单:apiVersion: kueue.x-k8s.io/v1beta1 kind: ResourceFlavor metadata: name: "vlp-24" spec: nodeLabels: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 --- apiVersion: kueue.x-k8s.io/v1beta1 kind: ClusterQueue metadata: name: "cluster-queue" spec: namespaceSelector: {} queueingStrategy: BestEffortFIFO resourceGroups: - coveredResources: ["google.com/tpu"] flavors: - name: "vlp-24" resources: - name: "google.com/tpu" nominalQuota: 24 --- apiVersion: kueue.x-k8s.io/v1beta1 kind: LocalQueue metadata: namespace: default name: multislice-queue spec: clusterQueue: cluster-queue
应用
kueue.yaml
清单:kubectl apply -f kueue.yaml
GKE 会创建以下 Kueue 资源:
- ResourceFlavor:集群中资源的抽象。在此示例中,GKE 会创建三个具有
2x4
拓扑的 TPU 切片。每个 TPU 切片都有一个包含 8 个芯片的2x4
拓扑(总共 24 个 TPU 芯片)。 - ClusterQueue:管理工作负载和集群资源的全局队列。
- LocalQueue:对密切相关的工作负载分组,这些工作负载通常由单个租户(用户)运行。每个 LocalQueue 都指向一个 ClusterQueue,系统会从 ClusterQueue 分配资源以运行其工作负载。Kueue 工作负载是一种表示批量工作负载的抽象,在这种情况下,每个工作负载都是一个 JobSet。
使用 JobSet 定义多切片工作负载
在本部分中,您将创建三个 JobSet。Jobset 是一种工作负载 API,可让您将一组 Kubernetes Job 作为一个单元进行管理。JobSet 最常见的应用场景是分布式训练,但您也可以使用它来运行批量工作负载。
以下 JobSet 运行 Jax 工作负载,该工作负载会输出切片中全局数量的 TPU 芯片,然后休眠 60 秒以模拟一些模型训练时间,然后退出。
创建以下
jobsets-multislice.yaml
清单:Autopilot
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-1slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 1 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' resources: limits: google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-2slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 2 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' sleep 60 resources: limits: google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-3slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 3 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 command: - bash - -c - | sleep 60 resources: limits: google.com/tpu: 4
标准
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-1slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 1 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 securityContext: privileged: true command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' resources: limits: google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-2slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 2 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 securityContext: privileged: true command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' sleep 60 resources: limits: google.com/tpu: 4 --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-3slice labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 3 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 securityContext: privileged: true command: - bash - -c - | sleep 60 resources: limits: google.com/tpu: 4
应用
jobsets-multislice.yaml
清单:kubectl apply -f jobsets-multislice.yaml
GKE 会使用以下资源请求创建 Job:
multislice-1slice
JobSet 创建一个 Job,总共需要一个 TPU 切片。multislice-2slice
JobSet 创建两个 Job,总共需要两个 TPU 切片。multislice-3slice
JobSet 创建三个 Job,总共需要三个 TPU 切片。
由于集群只有三个 TPU 切片,因此并非所有 JobSet 可以同时运行。当 Kueue 将所有三个 multislice-3slice
JobSet 加入队列时,其 Job 会单独运行以完成。multislice-1slice
和 multislice-2slice
会等待,然后一起运行。
验证 Kueue 是否已允许工作负载
检查 Kueue 中已加入队列的工作负载:
kubectl get workloads
输出类似于以下内容:
NAME QUEUE ADMITTED BY AGE jobset-multislice-1slice-2530a multislice-queue 3s jobset-multislice-2slice-ffb02 multislice-queue 4s jobset-multislice-3slice-8c695 multislice-queue cluster-queue 10s
Kueue 会将一个或多个工作负载加入队列,具体取决于它们所需的 TPU 资源。
监控工作负载
监控哪些 pod 正在运行:
kubectl get pods
输出类似于以下内容:
NAME READY STATUS RESTARTS AGE multislice-1slice-slice-0-0-pf2ll 1/1 Running 0 1s multislice-1slice-slice-0-1-55g62 1/1 Running 0 1s multislice-2slice-slice-0-0-f4hf7 1/1 Running 0 3s multislice-2slice-slice-0-1-c8kv7 1/1 Running 0 3s multislice-2slice-slice-1-0-7h46t 1/1 Running 0 3s multislice-2slice-slice-1-1-lj9hb 1/1 Running 0 3s multislice-3slice-slice-0-0-wzq9t 0/1 Completed 0 2m31s multislice-3slice-slice-0-1-zf4dp 0/1 Completed 0 2m30s multislice-3slice-slice-1-0-hbfn5 0/1 Completed 0 2m31s multislice-3slice-slice-1-1-45fgl 0/1 Completed 0 2m30s multislice-3slice-slice-2-0-wjbp4 0/1 Completed 0 2m30s multislice-3slice-slice-2-1-lwnvs 0/1 Completed 0 2m30s
看到 GKE 先为
multislice-3slice
安排、创建和运行了 pod。然后,GKE 运行了来自multislice-1slice
和multislice-2slice
JobSet 的 Pod。
启用 Kueue 工作负载优先级和抢占
(可选)您可以为 Kueue 工作负载分配优先级,以确定 Kueue 允许已加入队列的工作负载的顺序。
更新
ClusterQueue
,使其具有抢占政策:apiVersion: kueue.x-k8s.io/v1beta1 kind: ResourceFlavor metadata: name: "vlp-24" spec: nodeLabels: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 --- apiVersion: kueue.x-k8s.io/v1beta1 kind: ClusterQueue metadata: name: "cluster-queue" spec: namespaceSelector: {} resourceGroups: - coveredResources: ["google.com/tpu"] flavors: - name: "vlp-24" resources: - name: "google.com/tpu" nominalQuota: 24 preemption: reclaimWithinCohort: Any withinClusterQueue: LowerPriority --- apiVersion: kueue.x-k8s.io/v1beta1 kind: LocalQueue metadata: namespace: default name: multislice-queue spec: clusterQueue: cluster-queue
为要分配给工作负载的每个不同的优先级创建一个
PriorityClass
:apiVersion: scheduling.k8s.io/v1 kind: PriorityClass metadata: name: low-priority value: 100 globalDefault: false description: "This low priority class should be used for some Pods only."
为 JobSet 分配
priorityClassName
:Autopilot
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: low-priority labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 1 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 priorityClassName: low-priority containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 command: - bash - -c - | sleep 60 resources: limits: google.com/tpu: 4 # Number of TPU chips per worker
标准
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: low-priority labels: kueue.x-k8s.io/queue-name: multislice-queue annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: 1 template: spec: parallelism: 2 completions: 2 backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 priorityClassName: low-priority containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 securityContext: privileged: true command: - bash - -c - | sleep 60 resources: limits: google.com/tpu: 4 # Number of TPU chips per worker ```
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
删除项目
- In the Google Cloud console, go to the Manage resources page.
- In the project list, select the project that you want to delete, and then click Delete.
- In the dialog, type the project ID, and then click Shut down to delete the project.
逐个删除资源
删除 Kueue 配额系统:
kubectl delete -n team-a localqueue kubectl delete -n team-b localqueue kubectl delete clusterqueue kubectl delete clusterqueue kubectl delete clusterqueue kubectl delete resourceflavor kubectl delete resourceflavor kubectl delete resourceflavor
删除 Kueue 清单:
VERSION=kueue.x-k8s.io/v1beta1 kubectl delete -f \ https://github.com/kubernetes-sigs/kueue/releases/download/$VERSION/manifests.yaml
删除集群:
gcloud container clusters delete kueue-cohort --region=COMPUTE_REGION
后续步骤
- 详细了解 Kueue。
- 了解如何在 GKE 上通过命名空间之间的配额共享实现 Job 排队系统。