Entrena Llama 3 con PyTorch en TPU v5e
En este instructivo, se describe cómo entrenar un modelo Llama-3-8B con PyTorch/XLA en una TPU v5e con el conjunto de datos WikiText. Consulta Meta-Llama-3-8B para obtener detalles del modelo.
El modelo Llama-3-8B se aloja en la plataforma de Hugging Face.
Existen dos versiones de Meta-Llama-3-8B, una para usar con Transformers y otra con la base de código original de Llama 3. En este instructivo, se usa la versión de Transformers porque tiene las siguientes características:
Se integra sin problemas en el ecosistema de Hugging Face: Esto facilita la optimización del modelo, el uso de canalizaciones precompiladas y el acceso a una amplia colección de conjuntos de datos y herramientas.
Habilita la flexibilidad y la personalización: La versión de Transformers ofrece opciones de personalización y flexibilidad significativas para ajustar y, luego, implementar el modelo.
Proporciona asistencia de la comunidad: La comunidad de Hugging Face proporciona documentación, instructivos y asistencia extensos para usar modelos de Transformers.
Para obtener más información sobre los transformadores, consulta la documentación de Hugging Face Transformers.
Para acceder al modelo Meta-Llama-3-8B y usarlo, incluida la descarga de sus pesos y el analizador, necesitas un token de acceso de usuario de Hugging Face. El token proporciona lo siguiente:
Autenticación y autorización: El token de acceso actúa como una credencial y permite que los servidores de Hugging Face autoricen tu acceso a los recursos del modelo. Esto garantiza que solo los usuarios autorizados puedan descargar y usar el modelo.
Seguridad: Hugging Face usa tokens de acceso para proteger sus modelos y evitar el acceso no autorizado o el uso inadecuado.
Para obtener información sobre cómo crear y usar un token de acceso para este instructivo, consulta Ejecuta el modelo. Para obtener información más detallada sobre cómo crear y usar tokens de acceso, consulta la documentación de Hugging Face sobre los tokens de acceso de los usuarios.
También necesitas permiso para acceder al modelo Llama 3 8B en Hugging Face. Para obtener ese permiso, ve al modelo Meta-Llama-3-8B en Hugging Face y solicita acceso.
Preparación para aprovisionar un TPU v5litepod-16
Este instructivo se probó con las siguientes variables de entorno de Cloud TPU. Puedes usar otras variables para aprovisionar tu TPU, siempre que el tipo de acelerador, la zona y la versión del entorno de ejecución sean compatibles.
Por ejemplo, en este instructivo, se usa europe-west4-b
como la zona. Puedes usar cualquier otra zona que admita la
versión de TPU (tipo de acelerador) que ejecutes (v5litepod-16 en este instructivo).
Configura las siguientes variables de entorno de la VM de TPU.
export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v5litepod-16 export ZONE=europe-west4-b export RUNTIME_VERSION=v2-alpha-tpuv5-lite export QUEUED_RESOURCE_ID=queued-resource-id export VALID_UNTIL_DURATION=1d
Cuando tengas acceso al modelo Meta-Llama-3-8B en Hugging Face, prepara el entorno de TPU para ejecutar el instructivo.
Sigue la guía Configura el entorno de Cloud TPU para asegurarte de tener el acceso adecuado para usar Cloud TPU.
Crea una identidad de servicio para la VM de TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
Crea una cuenta de servicio de TPU y otorga acceso a los servicios de Google Cloud .
Las cuentas de servicio permiten que el servicio de Google Cloud TPU acceda a otros servicios Google Cloud. Se recomienda una cuenta de servicio administrada por el usuario. Puedes crear una cuenta de servicio desde la consola de Google Cloud o con el comando
gcloud
.Crea una cuenta de servicio con la herramienta de línea de comandos de
gcloud
:gcloud iam service-accounts create your-service-account-name \ --description="your-sa-description" \ --display-name="your-sa-display-name" export SERVICE_ACCOUNT_NAME=your-service-account-name
Crea una cuenta de servicio desde la consola de Google Cloud:
- Ve a la página Cuentas de servicio en la consola de Google Cloud.
- Haga clic en Crear cuenta de servicio.
- Ingresa el nombre de la cuenta de servicio.
- Opcional: Ingresa una descripción para la cuenta de servicio.
- Haz clic en Crear y continúa.
- Elige los roles que deseas otorgar a la cuenta de servicio.
- Haz clic en Continuar.
- (Opcional) Especifica los usuarios o grupos que pueden administrar la cuenta de servicio.
- Haz clic en Listo para terminar de crear la cuenta de servicio.
Después de crear tu cuenta de servicio, sigue estos pasos para otorgarle roles.
Se requieren los siguientes roles:
- Administrador de TPU: Es necesario para crear una TPU.
- Administrador de almacenamiento: Es necesario para acceder a Cloud Storage.
- Escritor de registros
- Escritor de métricas de Monitoring: Es necesario para escribir métricas en Cloud Monitoring.
Tu administrador debe otorgarte el rol
roles/resourcemanager.projectIamAdmin
para que puedas asignar roles de IAM a los usuarios. Un usuario con el rolroles/resourcemanager.projectIamAdmin
de administrador de IAM del proyecto también puede otorgar este rol.Usa los siguientes comandos
gcloud
para agregar roles de cuenta de servicio:gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/tpu.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/storage.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/logging.logWriter gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/monitoring.metricWriter
También puedes asignar roles con la consola de Google Cloud.
En la consola de Google Cloud, selecciona los siguientes roles:
- Selecciona tu cuenta de servicio y haz clic en Agregar principal.
- En el campo Principales nuevas, ingresa la dirección de correo electrónico de tu cuenta de servicio.
- En el menú desplegable Seleccionar un rol, busca el rol (por ejemplo, Administrador de almacenamiento) y selecciónalo.
- Haz clic en Guardar.
Autentícate con Google Cloud y configura el proyecto y la zona predeterminados para Google Cloud CLI.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Cómo proteger la capacidad
Cuando tengas todo listo para proteger la capacidad de las TPU, revisa la página de cuotas para obtener información sobre el sistema de Cloud Quotas. Si tienes más preguntas sobre cómo reservar la capacidad, comunícate con tu equipo de ventas o de cuentas de Cloud TPU.
Aprovisiona el entorno de Cloud TPU
Puedes aprovisionar VMs de TPU con GKE, con GKE y XPK, o como recursos en cola.
Requisitos previos
- Este instructivo se probó con Python 3.10 o versiones posteriores.
- Verifica que tu proyecto tenga suficiente cuota de
TPUS_PER_TPU_FAMILY
, que especifica la cantidad máxima de chips a los que puedes acceder en tu proyecto deGoogle Cloud . - Verifica que tu proyecto tenga suficiente cuota de TPU para lo siguiente:
- Cuota de VM de TPU
- Cuota de direcciones IP
- Quota de Hyperdisk Balanced
- Permisos del proyecto del usuario
- Si usas GKE con XPK, consulta Permisos de la consola de Google Cloud en la cuenta de usuario o de servicio para conocer los permisos necesarios para ejecutar XPK.
Aprovisiona un TPU v5litepod-16
Crea una VM de TPU:
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT_NAME} \ --spot
Verifica que la TPU esté en el estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Cuando la TPU se active (ACTIVE
), verás un resultado similar al siguiente:
createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
state: ACTIVE
tpu:
nodeSpec:
- node:
acceleratorType: v5litepod-16
networkConfig:
enableExternalIps: true
network: default
queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
runtimeVersion: v2-alpha-tpuv5-lite
schedulingConfig: {}
my-service-account@your-project-id.iam.gserviceaccount.com
email: 19672137854-compute@developer.iam.gserviceaccount.com
shieldedInstanceConfig: {}
nodeId: tpu-name
parent: projects/19672137403/locations/zone
Instalación
Instala la división pytorch-tpu/transformers
de Hugging Face Transformers y las dependencias. Este instructivo se probó con las siguientes versiones de dependencias:
torch
: Compatible con 2.6.0torch_xla[tpu]
: Compatible con 2.6.0jax
: 0.4.38jaxlib
: 0.4.38
Instala el software y las dependencias del framework
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git sudo apt install python3.10-venv python -m venv /home/$USER/venv/ source ~/venv/bin/activate cd transformers pip3 install --user -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Cuando se complete la instalación, verás un resultado similar al siguiente:
Collecting jax==0.4.38
Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Configura las configuraciones del modelo
El comando de entrenamiento en la siguiente sección, Run the model, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de FSDP (paralelismo de datos completamente fragmentado). El fragmentación de FSDP se usa para que los pesos del modelo se ajusten a un tamaño de lote más grande durante el entrenamiento. Cuando se entrena con modelos más pequeños, puede ser suficiente usar el paralelismo de datos y replicar los pesos en cada dispositivo. Para obtener más información sobre cómo dividir tensores en varios dispositivos en PyTorch/XLA, consulta la Guía del usuario de SPMD de PyTorch/XLA.
Con este comando, se crea el archivo de configuración de parámetros del modelo para Llama3-8B. Para otros modelos, busca la configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama2-7B.
cat > llama-config.json <<EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Crea el archivo de configuración de FSDP:
cat > fsdp-config.json <<EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
Para obtener más información sobre FSDP, consulta FSDPv2.
Sube los archivos de configuración a tus VMs de TPU con los siguientes comandos:
ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent. gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Este comando generará un resultado similar al siguiente:
Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers. SCP: Attempting to connect to worker 0... SCP: Attempting to connect to worker 1... SCP: Attempting to connect to worker 2... SCP: Attempting to connect to worker 3... llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.0KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00
Ejecuta el modelo
Con los archivos de configuración que creaste en la sección anterior, ejecuta la secuencia de comandos run_clm.py
para entrenar el modelo Llama 3 8B en el conjunto de datos de WikiText. La secuencia de comandos de entrenamiento tarda aproximadamente 10 minutos en ejecutarse en una TPU v5litepod-16.
Genera un nuevo token de Hugging Face si aún no tienes uno:
- Haz clic en Tu perfil > Configuración > Tokens de acceso.
- Selecciona Token nuevo.
- Especifica el nombre que desees y un rol de al menos Leer.
- Selecciona Generate un token.
Usa tu token de Hugging Face para acceder a Hugging Face desde tu VM de TPU con el siguiente comando.
Reemplaza la variable de token
huggingface-cli login
por la que se generó desde Hugging Face en el paso anterior:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' pip install -U "huggingface_hub[cli]" export PATH="/home/$USER/.local/bin/:$PATH" huggingface-cli login --token hf_abcxyzEFg'
Este comando te permitirá acceder a Hugging Face y mostrar el token activo actual.
Ejecuta el entrenamiento del modelo:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' source ~/venv/bin/activate export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=your-bucket/profile_path python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
El paso de entrenamiento tarda alrededor de 10 minutos. Hacia el final de la capacitación, verás mensajes similares a los siguientes:
[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >> Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >> Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >> Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >> Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >> Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >> Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >> Number of trainable parameters = 8,030,261,248
0%| | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
5%|▌ | 1/20 [00:07<02:29, 7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
5%|▌ | 1/20 [00:07<02:29, 7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
10%|█ | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
10%|█ | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
15%|█▌ | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
15%|█▌ | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
20%|██ | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
85%|████████▌ | 17/20 [07:55<00:06, 2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
Limpia
Una vez que se complete el entrenamiento, sigue el siguiente paso para borrar el recurso en fila y la VM de TPU. De esta manera, se detendrá la facturación por el uso de tu VM de TPU.
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --force \ --async