使用 JAX 在 Cloud TPU VM 上執行計算
本文簡要說明如何使用 JAX 和 Cloud TPU。
事前準備
執行本文中的指令前,您必須建立 Google Cloud帳戶、安裝 Google Cloud CLI,並設定 gcloud
指令。詳情請參閱「設定 Cloud TPU 環境」。
使用 gcloud
建立 Cloud TPU VM
定義一些環境變數,方便使用指令。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
請在 Cloud Shell 或已安裝 Google Cloud CLI 的電腦終端機中執行下列指令,建立 TPU VM。
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
連線至 Cloud TPU VM
使用下列指令,透過 SSH 連線至 TPU VM:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
如果您無法使用 SSH 連線至 TPU VM,可能是因為 TPU VM 沒有外部 IP 位址。如要存取沒有外部 IP 位址的 TPU VM,請按照「連線至沒有公開 IP 位址的 TPU VM」中的操作說明進行。
在 Cloud TPU VM 上安裝 JAX
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
系統檢查
確認 JAX 可存取 TPU,並執行基本運算:
啟動 Python 3 解譯器:
(vm)$ python3
>>> import jax
顯示可用的 TPU 核心數量:
>>> jax.device_count()
系統會顯示 TPU 核心數。顯示的核心數量取決於您使用的 TPU 版本。詳情請參閱「TPU 版本」。
執行計算
>>> jax.numpy.add(1, 1)
系統會顯示 numpy 加法結果:
指令的輸出內容:
Array(2, dtype=int32, weak_type=True)
結束 Python 解譯器
>>> exit()
在 TPU VM 上執行 JAX 程式碼
您現在可以執行任何想要的 JAX 程式碼。Flax 範例是您在 JAX 中執行標準機器學習模型的絕佳起點。例如,訓練基本 MNIST 卷積神經網路:
安裝 Flax 範例依附元件:
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
安裝 Flax:
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
執行 Flax MNIST 訓練指令碼:
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
指令碼會下載資料集並開始訓練。指令碼輸出內容應如下所示:
I0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
清除所用資源
如要避免系統向您的 Google Cloud 帳戶收取您在本頁所用資源的費用,請按照下列步驟操作。
使用完 TPU VM 後,請按照下列步驟清理資源。
如果您尚未中斷與 Cloud TPU 執行個體的連線,請中斷連線:
(vm)$ exit
畫面上的提示現在應為 username@projectname,表示您正在 Cloud Shell 中。
刪除 Cloud TPU:
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
執行下列指令,確認資源已刪除。確認 TPU 已從清單中移除。刪除作業需要幾分鐘的時間才能完成。
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
效能注意事項
以下是幾項與在 JAX 中使用 TPU 特別相關的重要細節。
填充
TPU 效能緩慢最常見的原因之一,是引入了無意間的邊框:
- Cloud TPU 之中的陣列為並排顯示,這會導致其中一個維度填充為 8 的倍數,另一個維度填充為 128 的倍數。
- 矩陣相乘單元最適合處理一對大型矩陣,以盡量減少填充的需求。
bfloat16 dtype
根據預設,在 TPU 上執行 JAX 的矩陣乘法會使用 bfloat16 搭配 float32 累加。您可以使用相關 jax.numpy
函式呼叫 (matmul、dot、einsum 等) 的運算精確度引數來控制這項功能。特別是:
precision=jax.lax.Precision.DEFAULT
:使用混合 bfloat16 精確度 (最快)precision=jax.lax.Precision.HIGH
:使用多個 MXU 傳遞來達成更高的精確度precision=jax.lax.Precision.HIGHEST
:使用更多 MXU 傳遞次數,以便達到完整的 float32 精確度
JAX 也新增了 bfloat16 dtype,可用於明確將陣列轉換為 bfloat16
。例如:jax.numpy.array(x, dtype=jax.numpy.bfloat16)
。
後續步驟
如要進一步瞭解 Cloud TPU,請參閱: