使用 Autocheckpoint 保留訓練進度

以往,當 TPU VM 需要維護時,程序會立即啟動,使用者無法執行保留進度的動作,例如儲存檢查點。如圖 1(a) 所示。

圖表顯示主機維護作業在有無自動檢查點的情況下所造成的影響

圖 1. 自動查核點功能的插圖:(a) 如果沒有自動查核點,在即將發生維護事件時,系統會遺失上一個查核點的訓練進度。(b) 使用自動檢查點功能,在即將發生維護事件時,可以保留自上次檢查點以來的訓練進度。

您可以使用自動查核點 (圖 1(b)) 來保留訓練進度,方法是設定程式碼,在維護事件發生時儲存非排程查核點。當維護事件發生時,系統會自動儲存自上次檢查點以來的進度。這項功能適用於單一切片和多切片。

Autocheckpoint 功能可搭配可擷取 SIGTERM 信號的架構,並隨後儲存檢查點。支援的架構包括:

使用自動檢查點

自動檢查點功能預設為停用。建立 TPU 或要求排入佇列的資源時,您可以在佈建 TPU 時新增 --autocheckpoint-enabled 標記,啟用自動檢查點。啟用這項功能後,Cloud TPU 收到維護事件通知後,就會執行下列步驟:

  1. 使用 TPU 裝置,擷取傳送至程序的 SIGTERM 信號
  2. 等待程序結束或 5 分鐘過後,兩者取其先
  3. 對受影響的區塊執行維護作業

Autocheckpoint 使用的基礎架構不受 ML 架構限制。只要機器學習架構能夠擷取 SIGTERM 信號並啟動檢查點程序,就能支援自動檢查點。

您必須在應用程式程式碼中啟用機器學習架構提供的 Autocheckpoint 功能。舉例來說,在 Pax 中,這表示在啟動訓練時啟用指令列標記。詳情請參閱使用 Pax 的 Autocheckpoint 快速入門指南。在幕後,架構會在收到 SIGTERM 信號時儲存非排程檢查點,並在 TPU 不再使用時對受影響的 TPU VM 進行維護。

快速入門:使用 MaxText 自動檢查點

MaxText 是高效能、可任意擴充、開放原始碼且經過充分測試的 LLM,以純 Python/JAX 編寫,並以 Cloud TPU 為目標。MaxText 包含使用自動檢查點功能所需的所有設定。

MaxText README 檔案說明瞭兩種大規模執行 MaxText 的方式:

使用 multihost_runner.py 時,請在佈建佇列資源時設定 autocheckpoint-enabled 標記,啟用自動檢查點。

使用 multihost_job.py 時,請在啟動工作時指定 ENABLE_AUTOCHECKPOINT=true 指令列旗標,啟用自動檢查點。

快速入門:在單一切片上使用 Pax 自動檢查點

本節將舉例說明如何在單一區塊中設定及使用 Autocheckpoint 與 Pax。在適當的設定下:

  • 發生維護事件時,系統會儲存檢查點。
  • 儲存檢查點後,Cloud TPU 會對受影響的 TPU VM 執行維護作業。
  • Cloud TPU 完成維護作業後,您就可以照常使用 TPU VM。
  1. 建立 TPU VM 或要求排隊資源時,請使用 autocheckpoint-enabled 標記。

    例如:

    1. 設定環境變數:

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      環境變數說明

      變數 說明
      PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
      TPU_NAME TPU 的名稱。
      ZONE 建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱「TPU 地區和區域」一文。
      ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
      RUNTIME_VERSION Cloud TPU 軟體版本

    2. 在有效配置中設定專案 ID 和區域:

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. 建立 TPU:

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. 使用 SSH 連線至 TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. 在單一切片上安裝 Pax

    Autocheckpoint 功能適用於 Pax 1.1.0 以上版本。在 TPU VM 上安裝 jax[tpu] 和最新的 paxml

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. 設定 LmCloudSpmd2B 模型。執行訓練指令碼前,請將 ICI_MESH_SHAPE 變更為 [1, 8, 1]

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. 使用適當的設定啟動訓練。

    以下範例說明如何設定 LmCloudSpmd2B 模型,將 Autocheckpoint 觸發的檢查點儲存至 Cloud Storage 值區。將 your-storage-bucket 替換為現有值區的名稱,或建立新值區

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    請注意傳遞至指令的兩個標記:

    • jax_fully_async_checkpoint:開啟這個標記後,系統會使用 orbax.checkpoint.AsyncCheckpointer。當訓練指令碼收到 SIGTERM 信號時,AsyncCheckpointer 類別會自動儲存檢查點。
    • exit_after_ondemand_checkpoint:啟用這個標記後,Autocheckpoint 成功儲存後,TPU 程序就會結束,並觸發立即執行維護作業。如果您未使用這個標記,系統會在儲存查核點後繼續訓練,而 Cloud TPU 會等待逾時 (5 分鐘) 才執行必要的維護作業。

使用 Orbax 自動檢查點

自動檢查點功能不限於 MaxText 或 Pax。任何可擷取 SIGTERM 信號並啟動檢查點程序的架構,都會與 Autocheckpoint 提供的基礎架構搭配運作。Orbax 是提供常用公用程式庫的命名空間,可為 JAX 使用者提供這些功能。

Orbax 說明文件所述,系統會預設為啟用 orbax.checkpoint.CheckpointManager 使用者的這些功能。在每個步驟後呼叫的 save 方法會自動檢查是否有即將發生的維護事件,如果有,即使步驟編號不是 save_interval_steps 的倍數,也會儲存查核點。GitHub 說明文件也說明如何在儲存 Autocheckpoint 後讓訓練程序結束,並修改使用者程式碼。