在 TPU 配量上執行 JAX 程式碼
執行本文件中的指令前,請務必按照「設定帳戶和 Cloud TPU 專案」中的操作說明操作。
在單一 TPU 板上執行 JAX 程式碼後,您可以透過在 TPU 配量上執行程式碼來擴大程式碼。TPU 配量是指透過專用高速網路連線,彼此相連的多個 TPU 板。本文將介紹如何在 TPU 配量上執行 JAX 程式碼。如需更深入的資訊,請參閱「在多主機和多程序環境中使用 JAX」。
建立 Cloud TPU 分片
建立一些環境變數:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
使用
gcloud
指令建立 TPU 切片。舉例來說,如要建立 v5litepod-32 配量,請使用下列指令:$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
在切片中安裝 JAX
建立 TPU 配量後,您必須在 TPU 配量中的所有主機上安裝 JAX。您可以使用 gcloud compute tpus tpu-vm ssh
指令,並使用 --worker=all
和 --commamnd
參數執行這項操作。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
在配量上執行 JAX 程式碼
如要在 TPU 配量上執行 JAX 程式碼,您必須在 TPU 配量中的每個主機上執行程式碼。jax.device_count()
呼叫會停止回應,直到在切片中的每個主機上呼叫為止。以下範例說明如何在 TPU 切片上執行 JAX 計算。
準備程式碼
您需要 gcloud
344.0.0 以上版本 (適用於 scp
指令)。使用 gcloud --version
檢查 gcloud
版本,並視需要執行 gcloud components upgrade
。
建立名為 example.py
的檔案,並加入以下程式碼:
import jax
# The total number of TPU cores in the slice
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
將 example.py
複製到區塊中的所有 TPU 工作站 VM
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
如果您先前未使用 scp
指令,可能會看到類似下列的錯誤訊息:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
如要解決錯誤,請按照錯誤訊息中顯示的內容執行 ssh-add
指令,然後重新執行指令。
在切片上執行程式碼
在每個 VM 上啟動 example.py
程式:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
輸出內容 (使用 v5litepod-32 切片產生):
global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]
清除所用資源
使用完 TPU VM 後,請按照下列步驟清理資源。
刪除 Cloud TPU 和 Compute Engine 資源。
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
執行
gcloud compute tpus execution-groups list
來驗證資源是否已刪除。刪除作業可能需要幾分鐘才能完成。下列指令的輸出內容不應包含本教學課程中建立的任何資源:$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}