使用 Autocheckpoint 保留訓練進度
以往,當 TPU VM 需要維護時,程序會立即啟動,使用者無法執行保留進度的動作,例如儲存檢查點。如圖 1(a) 所示。
圖 1. 自動查核點功能的插圖:(a) 如果沒有自動查核點,在即將發生維護事件時,系統會遺失上一個查核點的訓練進度。(b) 使用自動檢查點功能,在即將發生維護事件時,可以保留自上次檢查點以來的訓練進度。
您可以使用自動查核點 (圖 1(b)) 來保留訓練進度,方法是設定程式碼,在維護事件發生時儲存非排程查核點。當維護事件發生時,系統會自動儲存自上次檢查點以來的進度。這項功能適用於單一切片和多切片。
Autocheckpoint 功能可搭配可擷取 SIGTERM 信號的架構,並隨後儲存檢查點。支援的架構包括:
使用自動檢查點
自動檢查點功能預設為停用。建立 TPU 或要求排入佇列的資源時,您可以在佈建 TPU 時新增 --autocheckpoint-enabled
標記,啟用自動檢查點。啟用這項功能後,Cloud TPU 收到維護事件通知後,就會執行下列步驟:
- 使用 TPU 裝置,擷取傳送至程序的 SIGTERM 信號
- 等待程序結束或 5 分鐘過後,兩者取其先
- 對受影響的區塊執行維護作業
Autocheckpoint 使用的基礎架構不受 ML 架構限制。只要機器學習架構能夠擷取 SIGTERM 信號並啟動檢查點程序,就能支援自動檢查點。
您必須在應用程式程式碼中啟用機器學習架構提供的 Autocheckpoint 功能。舉例來說,在 Pax 中,這表示在啟動訓練時啟用指令列標記。詳情請參閱使用 Pax 的 Autocheckpoint 快速入門指南。在幕後,架構會在收到 SIGTERM 信號時儲存非排程檢查點,並在 TPU 不再使用時對受影響的 TPU VM 進行維護。
快速入門:使用 MaxText 自動檢查點
MaxText 是高效能、可任意擴充、開放原始碼且經過充分測試的 LLM,以純 Python/JAX 編寫,並以 Cloud TPU 為目標。MaxText 包含使用自動檢查點功能所需的所有設定。
MaxText README
檔案說明瞭兩種大規模執行 MaxText 的方式:
- 使用
multihost_runner.py
,建議用於實驗 - 使用
multihost_job.py
,建議用於正式環境
使用 multihost_runner.py
時,請在佈建佇列資源時設定 autocheckpoint-enabled
標記,啟用自動檢查點。
使用 multihost_job.py
時,請在啟動工作時指定 ENABLE_AUTOCHECKPOINT=true
指令列旗標,啟用自動檢查點。
快速入門:在單一切片上使用 Pax 自動檢查點
本節將舉例說明如何在單一區塊中設定及使用 Autocheckpoint 與 Pax。在適當的設定下:
- 發生維護事件時,系統會儲存檢查點。
- 儲存檢查點後,Cloud TPU 會對受影響的 TPU VM 執行維護作業。
- Cloud TPU 完成維護作業後,您就可以照常使用 TPU VM。
建立 TPU VM 或要求排隊資源時,請使用
autocheckpoint-enabled
標記。例如:
設定環境變數:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=zone-you-want-to-use export ACCELERATOR_TYPE=your-accelerator-type export RUNTIME_VERSION=tpu-ubuntu2204-base
在有效配置中設定專案 ID 和區域:
gcloud config set project $PROJECT_ID gcloud config set compute/zone $ZONE
建立 TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME \ --accelerator-type $ACCELERATOR_TYPE \ --version $RUNTIME_VERSION \ --autocheckpoint-enabled
使用 SSH 連線至 TPU:
gcloud compute tpus tpu-vm ssh $TPU_NAME
在單一切片上安裝 Pax
Autocheckpoint 功能適用於 Pax 1.1.0 以上版本。在 TPU VM 上安裝
jax[tpu]
和最新的paxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
設定
LmCloudSpmd2B
模型。執行訓練指令碼前,請將ICI_MESH_SHAPE
變更為[1, 8, 1]
:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 8, 1]
使用適當的設定啟動訓練。
以下範例說明如何設定
LmCloudSpmd2B
模型,將 Autocheckpoint 觸發的檢查點儲存至 Cloud Storage 值區。將 your-storage-bucket 替換為現有值區的名稱,或建立新值區。export JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py \ --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
請注意傳遞至指令的兩個標記:
jax_fully_async_checkpoint
:開啟這個標記後,系統會使用orbax.checkpoint.AsyncCheckpointer
。當訓練指令碼收到 SIGTERM 信號時,AsyncCheckpointer
類別會自動儲存檢查點。exit_after_ondemand_checkpoint
:啟用這個標記後,Autocheckpoint 成功儲存後,TPU 程序就會結束,並觸發立即執行維護作業。如果您未使用這個標記,系統會在儲存查核點後繼續訓練,而 Cloud TPU 會等待逾時 (5 分鐘) 才執行必要的維護作業。
使用 Orbax 自動檢查點
自動檢查點功能不限於 MaxText 或 Pax。任何可擷取 SIGTERM 信號並啟動檢查點程序的架構,都會與 Autocheckpoint 提供的基礎架構搭配運作。Orbax 是提供常用公用程式庫的命名空間,可為 JAX 使用者提供這些功能。
如Orbax 說明文件所述,系統會預設為啟用 orbax.checkpoint.CheckpointManager
使用者的這些功能。在每個步驟後呼叫的 save
方法會自動檢查是否有即將發生的維護事件,如果有,即使步驟編號不是 save_interval_steps
的倍數,也會儲存查核點。GitHub 說明文件也說明如何在儲存 Autocheckpoint 後讓訓練程序結束,並修改使用者程式碼。