TPU v5e で PyTorch を使用して Llama 3 をトレーニングする

このチュートリアルでは、WikiText データセットを使用して、TPU v5e で PyTorch/XLA を使用して Llama-3-8B モデルをトレーニングする方法について説明します。モデルの詳細については、Meta-Llama-3-8B をご覧ください。

Llama-3-8B モデルは Hugging Face プラットフォームでホストされています。

Meta-Llama-3-8B には 2 つのバージョンがあります。1 つは Transformer で使用するためのバージョンで、もう 1 つは元の Llama 3 コードベースで使用するためのバージョンです。このチュートリアルでは、Transformers バージョンを使用します。これは、次の理由からです。

  • Hugging Face エコシステムとシームレスに統合: モデルの微調整、事前構築済みパイプラインの使用、膨大なデータセットとツールへのアクセスが容易になります。

  • 柔軟性とカスタマイズを可能にする: Transformers バージョンでは、モデルのファインチューニングとデプロイに柔軟性とカスタマイズ オプションが大幅に強化されています。

  • コミュニティ サポートを提供: Hugging Face コミュニティは、Transformers モデルの使用に関する豊富なドキュメント、チュートリアル、サポートを提供しています。

Transformer の詳細については、Hugging Face Transformers のドキュメントをご覧ください。

Meta-Llama-3-8B モデルにアクセスして使用するには(重みとトークンのダウンロードを含む)、Hugging Face ユーザー アクセス トークンが必要です。トークンには次の機能があります。

  • 認証と認可: アクセス トークンは認証情報として機能し、Hugging Face サーバーがモデルのリソースへのアクセスを承認できるようにします。これにより、承認されたユーザーのみがモデルをダウンロードして使用できます。

  • セキュリティ: Hugging Face は、アクセス トークンを使用してモデルを保護し、不正アクセスや不正使用を防ぎます。

このチュートリアルでアクセス トークンを作成して使用する方法については、モデルを実行するをご覧ください。アクセス トークンの作成と使用の詳細については、ユーザー アクセス トークンの Hugging Face ドキュメントをご覧ください。

また、Hugging Face の Llama 3 8B モデルにアクセスする権限も必要です。この権限を取得するには、Hugging Face の Meta-Llama-3-8B モデルにアクセスしてアクセスをリクエストします。

TPU v5litepod-16 をプロビジョニングする準備を行う

このチュートリアルは、次の Cloud TPU 環境変数を使用してテストされました。アクセラレータ タイプ、ゾーン、ランタイム バージョンが互換性がある限り、他の変数を使用して TPU をプロビジョニングできます。たとえば、このチュートリアルでは、ゾーンとして europe-west4-b が使用されます。実行している TPU バージョン(アクセラレータ タイプ)(このチュートリアルでは v5litepod-16)をサポートする他のゾーンを使用できます。

次の TPU VM 環境変数を設定します。

   export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id
   export PROJECT_ID=your-project-id
   export ACCELERATOR_TYPE=v5litepod-16
   export ZONE=europe-west4-b
   export RUNTIME_VERSION=v2-alpha-tpuv5-lite
   export QUEUED_RESOURCE_ID=queued-resource-id
   export VALID_UNTIL_DURATION=1d

Hugging Face で Meta-Llama-3-8B モデルにアクセスしたら、チュートリアルを実行する TPU 環境を準備します。

  1. Cloud TPU 環境を設定するガイドに沿って、Cloud TPU を使用するための適切なアクセス権があることを確認します。

  2. TPU VM のサービス ID を作成します。

    gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
  3. TPU サービス アカウントを作成し、 Google Cloud サービスへのアクセス権を付与します。

    サービス アカウントにより、 Google Cloud TPU サービスが他の Google Cloudサービスにアクセスできるようになります。ユーザー管理のサービス アカウントの使用をおすすめします。サービス アカウントは、Google Cloud コンソールまたは gcloud コマンドを使用して作成できます。

    gcloud コマンドライン ツールを使用してサービス アカウントを作成します。

    gcloud iam service-accounts create your-service-account-name \
    --description="your-sa-description" \
    --display-name="your-sa-display-name"
    export SERVICE_ACCOUNT_NAME=your-service-account-name

    Google Cloud コンソールからサービス アカウントを作成します。

    1. Google Cloud コンソールで [サービス アカウント] ページに移動します。
    2. [サービス アカウントを作成] をクリックします。
    3. サービス アカウント名を入力します。
    4. (省略可)サービス アカウントの説明を入力します。
    5. [作成] をクリックして続行します。
    6. サービス アカウントに付与するロールを選択します。
    7. [続行] をクリックします。
    8. (省略可)サービス アカウントを管理できるユーザーまたはグループを指定します。
    9. [完了] をクリックして、サービス アカウントの作成を完了します。

    サービス アカウントを作成したら、次の手順でサービス アカウントのロールを付与します。

    次のロールが必要です。

    • TPU 管理者: TPU を作成するために必要です
    • ストレージ管理者: Cloud Storage にアクセスするために必要です。
    • ログ書き込み
    • モニタリング指標の書き込み: Cloud Monitoring に指標を書き込むために必要

    ユーザーに IAM ロールを割り当てるには、管理者から roles/resourcemanager.projectIamAdmin が付与されている必要があります。プロジェクト IAM 管理者 roles/resourcemanager.projectIamAdmin ロールを持つユーザーも、このロールを付与できます。

    次の gcloud コマンドを使用して、サービス アカウントのロールを追加します。

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/tpu.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/storage.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/logging.logWriter
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/monitoring.metricWriter

    Google Cloud コンソールを使用してロールを割り当てることもできます。

    Google Cloud コンソールで、次のロールを選択します。

    1. サービス アカウントを選択し、[プリンシパルを追加] をクリックします。
    2. [新しいプリンシパル] フィールドに、サービス アカウントのメールアドレスを入力します。
    3. [ロールを選択] プルダウンで、ロール(ストレージ管理者など)を検索して選択します。
    4. [保存] をクリックします。
  4. Google Cloud で認証し、Google Cloud CLI のデフォルトのプロジェクトとゾーンを構成します。

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

容量を確保する

TPU 容量を確保する準備ができたら、割り当てページで Cloud Quotas システムについて確認します。容量の確保について他にご不明な点がございましたら、Cloud TPU のセールスチームまたはアカウント チームにお問い合わせください。

Cloud TPU 環境をプロビジョニングする

TPU VM は、GKE、GKE と XPK、またはキューに入れられたリソースとしてプロビジョニングできます。

前提条件

  • このチュートリアルは Python 3.10 以降でテストされています。
  • プロジェクトに十分な TPUS_PER_TPU_FAMILY 割り当てがあることを確認します。これは、Google Cloud プロジェクト内でアクセスできるチップの最大数を指定します。
  • プロジェクトに次の TPU 割り当てがあることを確認します。
    • TPU VM の割り当て
    • IP アドレスの割り当て
    • Hyperdisk Balanced の割り当て
  • ユーザー プロジェクトの権限

TPU v5litepod-16 をプロビジョニングする

  1. TPU VM を作成します。

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID}   \
    --node-id=${TPU_NAME}  \
    --project=${PROJECT_ID}   \
    --zone=${ZONE}   \
    --accelerator-type=${ACCELERATOR_TYPE}   \
    --runtime-version=${RUNTIME_VERSION}   \
    --service-account=${SERVICE_ACCOUNT_NAME}   \
    --spot
  2. TPU が ACTIVE 状態であることを確認します。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project=${PROJECT_ID} \
    --zone=${ZONE}

TPU がアクティブになると(ACTIVE)、次のような出力が表示されます。

createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
  state: ACTIVE
tpu:
  nodeSpec:
  - node:
      acceleratorType: v5litepod-16
      networkConfig:
        enableExternalIps: true
        network: default
      queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
      runtimeVersion: v2-alpha-tpuv5-lite
      schedulingConfig: {}
      my-service-account@your-project-id.iam.gserviceaccount.com
        email: 19672137854-compute@developer.iam.gserviceaccount.com
      shieldedInstanceConfig: {}
    nodeId: tpu-name
    parent: projects/19672137403/locations/zone

インストール

Hugging Face Transformers の pytorch-tpu/transformers フォークと依存関係をインストールします。このチュートリアルは、次の依存関係のバージョンでテストされています。

  • torch: 2.6.0 と互換性あり
  • torch_xla[tpu]: 2.6.0 と互換性あり
  • jax: 0.4.38
  • jaxlib: 0.4.38

フレームワーク ソフトウェアと依存関係をインストールする

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    sudo apt install python3.10-venv
    python -m venv /home/$USER/venv/
    source ~/venv/bin/activate
    cd transformers
    pip3 install --user -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

インストールが完了すると、次のような出力が表示されます。

Collecting jax==0.4.38
  Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
  Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
  Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
  Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0

モデル構成を設定する

次のセクションのトレーニング コマンド(モデルを実行する)では、2 つの JSON 構成ファイルを使用して、モデル パラメータと FSDP(完全にシャーディングされたデータ パラレル)構成を定義します。FSDP シャーディングは、トレーニング中にモデルの重みが大きなバッチサイズに適合するために使用されます。小規模なモデルでトレーニングする場合は、データ並列処理を使用して各デバイスに重みを複製するだけで十分な場合があります。PyTorch/XLA でデバイス間でテンソルをシャーディングする方法については、PyTorch/XLA SPMD ユーザーガイドをご覧ください。

  1. このコマンドは、Llama3-8B のモデル パラメータ構成ファイルを作成します。他のモデルについては、Hugging Face で構成を確認してください。たとえば、Llama2-7B 構成をご覧ください。

    cat > llama-config.json <<EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. FSDP 構成ファイルを作成します。

    cat > fsdp-config.json <<EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    FSDP の詳細については、FSDPv2 をご覧ください。

  3. 次のコマンドを使用して、構成ファイルを TPU VM にアップロードします。

     ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent.
     gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

    このコマンドを実行すると、次のような出力が生成されます。

    Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers.
    SCP: Attempting to connect to worker 0...
    SCP: Attempting to connect to worker 1...
    SCP: Attempting to connect to worker 2...
    SCP: Attempting to connect to worker 3...
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.0KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00

モデルを実行する

前のセクションで作成した構成ファイルを使用して run_clm.py スクリプトを実行し、WikiText データセットで Llama 3 8B モデルをトレーニングします。トレーニング スクリプトを TPU v5litepod-16 で実行すると、約 10 分かかります。

  1. Hugging Face トークンをまだ生成していない場合は、新しいトークンを生成します。

    1. [Your Profile] > [Settings] > [Access Tokens] の順にクリックします。
    2. [New Token] を選択します。
    3. 任意の名前と、少なくとも Read ロールを指定します。
    4. [Generate a token] を選択します。
  2. Hugging Face トークンを使用して、次のコマンドを使用して TPU VM から Hugging Face にログインします。

    huggingface-cli login トークン変数を、前の手順で Hugging Face から生成されたトークンに置き換えます。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    pip install -U "huggingface_hub[cli]"
    export PATH="/home/$USER/.local/bin/:$PATH"
    huggingface-cli login --token hf_abcxyzEFg'

    このコマンドを実行すると、Hugging Face にログインし、現在アクティブなトークンが表示されます。

  3. モデルのトレーニングを実行します。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --worker=all \
        --command='
        source ~/venv/bin/activate
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=your-bucket/profile_path
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

トレーニング ステップには約 10 分かかります。トレーニングの終わりに、次のようなメッセージが表示されます。

[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >>   Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >>   Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >>   Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >>   Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >>   Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >>   Number of trainable parameters = 8,030,261,248
  0%|          | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
  5%|         | 1/20 [00:07<02:29,  7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  5%|         | 1/20 [00:07<02:29,  7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
 10%|         | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
 10%|█         | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
 15%|█▌        | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
 15%|█▌        | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
 20%|██        | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
 85%|████████▌ | 17/20 [07:55<00:06,  2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)

100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>

Training completed. Do not forget to share your model on huggingface.co/models =)

クリーンアップ

トレーニングが完了したら、次の手順でキューに登録されたリソースと TPU VM を削除します。これにより、TPU VM の使用に対する課金が停止されます。

  gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --force \
       --async