在 TPU v5e 上使用 PyTorch 训练 Llama 3
本教程介绍了如何使用 WikiText 数据集在 TPU v5e 上使用 PyTorch/XLA 训练 Llama-3-8B 模型。如需了解模型详情,请参阅 Meta-Llama-3-8B。
Llama-3-8B 模型托管在 Hugging Face 平台上。
Meta-Llama-3-8B 有两个版本,一个版本适用于 Transformers,另一个版本适用于原始 Llama 3 代码库。本教程使用 Transformers 版本,因为它:
与 Hugging Face 生态系统无缝集成:这样,您可以更轻松地微调模型、使用预构建的流水线,以及访问大量数据集和工具。
支持灵活性和自定义:Transformers 版本提供了极大的灵活性和自定义选项,可用于微调和部署模型。
提供社区支持:Hugging Face 社区提供了丰富的文档、教程和使用 Transformers 模型的支持。
如需详细了解 Transformer,请参阅 Hugging Face Transformer 文档。
如需访问和使用 Meta-Llama-3-8B 模型(包括下载其权重和分词器),您需要 Hugging Face 用户访问令牌。该令牌提供:
身份验证和授权:访问令牌可用作凭据,让 Hugging Face 服务器授权您访问模型的资源。这样可以确保只有获得授权的用户才能下载和使用该模型。
安全:Hugging Face 使用访问令牌来保护其模型,并防止未经授权的访问或滥用。
如需了解如何为本教程创建和使用访问令牌,请参阅运行模型。如需更全面地了解如何创建和使用访问令牌,请参阅 Hugging Face 文档中的用户访问令牌部分。
您还需要有权访问 Hugging Face 上的 Llama 3 8B 模型。如需获取此权限,请前往 Hugging Face 上的 Meta-Llama-3-8B 模型并请求访问权限。
准备预配 TPU v5litepod-16
本教程使用以下 Cloud TPU 环境变量进行了测试。您可以使用其他变量预配 TPU,前提是加速器类型、可用区和运行时版本兼容。例如,在本教程中,europe-west4-b
始终用作可用区。您可以使用支持您所运行的 TPU 版本(本教程中为 v5litepod-16)的任何其他可用区。
设置以下 TPU VM 环境变量。
export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v5litepod-16 export ZONE=europe-west4-b export RUNTIME_VERSION=v2-alpha-tpuv5-lite export QUEUED_RESOURCE_ID=queued-resource-id export VALID_UNTIL_DURATION=1d
获得对 Hugging Face 上 Meta-Llama-3-8B 模型的访问权限后,准备 TPU 环境以运行本教程。
请按照设置 Cloud TPU 环境指南操作,确保您拥有使用 Cloud TPU 的适当访问权限。
为 TPU 虚拟机创建服务身份。
gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
创建 TPU 服务账号并授予对 Google Cloud 服务的访问权限。
通过服务账号, Google Cloud TPU 服务可以访问其他 Google Cloud服务。建议使用用户管理的服务账号。您可以通过 Google Cloud 控制台或
gcloud
命令创建服务账号。使用
gcloud
命令行工具创建服务账号:gcloud iam service-accounts create your-service-account-name \ --description="your-sa-description" \ --display-name="your-sa-display-name" export SERVICE_ACCOUNT_NAME=your-service-account-name
通过 Google Cloud 控制台创建服务账号:
- 前往 Google Cloud 控制台中的“服务账号”页面。
- 点击创建服务账号。
- 输入服务账号名称。
- (可选)输入服务账号的说明。
- 点击创建并继续。
- 选择要向服务账号授予的角色。
- 点击继续。
- (可选)指定可以管理该服务账号的用户或群组。
- 点击完成以完成服务账号的创建过程。
创建服务账号后,请按照以下步骤授予服务账号角色。
您需要拥有以下角色:
- TPU 管理员:创建 TPU 所需
- Storage Admin:需要此角色才能访问 Cloud Storage
- Logs Writer
- Monitoring Metric Writer:用于将指标写入 Cloud Monitoring
您必须获得管理员授予的
roles/resourcemanager.projectIamAdmin
才能向用户分配 IAM 角色。拥有 Project IAM Adminroles/resourcemanager.projectIamAdmin
角色的用户也可以授予此角色。使用以下
gcloud
命令添加服务账号角色:gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/tpu.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/storage.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/logging.logWriter gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/monitoring.metricWriter
您还可以使用 Google Cloud 控制台分配角色。
在 Google Cloud 控制台中,选择以下角色:
- 选择您的服务账号,然后点击添加主账号。
- 在新主账号字段中,输入服务账号的电子邮件地址。
- 在选择角色下拉菜单中,搜索并选择相应角色(例如 Storage Admin)。
- 点击保存。
使用 Google Cloud 进行身份验证,并为 Google Cloud CLI 配置默认项目和区域。
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
保障容量
当您准备好预订 TPU 容量时,请查看配额页面,了解 Cloud Quotas 系统。如果您对如何确保容量还有其他疑问,请与您的 Cloud TPU 销售团队或客户支持团队联系。
预配 Cloud TPU 环境
您可以使用 GKE、GKE 和 XPK 预配 TPU 虚拟机,也可以将其作为队列化资源预配。
前提条件
- 本教程已使用 Python 3.10 或更高版本进行测试。
- 验证您的项目是否有足够的
TPUS_PER_TPU_FAMILY
配额,该配额指定您可以在Google Cloud 项目中访问的芯片数量上限。 - 验证您的项目是否有足够的 TPU 配额:
- TPU 虚拟机配额
- IP 地址配额
- Hyperdisk Balanced 配额
- 用户项目权限
- 如果您将 GKE 与 XPK 搭配使用,请参阅用户账号或服务账号的 Cloud 控制台权限,了解运行 XPK 所需的权限。
预配 TPU v5litepod-16
创建 TPU 虚拟机:
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT_NAME} \ --spot
验证 TPU 是否处于
ACTIVE
状态:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
TPU 变为活动状态 (ACTIVE
) 后,您将看到类似以下内容的输出:
createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
state: ACTIVE
tpu:
nodeSpec:
- node:
acceleratorType: v5litepod-16
networkConfig:
enableExternalIps: true
network: default
queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
runtimeVersion: v2-alpha-tpuv5-lite
schedulingConfig: {}
my-service-account@your-project-id.iam.gserviceaccount.com
email: 19672137854-compute@developer.iam.gserviceaccount.com
shieldedInstanceConfig: {}
nodeId: tpu-name
parent: projects/19672137403/locations/zone
安装
安装 pytorch-tpu/transformers
分支的 hugging face Transformer 和依赖项。本教程已使用以下依赖项版本进行测试:
torch
:与 2.6.0 兼容torch_xla[tpu]
:与 2.6.0 兼容jax
:0.4.38jaxlib
:0.4.38
安装框架软件和依赖项
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git sudo apt install python3.10-venv python -m venv /home/$USER/venv/ source ~/venv/bin/activate cd transformers pip3 install --user -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
安装完成后,您将看到类似于以下内容的输出:
Collecting jax==0.4.38
Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
设置模型配置
下一部分(运行模型)中的训练命令使用两个 JSON 配置文件来定义模型参数和 FSDP(完全分片数据并行)配置。FSDP 分片用于模型权重,以便在训练期间适应更大的批量大小。使用较小模型进行训练时,使用数据并行处理并在每台设备上复制权重可能就足够了。如需详细了解如何在 PyTorch/XLA 中跨设备分片张量,请参阅 PyTorch/XLA SPMD 用户指南。
此命令会为 Llama3-8B 创建模型参数配置文件。如需了解其他模型,请在 Hugging Face 上查找配置。例如,请参阅 Llama2-7B 配置。
cat > llama-config.json <<EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
创建 FSDP 配置文件:
cat > fsdp-config.json <<EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
如需详细了解 FSDP,请参阅 FSDPv2。
使用以下命令将配置文件上传到 TPU 虚拟机:
ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent. gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
此命令将生成类似于以下内容的输出:
Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers. SCP: Attempting to connect to worker 0... SCP: Attempting to connect to worker 1... SCP: Attempting to connect to worker 2... SCP: Attempting to connect to worker 3... llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.0KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00
运行模型
使用您在上一部分中创建的配置文件,运行 run_clm.py
脚本,以便在 WikiText 数据集上训练 Llama 3 8B 模型。训练脚本在 TPU v5litepod-16 上运行大约需要 10 分钟。
如果您还没有 Hugging Face 令牌,请生成一个新令牌:
- 依次点击您的个人资料 > 设置 > 访问令牌。
- 选择新建令牌 (New Token)。
- 指定您选择的名称和一个至少为 Read 的角色。
- 选择生成令牌。
使用 Hugging Face 令牌通过以下命令从 TPU 虚拟机登录 Hugging Face。
将
huggingface-cli login
令牌变量替换为在上一步中通过 Hugging Face 生成的令牌:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' pip install -U "huggingface_hub[cli]" export PATH="/home/$USER/.local/bin/:$PATH" huggingface-cli login --token hf_abcxyzEFg'
此命令会将您登录 Hugging Face,并显示当前有效的令牌。
运行模型训练:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' source ~/venv/bin/activate export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=your-bucket/profile_path python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
训练步骤大约需要 10 分钟。训练结束时,您会看到类似以下内容的消息:
[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >> Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >> Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >> Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >> Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >> Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >> Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >> Number of trainable parameters = 8,030,261,248
0%| | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
5%|▌ | 1/20 [00:07<02:29, 7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
5%|▌ | 1/20 [00:07<02:29, 7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
10%|█ | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
10%|█ | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
15%|█▌ | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
15%|█▌ | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
20%|██ | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
85%|████████▌ | 17/20 [07:55<00:06, 2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
清理
训练完成后,请按照以下步骤删除加入队列的资源和 TPU 虚拟机。这将停止对 TPU VM 用量进行计费。
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --force \ --async