Cloud TPU v5e トレーニング

Pod あたり 256 チップのフットプリントが小さい TPU v5e は、トランスフォーマー、text-to-image、および、畳み込みニューラル ネットワーク(CNN)のトレーニング、ファインチューニング、サービス提供に適した最適なプロダクトとなるように最適化されています。Cloud TPU v5e をサービングに使用する方法については、v5e を使用した推論をご覧ください。

Cloud TPU v5e TPU のハードウェアと構成の詳細については、TPU v5e をご覧ください。

使ってみる

以降のセクションでは、TPU v5e の使用を開始する方法について説明します。

リクエストの割り当て

トレーニングに TPU v5e を使用するには、割り当てが必要です。オンデマンド TPU、予約 TPU、TPU Spot VM には、さまざまな割り当てタイプがあります。推論に TPU v5e を使用する場合は、個別の割り当てが必要です。割り当ての詳細については、割り当てをご覧ください。TPU v5e の割り当てをリクエストするには、Cloud セールスにお問い合わせください。

Google Cloud アカウントとプロジェクトを作成する

Cloud TPU を使用するには、 Google Cloud アカウントとプロジェクトが必要です。詳細については、Cloud TPU 環境を設定するをご覧ください。

Cloud TPU を作成する

queued-resource create コマンドを使用して、Cloud TPU v5es をキューに格納されたリソースとしてプロビジョニングすることをおすすめします。詳細については、キューに格納されたリソースを管理するをご覧ください。

Create Node API(gcloud compute tpus tpu-vm create)を使用して Cloud TPU v5es をプロビジョニングすることもできます。詳細については、TPU リソースの管理をご覧ください。

トレーニングに使用できる v5e 構成の詳細については、トレーニング用の Cloud TPU v5e タイプをご覧ください。

フレームワークの設定

このセクションでは、TPU v5e で JAX または PyTorch を使用したカスタムモデルのトレーニングの一般的な設定プロセスについて説明します。

推論の設定手順については、v5e 推論の概要をご覧ください。

環境変数をいくつか定義します。

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id

JAX を設定する

スライス形状が 8 チップを超える場合、1 つのスライスに複数の VM があります。この場合、SSH を使用して個別にログインすることなく、--worker=all フラグを使用してすべての TPU VM に 1 つのステップでインストールを実行する必要があります。

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

コマンドフラグの説明

変数 説明
TPU_NAME キューに入れられたリソース リクエストの割り当て時に作成される TPU のユーザー割り当てテキスト ID。
PROJECT_ID Google Cloud プロジェクト名。既存のプロジェクトを使用するか、 Google Cloud プロジェクトを設定するの説明に従って新しいプロジェクトを作成します。
ZONE サポートされているゾーンについては、TPU のリージョンとゾーンのドキュメントをご覧ください。
worker 基盤となる TPU にアクセスできる TPU VM。

次のコマンドを実行して、デバイスの数を確認できます(ここに表示されている出力は、v5litepod-16 スライスで生成されたものです)。このコードは、JAX が Cloud TPU TensorCore を認識し、基本オペレーションを実行できることを確認することで、すべてが正しくインストールされていることをテストします。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() は、指定されたスライス内のチップの合計数を示します。jax.local_device_count() は、このスライス内の単一の VM からアクセス可能なチップの数を示します。

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

出力は次のようになります。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

このドキュメントの JAX チュートリアルを試して、JAX を使用した v5e トレーニングを開始します。

PyTorch を設定する

v5e は PJRT ランタイムのみをサポートしているのでご注意ください。PyTorch 2.1 以降では、すべての TPU バージョンのデフォルト ランタイムとして PJRT が使用されます。

このセクションでは、すべてのワーカー用のコマンドで PyTorch/XLA を使用して v5e 上で PJRT の使用を開始する方法について説明します。

依存関係をインストールする

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip install mkl mkl-include
      pip install tf-nightly tb-nightly tbp-nightly
      pip install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

PYTORCH_VERSION は、使用する PyTorch のバージョンに置き換えます。PYTORCH_VERSION は、PyTorch/XLA に同じバージョンを指定するために使用されます。2.6.0 が推奨です。

PyTorch と PyTorch/XLA のバージョンの詳細については、PyTorch - スタートガイドPyTorch/XLA リリースをご覧ください。

PyTorch/XLA のインストールの詳細については、PyTorch/XLA のインストールをご覧ください。

torchtorch_xlatorchvision のホイール(pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222 など)を取り付ける際にエラーが発生した場合は、次のコマンドを使用してバージョンをダウングレードします。

pip3 install setuptools==62.1.0

PJRT でスクリプトを実行する

unset LD_PRELOAD

Python スクリプトを使用して v5e VM で計算を行う例を次に示します。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

これにより、次のような出力が生成されます。

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

このドキュメントの PyTorch チュートリアルを試して、PyTorch を使用した v5e トレーニングを開始します。

セッションの終了時に TPU とキューに格納されたリソースを削除します。キューに格納されたリソースを削除するには、スライスを削除してから、キューに格納されたリソースを削除する、2 つのステップで行います。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

これら 2 つのステップは、FAILED 状態にある、キューに格納されたリソース リクエストを削除するためにも使用できます。

JAX / FLAX の例

以降のセクションでは、TPU v5e で JAX モデルと FLAX モデルをトレーニングする方法の例について説明します。

v5e で ImageNet をトレーニングする

このチュートリアルでは、架空の入力データを使用して v5e で ImageNet をトレーニングする方法について説明します。実際のデータを使用する場合は、GitHub の README ファイルをご覧ください。

設定

  1. 環境変数を作成します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン
    SERVICE_ACCOUNT サービス アカウントのメールアドレス。これは、 Google Cloud コンソールの [サービス アカウント] ページで確認できます。

    例: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当てテキスト ID。

  2. TPU リソースを作成します

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    キューに格納されたリソースが ACTIVE 状態になると、TPU VM に SSH 接続できるようになります。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    QueuedResource が ACTIVE 状態の場合、出力は次のようになります。

     state: ACTIVE
    
  3. 最新バージョンの JAX と jaxlib をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. ImageNet モデルのクローンを作成し、対応する要件をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
    
  5. 架空のデータを生成するには、データセットのサイズに関する情報が必要です。これは、ImageNet データセットのメタデータから収集できます。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
    

モデルをトレーニングする

前の手順をすべて完了したら、モデルをトレーニングできます。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

TPU とキューに格納されたリソースを削除する

セッションの終了時に TPU とキューに格納されたリソースを削除します。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hugging Face FLAX モデル

FLAX で実装された Hugging Face モデルは Cloud TPU v5e ですぐに使用できます。このセクションでは、一般的なモデルの実行手順について説明します。

Imagenette で ViT をトレーニングする

このチュートリアルでは、Cloud TPU v5e で Fast AI Imagenette データセットを使用して、HuggingFace から Vision Transformer(ViT)モデルをトレーニングする方法について説明します。

ViT モデルは、畳み込みネットワークと比較して優れた結果で ImageNet で Transformer エンコーダをトレーニングした最初のモデルです。詳しくは、次のリソースをご覧ください。

設定

  1. 環境変数を作成します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン
    SERVICE_ACCOUNT サービス アカウントのメールアドレス。これは、 Google Cloud コンソールの [サービス アカウント] ページで確認できます。

    例: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当てテキスト ID。

  2. TPU リソースを作成します

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    キューに格納されたリソースが ACTIVE 状態になると、TPU VM に SSH 接続できるようになります。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

     state: ACTIVE
    
  3. JAX とそのライブラリをインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Hugging Face のリポジトリをダウンロードし、要件をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
    
  5. Imagenette データセットをダウンロードします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

モデルをトレーニングする

事前にマッピングされたバッファ(4GB)を使用してモデルをトレーニングします。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

TPU とキューに格納されたリソースを削除する

セッションの終了時に TPU とキューに格納されたリソースを削除します。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

ViT のベンチマーク結果

トレーニング スクリプトは、v5litepod-4、v5litepod-16、v5litepod-64 で実行されました。次の表に、異なるアクセラレータ タイプでのスループットを示します。

アクセラレータ タイプ v5litepod-4 v5litepod-16 v5litepod-64
エポック 3 3 3
グローバル バッチサイズ 32 128 512
スループット(例/秒) 263.40 429.34 470.71

Pokémon データセットで Diffusion をトレーニングする

このチュートリアルでは、Cloud TPU v5e で Pokémon データセットを使用して、HuggingFace から Stable Diffusion モデルをトレーニングする方法について説明します。

Stable Diffusion モデルは、テキスト入力からフォトリアリスティックな画像を生成する、潜在的 text-to-image モデルです。詳しくは、次のリソースをご覧ください。

設定

  1. ストレージ バケットの名前の環境変数を設定します。

    export GCS_BUCKET_NAME=your_bucket_name
  2. モデル出力用のストレージ バケットを設定します。

    gcloud storage buckets create gs://GCS_BUCKET_NAME \
        --project=your_project \
        --location=us-west1
  3. 環境変数を作成します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west1-c
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン
    SERVICE_ACCOUNT サービス アカウントのメールアドレス。これは、 Google Cloud コンソールの [サービス アカウント] ページで確認できます。

    例: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当てテキスト ID。

  4. TPU リソースを作成します

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    キューに格納されたリソースが ACTIVE 状態になると、TPU VM に SSH 接続できるようになります。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

     state: ACTIVE
    
  5. JAX とそのライブラリをインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  6. HuggingFace のリポジトリをダウンロードし、要件をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
         --project=${PROJECT_ID} \
         --zone=${ZONE} \
         --worker=all \
         --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
    

モデルをトレーニングする

事前にマッピングされたバッファ(4GB)を使用してモデルをトレーニングします。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip3 install -r requirements.txt
    pip3 install .
    pip3 install gcsfs
    export LIBTPU_INIT_ARGS=''
    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

クリーンアップ

セッションの終了時に TPU、キューに格納されたリソース、Cloud Storage バケットを削除します。

  1. TPU を削除します。

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  2. キューに格納されているリソースを削除します。

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  3. Cloud Storage バケットを削除します。

    gcloud storage rm -r gs://${GCS_BUCKET_NAME}
    

拡散のベンチマーク結果

トレーニング スクリプトは、v5litepod-4、v5litepod-16、v5litepod-64 で実行されました。次の表に、スループットを示します。

アクセラレータ タイプ v5litepod-4 v5litepod-16 v5litepod-64
トレーニング ステップ 1500 1500 1500
グローバル バッチサイズ 32 64 128
スループット(例/秒) 36.53 43.71 49.36

PyTorch/XLA

以下のセクションでは、TPU v5e で PyTorch/XLA モデルをトレーニングする方法の例について説明します。

PJRT ランタイムを使用して ResNet をトレーニングする

PyTorch/XLA は、PyTorch 2.0 以降、XRT から PjRt に移行しています。PyTorch/XLA トレーニング ワークロード用に v5e を設定する手順が更新されています。

設定
  1. 環境変数を作成します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン
    SERVICE_ACCOUNT サービス アカウントのメールアドレス。これは、 Google Cloud コンソールの [サービス アカウント] ページで確認できます。

    例: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当てテキスト ID。

  2. TPU リソースを作成します

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    QueuedResource が ACTIVE 状態になると、TPU VM に SSH 接続できるようになります。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

     state: ACTIVE
    
  3. Torch/XLA 固有の依存関係をインストールします

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --worker=all \
      --command='
         sudo apt-get update -y
         sudo apt-get install libomp5 -y
         pip3 install mkl mkl-include
         pip3 install tf-nightly tb-nightly tbp-nightly
         pip3 install numpy
         sudo apt-get install libopenblas-dev -y
         pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    PYTORCH_VERSION は、使用する PyTorch のバージョンに置き換えます。PYTORCH_VERSION は、PyTorch/XLA に同じバージョンを指定するために使用されます。2.6.0 が推奨です。

    PyTorch と PyTorch/XLA のバージョンの詳細については、PyTorch - スタートガイドPyTorch/XLA リリースをご覧ください。

    PyTorch/XLA のインストールの詳細については、PyTorch/XLA のインストールをご覧ください。

ResNet モデルをトレーニングする
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git checkout release-r2.6
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 —num_workers=16  --log_steps=300 --batch_size=64 --profile'

TPU とキューに格納されたリソースを削除する

セッションの終了時に TPU とキューに格納されたリソースを削除します。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
ベンチマークの結果

次の表に、ベンチマークのスループットを示します。

アクセラレータ タイプ スループット(例/秒)
v5litepod-4 4,240 例/秒
v5litepod-16 10,810 例/秒
v5litepod-64 46,154 例/秒

v5e で ViT をトレーニングする

このチュートリアルでは、cifar10 データセット上の PyTorch/XLA で HuggingFace のリポジトリを使用して、v5e で VIT を実行する方法について説明します。

設定

  1. 環境変数を作成します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン
    SERVICE_ACCOUNT サービス アカウントのメールアドレス。これは、 Google Cloud コンソールの [サービス アカウント] ページで確認できます。

    例: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID キューに格納されたリソース リクエストのユーザー割り当てテキスト ID。

  2. TPU リソースを作成します

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    QueuedResource が ACTIVE 状態になると、TPU VM に SSH 接続できるようになります。

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    キューに格納されたリソースが ACTIVE 状態の場合、出力は次のようになります。

     state: ACTIVE
    
  3. PyTorch/XLA の依存関係をインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
      pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    PYTORCH_VERSION は、使用する PyTorch のバージョンに置き換えます。PYTORCH_VERSION は、PyTorch/XLA に同じバージョンを指定するために使用されます。2.6.0 が推奨です。

    PyTorch と PyTorch/XLA のバージョンの詳細については、PyTorch - スタートガイドPyTorch/XLA リリースをご覧ください。

    PyTorch/XLA のインストールの詳細については、PyTorch/XLA のインストールをご覧ください。

  4. HuggingFace のリポジトリをダウンロードし、要件をインストールします。

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

モデルをトレーニングする

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

TPU とキューに格納されたリソースを削除する

セッションの終了時に TPU とキューに格納されたリソースを削除します。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

ベンチマークの結果

次の表に、さまざまなアクセラレータ タイプのベンチマーク スループットを示します。

v5litepod-4 v5litepod-16 v5litepod-64
エポック 3 3 3
グローバル バッチサイズ 32 128 512
スループット(例/秒) 201 657 2,844