使用 JAX 在 Cloud TPU VM 上執行計算

本文簡要說明如何使用 JAX 和 Cloud TPU。

事前準備

執行本文中的指令前,您必須建立 Google Cloud帳戶、安裝 Google Cloud CLI,並設定 gcloud 指令。詳情請參閱「設定 Cloud TPU 環境」。

使用 gcloud 建立 Cloud TPU VM

  1. 定義一些環境變數,方便使用指令。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east5-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱「TPU 地區和區域」一文。
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 請在 Cloud Shell 或已安裝 Google Cloud CLI 的電腦終端機中執行下列指令,建立 TPU VM。

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$RUNTIME_VERSION

連線至 Cloud TPU VM

使用下列指令,透過 SSH 連線至 TPU VM:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE

如果您無法使用 SSH 連線至 TPU VM,可能是因為 TPU VM 沒有外部 IP 位址。如要存取沒有外部 IP 位址的 TPU VM,請按照「連線至沒有公開 IP 位址的 TPU VM」中的操作說明進行。

在 Cloud TPU VM 上安裝 JAX

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

系統檢查

確認 JAX 可存取 TPU,並執行基本運算:

  1. 啟動 Python 3 解譯器:

    (vm)$ python3
    >>> import jax
  2. 顯示可用的 TPU 核心數量:

    >>> jax.device_count()

系統會顯示 TPU 核心數。顯示的核心數量取決於您使用的 TPU 版本。詳情請參閱「TPU 版本」。

執行計算

>>> jax.numpy.add(1, 1)

系統會顯示 numpy 加法結果:

指令的輸出內容:

Array(2, dtype=int32, weak_type=True)

結束 Python 解譯器

>>> exit()

在 TPU VM 上執行 JAX 程式碼

您現在可以執行任何想要的 JAX 程式碼。Flax 範例是您在 JAX 中執行標準機器學習模型的絕佳起點。例如,訓練基本 MNIST 卷積神經網路:

  1. 安裝 Flax 範例依附元件:

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. 安裝 Flax:

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. 執行 Flax MNIST 訓練指令碼:

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
        --config=configs/default.py \
        --config.learning_rate=0.05 \
        --config.num_epochs=5

指令碼會下載資料集並開始訓練。指令碼輸出內容應如下所示:

I0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

清除所用資源

如要避免系統向您的 Google Cloud 帳戶收取您在本頁所用資源的費用,請按照下列步驟操作。

使用完 TPU VM 後,請按照下列步驟清理資源。

  1. 如果您尚未中斷與 Cloud TPU 執行個體的連線,請中斷連線:

    (vm)$ exit

    畫面上的提示現在應為 username@projectname,表示您正在 Cloud Shell 中。

  2. 刪除 Cloud TPU:

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE
  3. 執行下列指令,確認資源已刪除。確認 TPU 已從清單中移除。刪除作業需要幾分鐘的時間才能完成。

    $ gcloud compute tpus tpu-vm list \
        --zone=$ZONE

效能注意事項

以下是幾項與在 JAX 中使用 TPU 特別相關的重要細節。

填充

TPU 效能緩慢最常見的原因之一,是引入了無意間的邊框:

  • Cloud TPU 之中的陣列為並排顯示,這會導致其中一個維度填充為 8 的倍數,另一個維度填充為 128 的倍數。
  • 矩陣相乘單元最適合處理一對大型矩陣,以盡量減少填充的需求。

bfloat16 dtype

根據預設,在 TPU 上執行 JAX 的矩陣乘法會使用 bfloat16 搭配 float32 累加。您可以使用相關 jax.numpy 函式呼叫 (matmul、dot、einsum 等) 的運算精確度引數來控制這項功能。特別是:

  • precision=jax.lax.Precision.DEFAULT:使用混合 bfloat16 精確度 (最快)
  • precision=jax.lax.Precision.HIGH:使用多個 MXU 傳遞來達成更高的精確度
  • precision=jax.lax.Precision.HIGHEST:使用更多 MXU 傳遞次數,以便達到完整的 float32 精確度

JAX 也新增了 bfloat16 dtype,可用於明確將陣列轉換為 bfloat16。例如:jax.numpy.array(x, dtype=jax.numpy.bfloat16)

後續步驟

如要進一步瞭解 Cloud TPU,請參閱: