在 TPU v5e 上使用 PyTorch 训练 Llama 3

本教程介绍了如何使用 WikiText 数据集在 TPU v5e 上使用 PyTorch/XLA 训练 Llama-3-8B 模型。如需了解模型详情,请参阅 Meta-Llama-3-8B

Llama-3-8B 模型托管在 Hugging Face 平台上。

Meta-Llama-3-8B 有两个版本,一个版本适用于 Transformers,另一个版本适用于原始 Llama 3 代码库。本教程使用 Transformers 版本,因为它:

  • 与 Hugging Face 生态系统无缝集成:这样,您可以更轻松地微调模型、使用预构建的流水线,以及访问大量数据集和工具。

  • 支持灵活性和自定义:Transformers 版本提供了极大的灵活性和自定义选项,可用于微调和部署模型。

  • 提供社区支持:Hugging Face 社区提供了丰富的文档、教程和使用 Transformers 模型的支持。

如需详细了解 Transformer,请参阅 Hugging Face Transformer 文档

如需访问和使用 Meta-Llama-3-8B 模型(包括下载其权重和分词器),您需要 Hugging Face 用户访问令牌。该令牌提供:

  • 身份验证和授权:访问令牌可用作凭据,让 Hugging Face 服务器授权您访问模型的资源。这样可以确保只有获得授权的用户才能下载和使用该模型。

  • 安全:Hugging Face 使用访问令牌来保护其模型,并防止未经授权的访问或滥用。

如需了解如何为本教程创建和使用访问令牌,请参阅运行模型。如需更全面地了解如何创建和使用访问令牌,请参阅 Hugging Face 文档中的用户访问令牌部分

您还需要有权访问 Hugging Face 上的 Llama 3 8B 模型。如需获取此权限,请前往 Hugging Face 上的 Meta-Llama-3-8B 模型并请求访问权限。

准备预配 TPU v5litepod-16

本教程使用以下 Cloud TPU 环境变量进行了测试。您可以使用其他变量预配 TPU,前提是加速器类型、可用区和运行时版本兼容。例如,在本教程中,europe-west4-b 始终用作可用区。您可以使用支持您所运行的 TPU 版本(本教程中为 v5litepod-16)的任何其他可用区。

设置以下 TPU VM 环境变量。

   export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id
   export PROJECT_ID=your-project-id
   export ACCELERATOR_TYPE=v5litepod-16
   export ZONE=europe-west4-b
   export RUNTIME_VERSION=v2-alpha-tpuv5-lite
   export QUEUED_RESOURCE_ID=queued-resource-id
   export VALID_UNTIL_DURATION=1d

获得对 Hugging Face 上 Meta-Llama-3-8B 模型的访问权限后,准备 TPU 环境以运行本教程。

  1. 请按照设置 Cloud TPU 环境指南操作,确保您拥有使用 Cloud TPU 的适当访问权限。

  2. 为 TPU 虚拟机创建服务身份。

    gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
  3. 创建 TPU 服务账号并授予对 Google Cloud 服务的访问权限。

    通过服务账号, Google Cloud TPU 服务可以访问其他 Google Cloud服务。建议使用用户管理的服务账号。您可以通过 Google Cloud 控制台或 gcloud 命令创建服务账号。

    使用 gcloud 命令行工具创建服务账号:

    gcloud iam service-accounts create your-service-account-name \
    --description="your-sa-description" \
    --display-name="your-sa-display-name"
    export SERVICE_ACCOUNT_NAME=your-service-account-name

    通过 Google Cloud 控制台创建服务账号:

    1. 前往 Google Cloud 控制台中的“服务账号”页面。
    2. 点击创建服务账号
    3. 输入服务账号名称。
    4. (可选)输入服务账号的说明。
    5. 点击创建并继续。
    6. 选择要向服务账号授予的角色。
    7. 点击继续
    8. (可选)指定可以管理该服务账号的用户或群组。
    9. 点击完成以完成服务账号的创建过程。

    创建服务账号后,请按照以下步骤授予服务账号角色。

    您需要拥有以下角色:

    • TPU 管理员:创建 TPU 所需
    • Storage Admin:需要此角色才能访问 Cloud Storage
    • Logs Writer
    • Monitoring Metric Writer:用于将指标写入 Cloud Monitoring

    您必须获得管理员授予的 roles/resourcemanager.projectIamAdmin 才能向用户分配 IAM 角色。拥有 Project IAM Admin roles/resourcemanager.projectIamAdmin 角色的用户也可以授予此角色。

    使用以下 gcloud 命令添加服务账号角色:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/tpu.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/storage.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/logging.logWriter
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/monitoring.metricWriter

    您还可以使用 Google Cloud 控制台分配角色。

    在 Google Cloud 控制台中,选择以下角色:

    1. 选择您的服务账号,然后点击添加主账号
    2. 新主账号字段中,输入服务账号的电子邮件地址。
    3. 选择角色下拉菜单中,搜索并选择相应角色(例如 Storage Admin)。
    4. 点击保存
  4. 使用 Google Cloud 进行身份验证,并为 Google Cloud CLI 配置默认项目和区域。

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

保障容量

当您准备好预订 TPU 容量时,请查看配额页面,了解 Cloud Quotas 系统。如果您对如何确保容量还有其他疑问,请与您的 Cloud TPU 销售团队或客户支持团队联系。

预配 Cloud TPU 环境

您可以使用 GKE、GKE 和 XPK 预配 TPU 虚拟机,也可以将其作为队列化资源预配。

前提条件

  • 本教程已使用 Python 3.10 或更高版本进行测试。
  • 验证您的项目是否有足够的 TPUS_PER_TPU_FAMILY 配额,该配额指定您可以在Google Cloud 项目中访问的芯片数量上限。
  • 验证您的项目是否有足够的 TPU 配额:
    • TPU 虚拟机配额
    • IP 地址配额
    • Hyperdisk Balanced 配额
  • 用户项目权限

预配 TPU v5litepod-16

  1. 创建 TPU 虚拟机:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID}   \
    --node-id=${TPU_NAME}  \
    --project=${PROJECT_ID}   \
    --zone=${ZONE}   \
    --accelerator-type=${ACCELERATOR_TYPE}   \
    --runtime-version=${RUNTIME_VERSION}   \
    --service-account=${SERVICE_ACCOUNT_NAME}   \
    --spot
  2. 验证 TPU 是否处于 ACTIVE 状态:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project=${PROJECT_ID} \
    --zone=${ZONE}

TPU 变为活动状态 (ACTIVE) 后,您将看到类似以下内容的输出:

createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
  state: ACTIVE
tpu:
  nodeSpec:
  - node:
      acceleratorType: v5litepod-16
      networkConfig:
        enableExternalIps: true
        network: default
      queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
      runtimeVersion: v2-alpha-tpuv5-lite
      schedulingConfig: {}
      my-service-account@your-project-id.iam.gserviceaccount.com
        email: 19672137854-compute@developer.iam.gserviceaccount.com
      shieldedInstanceConfig: {}
    nodeId: tpu-name
    parent: projects/19672137403/locations/zone

安装

安装 pytorch-tpu/transformers 分支的 hugging face Transformer 和依赖项。本教程已使用以下依赖项版本进行测试:

  • torch:与 2.6.0 兼容
  • torch_xla[tpu]:与 2.6.0 兼容
  • jax:0.4.38
  • jaxlib:0.4.38

安装框架软件和依赖项

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    sudo apt install python3.10-venv
    python -m venv /home/$USER/venv/
    source ~/venv/bin/activate
    cd transformers
    pip3 install --user -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

安装完成后,您将看到类似于以下内容的输出:

Collecting jax==0.4.38
  Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
  Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
  Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
  Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0

设置模型配置

下一部分(运行模型)中的训练命令使用两个 JSON 配置文件来定义模型参数和 FSDP(完全分片数据并行)配置。FSDP 分片用于模型权重,以便在训练期间适应更大的批量大小。使用较小模型进行训练时,使用数据并行处理并在每台设备上复制权重可能就足够了。如需详细了解如何在 PyTorch/XLA 中跨设备分片张量,请参阅 PyTorch/XLA SPMD 用户指南

  1. 此命令会为 Llama3-8B 创建模型参数配置文件。如需了解其他模型,请在 Hugging Face 上查找配置。例如,请参阅 Llama2-7B 配置

    cat > llama-config.json <<EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. 创建 FSDP 配置文件:

    cat > fsdp-config.json <<EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    如需详细了解 FSDP,请参阅 FSDPv2

  3. 使用以下命令将配置文件上传到 TPU 虚拟机:

     ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent.
     gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

    此命令将生成类似于以下内容的输出:

    Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers.
    SCP: Attempting to connect to worker 0...
    SCP: Attempting to connect to worker 1...
    SCP: Attempting to connect to worker 2...
    SCP: Attempting to connect to worker 3...
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.0KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00

运行模型

使用您在上一部分中创建的配置文件,运行 run_clm.py 脚本,以便在 WikiText 数据集上训练 Llama 3 8B 模型。训练脚本在 TPU v5litepod-16 上运行大约需要 10 分钟。

  1. 如果您还没有 Hugging Face 令牌,请生成一个新令牌:

    1. 依次点击您的个人资料 > 设置 > 访问令牌
    2. 选择新建令牌 (New Token)。
    3. 指定您选择的名称和一个至少为 Read 的角色。
    4. 选择生成令牌
  2. 使用 Hugging Face 令牌通过以下命令从 TPU 虚拟机登录 Hugging Face。

    huggingface-cli login 令牌变量替换为在上一步中通过 Hugging Face 生成的令牌:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    pip install -U "huggingface_hub[cli]"
    export PATH="/home/$USER/.local/bin/:$PATH"
    huggingface-cli login --token hf_abcxyzEFg'

    此命令会将您登录 Hugging Face,并显示当前有效的令牌。

  3. 运行模型训练:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --worker=all \
        --command='
        source ~/venv/bin/activate
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=your-bucket/profile_path
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

训练步骤大约需要 10 分钟。训练结束时,您会看到类似以下内容的消息:

[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >>   Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >>   Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >>   Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >>   Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >>   Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >>   Number of trainable parameters = 8,030,261,248
  0%|          | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
  5%|         | 1/20 [00:07<02:29,  7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  5%|         | 1/20 [00:07<02:29,  7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
 10%|         | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
 10%|█         | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
 15%|█▌        | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
 15%|█▌        | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
 20%|██        | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
 85%|████████▌ | 17/20 [07:55<00:06,  2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)

100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>

Training completed. Do not forget to share your model on huggingface.co/models =)

清理

训练完成后,请按照以下步骤删除加入队列的资源和 TPU 虚拟机。这将停止对 TPU VM 用量进行计费。

  gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --force \
       --async