Entrena Llama 3 con PyTorch en TPU v5e

En este instructivo, se describe cómo entrenar un modelo Llama-3-8B con PyTorch/XLA en una TPU v5e con el conjunto de datos WikiText. Consulta Meta-Llama-3-8B para obtener detalles del modelo.

El modelo Llama-3-8B se aloja en la plataforma de Hugging Face.

Existen dos versiones de Meta-Llama-3-8B, una para usar con Transformers y otra con la base de código original de Llama 3. En este instructivo, se usa la versión de Transformers porque tiene las siguientes características:

  • Se integra sin problemas en el ecosistema de Hugging Face: Esto facilita la optimización del modelo, el uso de canalizaciones precompiladas y el acceso a una amplia colección de conjuntos de datos y herramientas.

  • Habilita la flexibilidad y la personalización: La versión de Transformers ofrece opciones de personalización y flexibilidad significativas para ajustar y, luego, implementar el modelo.

  • Proporciona asistencia de la comunidad: La comunidad de Hugging Face proporciona documentación, instructivos y asistencia extensos para usar modelos de Transformers.

Para obtener más información sobre los transformadores, consulta la documentación de Hugging Face Transformers.

Para acceder al modelo Meta-Llama-3-8B y usarlo, incluida la descarga de sus pesos y el analizador, necesitas un token de acceso de usuario de Hugging Face. El token proporciona lo siguiente:

  • Autenticación y autorización: El token de acceso actúa como una credencial y permite que los servidores de Hugging Face autoricen tu acceso a los recursos del modelo. Esto garantiza que solo los usuarios autorizados puedan descargar y usar el modelo.

  • Seguridad: Hugging Face usa tokens de acceso para proteger sus modelos y evitar el acceso no autorizado o el uso inadecuado.

Para obtener información sobre cómo crear y usar un token de acceso para este instructivo, consulta Ejecuta el modelo. Para obtener información más detallada sobre cómo crear y usar tokens de acceso, consulta la documentación de Hugging Face sobre los tokens de acceso de los usuarios.

También necesitas permiso para acceder al modelo Llama 3 8B en Hugging Face. Para obtener ese permiso, ve al modelo Meta-Llama-3-8B en Hugging Face y solicita acceso.

Preparación para aprovisionar un TPU v5litepod-16

Este instructivo se probó con las siguientes variables de entorno de Cloud TPU. Puedes usar otras variables para aprovisionar tu TPU, siempre que el tipo de acelerador, la zona y la versión del entorno de ejecución sean compatibles. Por ejemplo, en este instructivo, se usa europe-west4-b como la zona. Puedes usar cualquier otra zona que admita la versión de TPU (tipo de acelerador) que ejecutes (v5litepod-16 en este instructivo).

Configura las siguientes variables de entorno de la VM de TPU.

   export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id
   export PROJECT_ID=your-project-id
   export ACCELERATOR_TYPE=v5litepod-16
   export ZONE=europe-west4-b
   export RUNTIME_VERSION=v2-alpha-tpuv5-lite
   export QUEUED_RESOURCE_ID=queued-resource-id
   export VALID_UNTIL_DURATION=1d

Cuando tengas acceso al modelo Meta-Llama-3-8B en Hugging Face, prepara el entorno de TPU para ejecutar el instructivo.

  1. Sigue la guía Configura el entorno de Cloud TPU para asegurarte de tener el acceso adecuado para usar Cloud TPU.

  2. Crea una identidad de servicio para la VM de TPU.

    gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
  3. Crea una cuenta de servicio de TPU y otorga acceso a los servicios de Google Cloud .

    Las cuentas de servicio permiten que el servicio de Google Cloud TPU acceda a otros servicios Google Cloud. Se recomienda una cuenta de servicio administrada por el usuario. Puedes crear una cuenta de servicio desde la consola de Google Cloud o con el comando gcloud.

    Crea una cuenta de servicio con la herramienta de línea de comandos de gcloud:

    gcloud iam service-accounts create your-service-account-name \
    --description="your-sa-description" \
    --display-name="your-sa-display-name"
    export SERVICE_ACCOUNT_NAME=your-service-account-name

    Crea una cuenta de servicio desde la consola de Google Cloud:

    1. Ve a la página Cuentas de servicio en la consola de Google Cloud.
    2. Haga clic en Crear cuenta de servicio.
    3. Ingresa el nombre de la cuenta de servicio.
    4. Opcional: Ingresa una descripción para la cuenta de servicio.
    5. Haz clic en Crear y continúa.
    6. Elige los roles que deseas otorgar a la cuenta de servicio.
    7. Haz clic en Continuar.
    8. (Opcional) Especifica los usuarios o grupos que pueden administrar la cuenta de servicio.
    9. Haz clic en Listo para terminar de crear la cuenta de servicio.

    Después de crear tu cuenta de servicio, sigue estos pasos para otorgarle roles.

    Se requieren los siguientes roles:

    • Administrador de TPU: Es necesario para crear una TPU.
    • Administrador de almacenamiento: Es necesario para acceder a Cloud Storage.
    • Escritor de registros
    • Escritor de métricas de Monitoring: Es necesario para escribir métricas en Cloud Monitoring.

    Tu administrador debe otorgarte el rol roles/resourcemanager.projectIamAdmin para que puedas asignar roles de IAM a los usuarios. Un usuario con el rol roles/resourcemanager.projectIamAdmin de administrador de IAM del proyecto también puede otorgar este rol.

    Usa los siguientes comandos gcloud para agregar roles de cuenta de servicio:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/tpu.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/storage.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/logging.logWriter
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/monitoring.metricWriter

    También puedes asignar roles con la consola de Google Cloud.

    En la consola de Google Cloud, selecciona los siguientes roles:

    1. Selecciona tu cuenta de servicio y haz clic en Agregar principal.
    2. En el campo Principales nuevas, ingresa la dirección de correo electrónico de tu cuenta de servicio.
    3. En el menú desplegable Seleccionar un rol, busca el rol (por ejemplo, Administrador de almacenamiento) y selecciónalo.
    4. Haz clic en Guardar.
  4. Autentícate con Google Cloud y configura el proyecto y la zona predeterminados para Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Cómo proteger la capacidad

Cuando tengas todo listo para proteger la capacidad de las TPU, revisa la página de cuotas para obtener información sobre el sistema de Cloud Quotas. Si tienes más preguntas sobre cómo reservar la capacidad, comunícate con tu equipo de ventas o de cuentas de Cloud TPU.

Aprovisiona el entorno de Cloud TPU

Puedes aprovisionar VMs de TPU con GKE, con GKE y XPK, o como recursos en cola.

Requisitos previos

  • Este instructivo se probó con Python 3.10 o versiones posteriores.
  • Verifica que tu proyecto tenga suficiente cuota de TPUS_PER_TPU_FAMILY, que especifica la cantidad máxima de chips a los que puedes acceder en tu proyecto deGoogle Cloud .
  • Verifica que tu proyecto tenga suficiente cuota de TPU para lo siguiente:
    • Cuota de VM de TPU
    • Cuota de direcciones IP
    • Quota de Hyperdisk Balanced
  • Permisos del proyecto del usuario

Aprovisiona un TPU v5litepod-16

  1. Crea una VM de TPU:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID}   \
    --node-id=${TPU_NAME}  \
    --project=${PROJECT_ID}   \
    --zone=${ZONE}   \
    --accelerator-type=${ACCELERATOR_TYPE}   \
    --runtime-version=${RUNTIME_VERSION}   \
    --service-account=${SERVICE_ACCOUNT_NAME}   \
    --spot
  2. Verifica que la TPU esté en el estado ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project=${PROJECT_ID} \
    --zone=${ZONE}

Cuando la TPU se active (ACTIVE), verás un resultado similar al siguiente:

createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
  state: ACTIVE
tpu:
  nodeSpec:
  - node:
      acceleratorType: v5litepod-16
      networkConfig:
        enableExternalIps: true
        network: default
      queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
      runtimeVersion: v2-alpha-tpuv5-lite
      schedulingConfig: {}
      my-service-account@your-project-id.iam.gserviceaccount.com
        email: 19672137854-compute@developer.iam.gserviceaccount.com
      shieldedInstanceConfig: {}
    nodeId: tpu-name
    parent: projects/19672137403/locations/zone

Instalación

Instala la división pytorch-tpu/transformers de Hugging Face Transformers y las dependencias. Este instructivo se probó con las siguientes versiones de dependencias:

  • torch: Compatible con 2.6.0
  • torch_xla[tpu]: Compatible con 2.6.0
  • jax: 0.4.38
  • jaxlib: 0.4.38

Instala el software y las dependencias del framework

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    sudo apt install python3.10-venv
    python -m venv /home/$USER/venv/
    source ~/venv/bin/activate
    cd transformers
    pip3 install --user -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Cuando se complete la instalación, verás un resultado similar al siguiente:

Collecting jax==0.4.38
  Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
  Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
  Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
  Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0

Configura las configuraciones del modelo

El comando de entrenamiento en la siguiente sección, Run the model, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de FSDP (paralelismo de datos completamente fragmentado). El fragmentación de FSDP se usa para que los pesos del modelo se ajusten a un tamaño de lote más grande durante el entrenamiento. Cuando se entrena con modelos más pequeños, puede ser suficiente usar el paralelismo de datos y replicar los pesos en cada dispositivo. Para obtener más información sobre cómo dividir tensores en varios dispositivos en PyTorch/XLA, consulta la Guía del usuario de SPMD de PyTorch/XLA.

  1. Con este comando, se crea el archivo de configuración de parámetros del modelo para Llama3-8B. Para otros modelos, busca la configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama2-7B.

    cat > llama-config.json <<EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. Crea el archivo de configuración de FSDP:

    cat > fsdp-config.json <<EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Para obtener más información sobre FSDP, consulta FSDPv2.

  3. Sube los archivos de configuración a tus VMs de TPU con los siguientes comandos:

     ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent.
     gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

    Este comando generará un resultado similar al siguiente:

    Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers.
    SCP: Attempting to connect to worker 0...
    SCP: Attempting to connect to worker 1...
    SCP: Attempting to connect to worker 2...
    SCP: Attempting to connect to worker 3...
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.0KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00

Ejecuta el modelo

Con los archivos de configuración que creaste en la sección anterior, ejecuta la secuencia de comandos run_clm.py para entrenar el modelo Llama 3 8B en el conjunto de datos de WikiText. La secuencia de comandos de entrenamiento tarda aproximadamente 10 minutos en ejecutarse en una TPU v5litepod-16.

  1. Genera un nuevo token de Hugging Face si aún no tienes uno:

    1. Haz clic en Tu perfil > Configuración > Tokens de acceso.
    2. Selecciona Token nuevo.
    3. Especifica el nombre que desees y un rol de al menos Leer.
    4. Selecciona Generate un token.
  2. Usa tu token de Hugging Face para acceder a Hugging Face desde tu VM de TPU con el siguiente comando.

    Reemplaza la variable de token huggingface-cli login por la que se generó desde Hugging Face en el paso anterior:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    pip install -U "huggingface_hub[cli]"
    export PATH="/home/$USER/.local/bin/:$PATH"
    huggingface-cli login --token hf_abcxyzEFg'

    Este comando te permitirá acceder a Hugging Face y mostrar el token activo actual.

  3. Ejecuta el entrenamiento del modelo:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --worker=all \
        --command='
        source ~/venv/bin/activate
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=your-bucket/profile_path
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

El paso de entrenamiento tarda alrededor de 10 minutos. Hacia el final de la capacitación, verás mensajes similares a los siguientes:

[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >>   Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >>   Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >>   Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >>   Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >>   Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >>   Number of trainable parameters = 8,030,261,248
  0%|          | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
  5%|         | 1/20 [00:07<02:29,  7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  5%|         | 1/20 [00:07<02:29,  7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
 10%|         | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
 10%|█         | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
 15%|█▌        | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
 15%|█▌        | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
 20%|██        | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
 85%|████████▌ | 17/20 [07:55<00:06,  2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)

100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>

Training completed. Do not forget to share your model on huggingface.co/models =)

Limpia

Una vez que se complete el entrenamiento, sigue el siguiente paso para borrar el recurso en fila y la VM de TPU. De esta manera, se detendrá la facturación por el uso de tu VM de TPU.

  gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --force \
       --async