在 TPU 配量上執行 JAX 程式碼

執行本文件中的指令前,請務必按照「設定帳戶和 Cloud TPU 專案」中的操作說明操作。

在單一 TPU 板上執行 JAX 程式碼後,您可以透過在 TPU 配量上執行程式碼來擴大程式碼。TPU 配量是指透過專用高速網路連線,彼此相連的多個 TPU 板。本文將介紹如何在 TPU 配量上執行 JAX 程式碼。如需更深入的資訊,請參閱「在多主機和多程序環境中使用 JAX」。

建立 Cloud TPU 分片

  1. 建立一些環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5litepod-32
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱「TPU 地區和區域」一文。
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用 gcloud 指令建立 TPU 切片。舉例來說,如要建立 v5litepod-32 配量,請使用下列指令:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE}  \
        --version=${RUNTIME_VERSION} 

在切片中安裝 JAX

建立 TPU 配量後,您必須在 TPU 配量中的所有主機上安裝 JAX。您可以使用 gcloud compute tpus tpu-vm ssh 指令,並使用 --worker=all--commamnd 參數執行這項操作。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

在配量上執行 JAX 程式碼

如要在 TPU 配量上執行 JAX 程式碼,您必須在 TPU 配量中的每個主機上執行程式碼jax.device_count() 呼叫會停止回應,直到在切片中的每個主機上呼叫為止。以下範例說明如何在 TPU 切片上執行 JAX 計算。

準備程式碼

您需要 gcloud 344.0.0 以上版本 (適用於 scp 指令)。使用 gcloud --version 檢查 gcloud 版本,並視需要執行 gcloud components upgrade

建立名為 example.py 的檔案,並加入以下程式碼:


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py 複製到區塊中的所有 TPU 工作站 VM

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

如果您先前未使用 scp 指令,可能會看到類似下列的錯誤訊息:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

如要解決錯誤,請按照錯誤訊息中顯示的內容執行 ssh-add 指令,然後重新執行指令。

在切片上執行程式碼

在每個 VM 上啟動 example.py 程式:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

輸出內容 (使用 v5litepod-32 切片產生):

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

清除所用資源

使用完 TPU VM 後,請按照下列步驟清理資源。

  1. 刪除 Cloud TPU 和 Compute Engine 資源。

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. 執行 gcloud compute tpus execution-groups list 來驗證資源是否已刪除。刪除作業可能需要幾分鐘才能完成。下列指令的輸出內容不應包含本教學課程中建立的任何資源:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}