Cloud TPU v5p トレーニング

Cloud TPU v5p は、Google Cloud の第 5 世代 Cloud TPU であり、v4 TPU の後継です。v5p は大規模なトレーニング用に最適化されており、基盤となる LLM、拡散モデル、生成 AI を開発するための主要なプラットフォームです。大まかに言うと、v5p は v4 の最大 2 倍の性能を備えながら、Pod に 2 倍の TPU を詰め込み(最大スライスは v4 の 3k に対して 6k)、Pod レベルで最大 4 倍の性能を実現します。また、高いクロック周波数(1.05 Ghz に対して 1.75 Ghz)で動作し、大規模な埋め込み用の SparseCore が追加され、高帯域幅メモリ(HBM)容量を 3 倍に増やしています。

Cloud TPU v5p のコンセプト

Cloud TPU を初めて使用する場合は、TPU ドキュメントのホームページをご覧ください。

Cloud TPU のコンセプト(スライス、ホスト、TensorCore など)と、すべての Cloud TPU バージョンに対する Cloud TPU システム アーキテクチャについては、Cloud TPU システム アーキテクチャ ページをご覧ください。

各 Cloud TPU のバージョンでは、トレーニングまたは推論のために特定のアクセラレータ タイプが必要です。これらのアクセラレータ タイプについては、v5p 構成をご覧ください。

TPU リソースを管理する

このドキュメントのすべてのコマンドは、TPU v5p VM を作成していることを前提としています。TPU VM を作成するコマンドの詳細については、TPU の管理またはキューに入れられたリソースの管理に関するキューに入れられたリソース ユーザーガイドをご覧ください。コマンドを簡単に実行できるように、このドキュメントのコードサンプルでは次の環境変数を使用しています。

export PROJECT_ID=your-project
export ACCELERATOR_TYPE=v5p-8
export ZONE=us-east5-a
export RUNTIME_VERSION=v2-alpha-tpuv5
export TPU_NAME=your-tpu-name

環境変数の説明

PROJECT_ID
TPU を作成する Google Cloud プロジェクト。
ACCELERATOR_TYPE
アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
ZONE
Cloud TPU を作成するゾーン
RUNTIME_VERSION
TPU ソフトウェアのバージョン
TPU_NAME
使用している TPU のユーザー定義名。

フレームワークの設定

このセクションでは、TPU v5p で JAX または PyTorch を使用したモデル トレーニングの一般的な設定プロセスについて説明します。

JAX を設定する

スライス形状が 4 チップを超える場合、1 つのスライスに複数の VM があります。この場合、--worker=all フラグを使用して、1 つのコマンドを使用してすべての TPU VM にインストールを実行する必要があります。

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

次のコマンドを実行して、デバイスの数を確認できます(ここに表示されている出力は、v5p-32 スライスで生成されたものです)。このコードは、JAX で Cloud TPU TensorCore が表示され、基本オペレーションを実行できることを確認することによって、すべてが正しくインストールされていることをテストします。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() は、指定されたスライス内のチップの合計数を示します。jax.local_device_count() は、このスライス内の単一の VM からアクセス可能なチップの数を示します。

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

--node=all を使用して、すべてのマルチスライス ワーカーでコマンドを実行します。

gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

このドキュメントの JAX チュートリアルを試して、JAX を使用した v5p トレーニングを開始します。

PyTorch を設定する

PJRT ランタイムは v5p でサポートされている唯一のランタイムで、PyTorch 2.1+ ではすべての TPU バージョンのデフォルト ランタイムとして PJRT が使用されます。このセクションでは、すべてのワーカーに PyTorch/XLA 2.2.0 を使用して v5p Pod で PJRT を使用する方法について説明します。

依存関係をインストールする

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
sudo apt-get update
sudo apt-get install libopenblas-dev -y
pip install numpy
pip install torch torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
'

PJRT で Python スクリプトを使用してインストールを検証します。スクリプトには、使用可能な TPU デバイスが表示されます(ここに示す出力は v5p-32 スライスで生成されています)。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
--command='
PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
'
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']

--node=all を使用して、すべてのマルチスライス ワーカーでコマンドを実行します。

gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \
--command='
PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
'

このドキュメントの PyTorch チュートリアルを試して、PyTorch を使用した v5p トレーニングを開始します。

モニタリングとプロファイル

Cloud TPU v5p は、以前の世代の Cloud TPU と同じ方法でモニタリングとプロファイリングをサポートしています。プロファイリングの詳細については、Cloud TPU ツールでモデルをプロファイリングするをご覧ください。モニタリングの詳細については、Cloud TPU VM のモニタリングをご覧ください。

トレーニング チュートリアル

このセクションでは、単一スライスのトレーニング チュートリアルについて説明します。これらのチュートリアルをマルチスライス トレーニングに適応させるには、SSH コマンドに --node=all フラグを追加します。詳細とベスト プラクティスについては、マルチスライスの概要をご覧ください。

JAX チュートリアル

Diffusion 2.1 をトレーニングする

このチュートリアルでは、Cloud TPU v5p で Pokémon データセットを使用して、HuggingFace から Stable Diffusion モデルをトレーニングする方法について説明します。

Stable Diffusion モデルは、テキスト入力からフォトリアリスティックな画像を生成する、潜在的 text-to-image モデルです。詳しくは、次のリソースをご覧ください。

設定

  1. 環境変数を作成します。

    export GCS_BUCKET_NAME=your-bucket
    export PROJECT_ID=your-project-ID
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export LOCATION=europe-west4
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export SERVICE_ACCOUNT=your-service-account
    export TPU_NAME=your-tpu-name
    export QUEUED_RESOURCE_ID=your-qr-name
    export QUOTA_TYPE=spot
    export VALID_UNTIL_DURATION=1d

    コマンドフラグの説明

    変数 説明
    PROJECT_ID Google Cloud プロジェクト名
    ACCELERATOR_TYPE TPU のバージョンについては、TPU のバージョンをご覧ください。
    ZONE サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
    LOCATION Cloud Storage ストレージ バケットを作成する Google Cloud リージョン。
    RUNTIME_VERSION v5p の場合、RUNTIME_VERSION に v2-alpha-tpuv5 を使用します。
    SERVICE_ACCOUNT これは、Google Cloud コンソール -> IAM] -> サービス アカウント で確認できるサービス アカウントのアドレスです。 例: tpu-service-account@myprojectID。iam.gserviceaccount.com
    TPU_NAME キューに格納されたリソース リクエストの割り当て時に作成される TPU のユーザー割り当て ID。
    QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当て ID。キューに格納されたリソースについては、キューに格納されたリソースのドキュメントをご覧ください。
    QUOTA_TYPE reservedspot のいずれかを設定できます。どちらも指定されていない場合、QUOTA_TYPE はデフォルトで on-demand になります。Cloud TPU でサポートされている割り当てのさまざまなタイプについては、割り当てをご覧ください。
    VALID_UNTIL_DURATION リクエストが有効である期間。さまざまな有効期間の詳細については、キューに入れられたリソースをご覧ください。
  2. モデル出力用のストレージ バケットを設定します。

    gcloud storage buckets create gs://$GCS_BUCKET_NAME \
     --project=$PROJECT_ID \
     --location=$LOCATION
  3. TPU リソースを作成します。

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_UNTIL_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}

    キューに格納されたリソースが ACTIVE 状態になると、TPU VM に SSH 接続できるようになります。 次のコマンドを実行して、キューに入れられたリソースの状態を確認します。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
    --project ${PROJECT_ID} --zone ${ZONE}

    キューに入れられたリソースが ACTIVE 状態の場合、出力は次のようになります。

    state: ACTIVE
    
  4. モデルのトレーニング

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --project $PROJECT_ID --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    git reset --hard 57629bcf4fa32fe5a57096b60b09f41f2fa5c35d # This identifies the GitHub commit to use.
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # Install the latest version of JAX
    pip3 install -r requirements.txt
    pip3 install .
    export LIBTPU_INIT_ARGS=""
    python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run base_output_directory=gs://$GCS_BUCKET_NAME enable_profiler=False"

クリーンアップ

セッションの終了時に TPU とキューに入れられたリソース リクエストを削除するか、「FAILED」状態のキューに入れられたリソース リクエストを削除します。キューに入れられたリソースを削除するには、2 ステップで、スライスを削除した後、キューに入れられたリソース リクエストを削除します。

   gcloud compute tpus tpu-vm delete ${TPU_NAME} --project=${PROJECT_ID} \
      --zone=${ZONE} --quiet
   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet

または、--force を使用して、1 ステップでスライスとキューに入れられたリソース リクエストを削除します。

# With --force
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
--project ${PROJECT_ID} --zone ${ZONE} --quiet --force

ベンチマークの結果

Stable Diffusion トレーニング スクリプトは、v5p-8、v5p-32、v5p-128 で実行されました。次の表では、スループットを示します。

v5p-8

v5p-32

v5p-128

トレーニング ステップ

150

150

150

グローバル バッチサイズ

32

64

64

スループット(例/秒)

12.10

18.08

19.10

MaxText

このチュートリアルでは、Cloud TPU で合成データセットを使用して MaxText モデルをトレーニングする方法について説明します。

MaxText は、Cloud TPU をターゲットとする純粋な Python/JAX で記述された、高性能で任意に拡張可能なオープンソース LLM です。MaxText は、自然言語処理(NLP)の研究開発の新境地を開拓するための、アクセスしやすく適応性の高いツールを使用して研究者と開発者を支援します。

このチュートリアルを実行する前に、Cloud TPU 環境を設定する必要があります。

  1. 環境変数を設定する

    export PROJECT_ID=your_project_ID
    export TPU_NAME=your_tpu_name # user defined TPU name
    export ACCELERATOR_TYPE=v5p-256
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export RUN_NAME=your_experiment_run_name # user defined name for this run
    export GCS_BUCKET_NAME=your_bucket_name # Output cloud folder. Should start with gs://
    export MAXTEXT_OUTPUT_PATH=${GCS_BUCKET_NAME}/your_experiment_output_path
    export NUM_SLICES=1 # Update the value to a number >1 for Multislice.

    コマンドフラグの説明

    変数 説明
    PROJECT_ID Google Cloud プロジェクト名
    TPU_NAME TPU のユーザー定義名。
    ACCELERATOR_TYPE TPU のバージョンについては、TPU のバージョンをご覧ください。
    ZONE サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
    RUNTIME_VERSION v5p の場合は、ランタイム バージョンに v2-alpha-tpuv5 を使用します。
    RUN_NAME ユーザーが指定したテスト実行名。

    マルチスライスに推奨されるオプションの設定:

    export NETWORK_NAME=your_network_name
    export FIREWALL_RULE_NAME=your_firewall_rule_name

    マルチスライス ワークロードを実行していて、最適なネットワーク パフォーマンスが必要な場合は、最大伝送単位(MTU)が 8,896 バイトの専用ネットワークを作成して、適切なファイアウォール ルールを構成することを検討してください。このステップは省略できますが、特にデータセンター ネットワーク(DCN)でスライス数をスケールアップする場合に、パフォーマンスを大幅に改善できます。なお、ネットワークを作成するには、プロジェクトに compute.networks.create 権限が必要です。次の例では、専用ネットワークとファイアウォール ルールを作成する方法を示します。

    専用ネットワークを作成します。

    gcloud compute networks create ${NETWORK_NAME} \
    --mtu=8896 \
    --project=${PROJECT_ID} \
    --subnet-mode=auto \
    --bgp-routing-mode=regional

    ファイアウォール ルールを作成します。

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT_ID}
  2. MaxText リポジトリのクローンを作成します

    git clone https://github.com/google/maxtext.git
  3. モデルのトレーニング

    以降のセクションでは、MaxText をトレーニングする 2 つのオプションについて説明します。

    オプション 1

    Cloud TPU のプロビジョニングと依存関係のインストールから、モデルの実行とリソースの削除まで、ワークフロー全体をスクリプトで管理する場合は、multihost_job.py を使用できます。

    cd maxtext && python3 multihost_job.py --PROJECT=${PROJECT_ID} --ZONE=${ZONE} \
    --NUM_SLICES=${NUM_SLICES} --TPU_TYPE=${ACCELERATOR_TYPE} \
    --VERSION=${RUNTIME_VERSION} --RUN_NAME=${RUN_NAME} #user defined run name \
    --BUCKET_NAME=${GCS_BUCKET_NAME} \ #used to store logs and configs
    --COMMAND="bash setup.sh && bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"

    スクリプトを開始すると、次のようなメッセージがログに表示されます。ログの場所は出力メッセージで参照されます。最初のリンクをクリックして、TPU のプロビジョニングが完了したらすべてのワーカーのログにアクセスします。

    ------------------------------------
    
    multihost_job finished running, TPUs are starting up to run your job remotely.
    
    Logs for your job are displayed here:
    https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22_log%22%2529;?project=PROJECT_ID
    
    To see the output of a single host, you may edit the slice and worker
    number in the `log_file_path` property here:
    
    https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22RUN_NAME_log%22%2529%20AND%0Alabels.%22agent.googleapis.com%2Flog_file_path%22%3D%20%22%2FRUN_NAME%2Fmain_command_log_slice_0_worker_0%22;?project=PROJECT_ID
    
    When your job is finished, the main command log is in your Cloud Storage
    bucket:
    https://console.cloud.google.com/storage/browser/YOUR_BUCKET_NAME/RUN_NAME?project=PROJECT_ID
    
    View the status of the created TPUs using:
    gcloud compute tpus queued-resources list --filter=RUN_NAME
    --zone=ZONE --project=PROJECT_ID
    
オプション 2

プロビジョニングされた Cloud TPU でトレーニング スクリプトを複数回実行するには、multihost_runner.py スクリプトを使ってリソースを使用します。

  1. 変数を設定して TPU を作成します。

    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export VALID_DURATION=1d
    export QUOTA_TYPE=quota_type
    --node-count ${NODE_COUNT} \
    --node-prefix ${NODE_PREFIX} # optional, the default is QUEUED_RESOURCE_ID
  2. TPU リソースを作成します。

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}

    QueuedResourceACTIVE 状態になると、SSH を使用して TPU VM に接続できるようになります。

    describe コマンドを使用して、キューに入れられたリソースのステータスを確認します。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}
    --project ${PROJECT_ID} --zone ${ZONE}

    キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

     state: ACTIVE
    
  3. SSH を使用して TPU に接続します

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
  4. 依存関係をインストールする

    export TPU_NAME=your_tpu_name
    export MAXTEXT_OUTPUT_PATH=output-path
    cd maxtext && python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \
    --COMMAND='bash setup.sh'
  5. 32b.sh、64b.sh などのさまざまな構成スクリプトを使用してモデルを実行します。TPU VM からスクリプトを実行する場合は、--INTERNAL_IP=true フラグを追加する必要があります。

    python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \
    --COMMAND="bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME}
    OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"

クリーンアップ

TPU とキューに格納されたリソースを削除します

ベンチマークの結果

MaxText トレーニング スクリプトは、32 ~ 1,160 バイトを bf16 適合率で実行しました。これらの実行結果を次の表に示します。

パラメータ数

アクセラレータ タイプ

TFLOP/チップ/秒

モデルの FLOP 使用率

(MFU)

32B

v5p-128

3.28E+02

71.47%

64B

v5p-128

3.23E+02

70.31%

128B

v5p-256

3.15E+02

68.68%

128B

v5p-512

3.15E+02

68.53%

256B

v5p-1024

3.16E+02

68.82%

512B

v5p-1024

2.94E+02

63.99%

1024B

v5p-2048

2.49E+02

64.05%

1024B

v5p-4096

2.97E+02

64.80%

1160B

v5p-7680

2.95E+02

64.27%

1160B

v5p-12288

3.04E+02

66.23%

256B パラメータ モデルは、bf16 と int8 の両方の重み付けを使用して、v5p-512 と v5p-1024 でテストされています。次の表では、これらのテスト結果を示します。

v5p-512

v5p-512

v5p-1024

v5p-1024

グローバル バッチサイズ

(トークン)

5.24E+05

5.24E+05

1.05E+06

1.05E+06

適合率

bf16

int8

bf16

int8

TFLOP/チップ/秒

307

408

308

414

モデルの FLOP 使用率

(MFU)

66.98%

88.85%

67.09%

90.23%

TensorFlow チュートリアル

単一ホスト v5p で ResNet をトレーニングする

このチュートリアルでは、架空のデータセットを使用して v5p-8 TPU で ImageNet をトレーニングする方法について説明します。別のデータセットを使用する場合は、データセットの準備をご覧ください。

設定

  1. 環境変数を作成します。

    export PROJECT_ID=your-project-ID
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=us-east1-c
    export RUNTIME_VERSION=tpu-vm-tf-2.18.0-pjrt
    export TPU_NAME=your-tpu-name
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type

    このチュートリアルでは、ACCELERATOR_TYPE として v5p-8 を使用します。

  2. TPU リソースを作成します。

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --${QUOTA_TYPE}

    キューに入れられたリソースが ACTIVE 状態になると、SSH を使用して TPU VM に接続できるようになります。キューに格納されたリソースの状態を確認するには、次のコマンドを使用します。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
  3. SSH を使用して TPU に接続します

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
  4. いくつかの環境変数を設定します

    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export NEXT_PLUGGABLE_DEVICE_USE_C_API=true
    export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
  5. モデル リポジトリのディレクトリに移動し、要件をインストールします。

    cd ${MODELS_REPO} && git checkout r2.15.0
    pip install -r official/requirements.txt

モデルのトレーニング

  1. トレーニング スクリプトを実行します。

    python3 official/vision/train.py \
      --tpu=local \
      --experiment=resnet_imagenet \
      --mode=train_and_eval \
      --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
      --model_dir=${MODEL_DIR} \
      --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"

クリーンアップ

TPU とキューに格納されたリソースを削除します

マルチホスト v5p で ResNet をトレーニングする

このチュートリアルでは、架空のデータセットを使用して v5p-16 以上での ImageNet をトレーニングする方法について説明します。別のデータセットを使用する場合は、データセットの準備をご覧ください。

  1. 環境変数を作成します。

    export PROJECT_ID=your_project_ID
    export TPU_NAME=your_tpu_name
    export ZONE=us-east1-c
    export ACCELERATOR_TYPE=v5p-16
    export RUNTIME_VERSION=tpu-vm-tf-2.18.0-pod-pjrt
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type

    ACCELERATOR_TYPEv5p-16 かそれ以上にすることができます。

  2. TPU リソースを作成します。

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --${QUOTA_TYPE}

    キューに入れられたリソースが ACTIVE 状態になると、SSH を使用して TPU VM に接続できるようになります。

    describe コマンドを使用して、キューに入れられたリソースのステータスをクエリします。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
  3. SSH を使用して TPU(ワーカーゼロ)に接続します

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
  4. いくつかの環境変数を設定します

    export TPU_NAME=your_tpu_name
    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export TPU_LOAD_LIBRARY=0
  5. モデル リポジトリのディレクトリに移動し、要件をインストールします。

    cd $MODELS_REPO && git checkout r2.15.0
    pip install -r official/requirements.txt

モデルのトレーニング

  1. トレーニング スクリプトを実行します。

    python3 official/vision/train.py \
      --tpu=${TPU_NAME} \
      --experiment=resnet_imagenet \
      --mode=train_and_eval \
      --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
      --model_dir=${MODEL_DIR} \
      --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"

クリーンアップ

TPU とキューに格納されたリソースを削除します

PyTorch/XLA

Llama 2

このチュートリアルでは、ML 計算グラフ(GSPMD)の一般および拡張可能な並列化により、PyTorch/XLA で HuggingFace のリポジトリのフォークを使用して、v5p で Llama 2 7B モデルをトレーニングする方法について説明します。

設定

  1. 環境変数を作成します。

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_DURATION=1d
  2. TPU リソースの作成

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}

    QueuedResourceACTIVE 状態になると、SSH を使用して TPU VM に接続できるようになります。

    describe コマンドを使用して、キューに入れられたリソースのステータスを確認します。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project ${PROJECT_ID} \
    --zone ${ZONE}

    キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

     state: ACTIVE
    

  3. Pytorch/XLA と必要な依存関係をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone  ${ZONE} \
    --worker=all \
    --command='
    sudo apt-get update
    sudo apt-get install libopenblas-dev -y
    pip3 install numpy
    pip3 install typing-extensions
    pip install torch torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html
    '
  4. HuggingFace のリポジトリをダウンロードし、要件をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    git clone -b llama2-google-next-training https://github.com/pytorch-tpu/transformers.git
    cd transformers
    pip3 install git+file://$PWD
    pip3 install datasets accelerate evaluate scikit-learn'
  5. 7B モデル構成をダウンロードします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command="curl https://huggingface.co/TheBloke/Llama-2-7B-fp16/raw/main/config.json --output ~/config.json"
  6. モデルのトレーニング

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    export PJRT_DEVICE=TPU
    export XLA_USE_BF16=1
    export XLA_IR_DEBUG=1
    export XLA_HLO_DEBUG=1
    
    export LIBTPU_INIT_ARGS="--xla_enable_async_collective_permute=true \
    --xla_tpu_enable_async_collective_fusion_multiple_steps=true \
    --xla_tpu_enable_async_collective_fusion=true \
    --xla_tpu_overlap_compute_collective_tc=true \
    --xla_enable_async_all_gather=true \
    --xla_jf_spmd_threshold_for_windowed_einsum_mib=0"
    
    export PROFILE_EPOCH=0
    export PROFILE_STEP=3
    export PROFILE_DURATION_MS=20000
    export PROFILE_LOGDIR=/tmp/home/
    
    cd transformers
    python examples/pytorch/language-modeling/run_clm.py \
     --tokenizer_name hf-internal-testing/llama-tokenizer \
     --dataset_name wikitext \
     --dataset_config_name wikitext-2-raw-v1 \
     --per_device_train_batch_size 96 \
     --per_device_eval_batch_size 8 \
     --num_train_epochs 1 \
     --do_train \
     --output_dir /tmp/output \
     --overwrite_output_dir \
     --config_name ~/config.json \
     --save_strategy no \
     --logging_strategy no \
     --remove_unused_columns no \
     --optim adafactor \
     --torch_dtype bfloat16 \
     --dataloader_drop_last yes \
     --block_size 2048 \
     --spmd_2d_sharding 1 \
     --spmd_grad_chkpt
    '

マルチスライス環境で実行する場合は、フラグ --spmd_dcn_parallelism をスライスの数に設定する必要があります。

SPMD_USER_GUIDE には、HF スクリプトのさまざまな環境変数とトグルを説明する詳細なユーザーガイドが用意されています。なお、LIBTPU_INIT_ARGS は、今後のリリースで PyTorch/XLA に組み込まれ、デフォルトでオンになります。

クリーンアップ

TPU とキューに格納されたリソースを削除します

ベンチマークの結果

次の表では、3 つの Llama 2 モデルのサイズすべてのスループットを示します。

v5p-8

v5p-128

v5p-128

モデルのサイズ

70 億人

13B

70B

グローバル バッチサイズ

96

1024

128

シャーディング メッシュの形状

(4, 1)

(64, 1)

(16, 4)

モデルの FLOP 使用率

(MFU)

56.67%

55.80%

51.85%

サポートとフィードバック

フィードバックをぜひお寄せください。フィードバックを共有したり、サポートをリクエストしたりするには、Cloud TPU サポートまたはフィードバック フォームにご記入ください。