Treinar o Llama 3 usando o PyTorch na TPU v5e

Este tutorial descreve como treinar um modelo Llama-3-8B usando PyTorch/XLA em TPU v5e usando o conjunto de dados WikiText. Consulte Meta-Llama-3-8B para detalhes do modelo.

O modelo Llama-3-8B é hospedado na plataforma Hugging Face.

Há duas versões do Meta-Llama-3-8B, uma para uso com Transformers e outra com a base de código original do Llama 3. Este tutorial usa a versão Transformers porque ela:

  • Integração perfeita com o ecossistema do Hugging Face: isso facilita o ajuste fino do modelo, o uso de pipelines pré-criados e o acesso a uma vasta coleção de conjuntos de dados e ferramentas.

  • Permite flexibilidade e personalização: a versão Transformers oferece opções significativas de flexibilidade e personalização para ajustar e implantar o modelo.

  • Oferece suporte da comunidade: a comunidade do Hugging Face oferece documentação, tutoriais e suporte amplos para o uso de modelos Transformers.

Para mais informações sobre os transformadores, consulte a documentação do Huggging Face Transformers.

Para acessar e usar o modelo Meta-Llama-3-8B, incluindo o download dos pesos e do tokenizer, você precisa de um token de acesso do usuário do Hugging Face. O token oferece:

  • Autenticação e autorização: o token de acesso funciona como uma credencial e permite que os servidores do Hugging Face autorizem seu acesso aos recursos do modelo. Isso garante que apenas usuários autorizados possam fazer o download e usar o modelo.

  • Segurança: o Hugging Face usa tokens de acesso para proteger os modelos e evitar acesso ou uso indevido.

Para informações sobre como criar e usar um token de acesso para este tutorial, consulte Executar o modelo. Para informações mais completas sobre a criação e o uso de tokens de acesso, consulte a documentação do Hugging Face sobre tokens de acesso do usuário.

Você também precisa de permissão para acessar o modelo Llama 3 8B no Hugging Face. Para receber essa permissão, acesse o modelo Meta-Llama-3-8B no Hugging Face e solicite acesso.

Preparar para provisionar uma TPU v5litepod-16

Este tutorial foi testado usando as seguintes variáveis de ambiente do Cloud TPU. É possível usar outras variáveis para provisionar a TPU, desde que o tipo de acelerador, a zona e a versão do ambiente de execução sejam compatíveis. Por exemplo, neste tutorial, europe-west4-b é usado como a zona. É possível usar qualquer outra zona que ofereça suporte à versão da TPU (tipo de acelerador) que você está executando (v5litepod-16 neste tutorial).

Defina as seguintes variáveis de ambiente da VM TPU.

   export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id
   export PROJECT_ID=your-project-id
   export ACCELERATOR_TYPE=v5litepod-16
   export ZONE=europe-west4-b
   export RUNTIME_VERSION=v2-alpha-tpuv5-lite
   export QUEUED_RESOURCE_ID=queued-resource-id
   export VALID_UNTIL_DURATION=1d

Quando você tiver acesso ao modelo Meta-Llama-3-8B no Hugging Face, prepare o ambiente da TPU para executar o tutorial.

  1. Siga o guia Configurar o ambiente do Cloud TPU para garantir que você tenha o acesso adequado para usar o Cloud TPU.

  2. Crie uma identidade de serviço para a VM da TPU.

    gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
  3. Crie uma conta de serviço de TPU e conceda acesso aos serviços Google Cloud .

    As contas de serviço permitem que o serviço Google Cloud TPU acesse outros serviços Google Cloud. Recomendamos usar uma conta de serviço gerenciada pelo usuário. É possível criar uma conta de serviço no console do Google Cloud ou com o comando gcloud.

    Crie uma conta de serviço usando a ferramenta de linha de comando gcloud:

    gcloud iam service-accounts create your-service-account-name \
    --description="your-sa-description" \
    --display-name="your-sa-display-name"
    export SERVICE_ACCOUNT_NAME=your-service-account-name

    Crie uma conta de serviço no console do Google Cloud:

    1. Acesse a página "Contas de serviço" no console do Google Cloud.
    2. Clique em Criar conta de serviço.
    3. Insira o nome da conta de serviço.
    4. (Opcional) Digite uma descrição para a conta de serviço.
    5. Clique em Criar e continue.
    6. Escolha os papéis que você quer conceder à conta de serviço.
    7. Clique em Continuar.
    8. (Opcional) Especifique os usuários ou grupos que podem gerenciar a conta de serviço.
    9. Clique em Concluído para terminar a criação da conta de serviço.

    Depois de criar a conta de serviço, siga estas etapas para conceder papéis a ela.

    Os seguintes papéis são necessários:

    • Administrador da TPU: necessário para criar uma TPU.
    • Administrador do Storage: necessário para acessar o Cloud Storage.
    • Gravador de registros
    • Gravador de métricas do Monitoring: necessário para gravar métricas no Cloud Monitoring.

    O administrador precisa conceder a você o roles/resourcemanager.projectIamAdmin para que você possa atribuir papéis do IAM aos usuários. Um usuário com o papel de administrador do IAM do projeto roles/resourcemanager.projectIamAdmin também pode conceder esse papel.

    Use os comandos gcloud a seguir para adicionar papéis de conta de serviço:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/tpu.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/storage.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/logging.logWriter
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/monitoring.metricWriter

    Também é possível atribuir papéis usando o console do Google Cloud.

    No console do Google Cloud, selecione os seguintes papéis:

    1. Selecione sua conta de serviço e clique em Adicionar principal.
    2. No campo Novos participantes, insira o endereço de e-mail da sua conta de serviço.
    3. No menu suspenso Selecionar um papel, pesquise e selecione o papel (por exemplo, Administrador do Storage).
    4. Clique em Salvar.
  4. Faça a autenticação com Google Cloud e configure o projeto e a zona padrão para a Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Capacidade segura

Quando estiver tudo pronto para garantir a capacidade do TPU, consulte a página de cotas para saber mais sobre o sistema de cotas do Cloud. Se você tiver outras dúvidas sobre como garantir capacidade, entre em contato com a equipe de vendas ou de contas do Cloud TPU.

Provisionar o ambiente do Cloud TPU

É possível provisionar VMs do TPU com o GKE, com o GKE e o XPK ou como recursos em fila.

Pré-requisitos

  • Este tutorial foi testado com o Python 3.10 ou mais recente.
  • Verifique se o projeto tem cota de TPUS_PER_TPU_FAMILY suficiente, que especifica o número máximo de chips que você pode acessar no projetoGoogle Cloud .
  • Verifique se o projeto tem cota suficiente de TPU para:
    • Cota de VM de TPU
    • Quota de endereço IP
    • Cota do Hyperdisk equilibrado
  • Permissões do projeto do usuário

Provisionar um TPU v5litepod-16

  1. Crie uma VM de TPU:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID}   \
    --node-id=${TPU_NAME}  \
    --project=${PROJECT_ID}   \
    --zone=${ZONE}   \
    --accelerator-type=${ACCELERATOR_TYPE}   \
    --runtime-version=${RUNTIME_VERSION}   \
    --service-account=${SERVICE_ACCOUNT_NAME}   \
    --spot
  2. Verifique se a TPU está no estado ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project=${PROJECT_ID} \
    --zone=${ZONE}

Quando a TPU se tornar ativa (ACTIVE), você verá uma saída semelhante a:

createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
  state: ACTIVE
tpu:
  nodeSpec:
  - node:
      acceleratorType: v5litepod-16
      networkConfig:
        enableExternalIps: true
        network: default
      queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
      runtimeVersion: v2-alpha-tpuv5-lite
      schedulingConfig: {}
      my-service-account@your-project-id.iam.gserviceaccount.com
        email: 19672137854-compute@developer.iam.gserviceaccount.com
      shieldedInstanceConfig: {}
    nodeId: tpu-name
    parent: projects/19672137403/locations/zone

Instalação

Instale o pytorch-tpu/transformers fork dos transformadores e dependências do Hugging Face. Este tutorial foi testado com as seguintes versões de dependência:

  • torch: compatível com 2.6.0
  • torch_xla[tpu]: compatível com 2.6.0
  • jax: 0.4.38
  • jaxlib: 0.4.38

Instalar o software e as dependências do framework

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    sudo apt install python3.10-venv
    python -m venv /home/$USER/venv/
    source ~/venv/bin/activate
    cd transformers
    pip3 install --user -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Quando a instalação for concluída, você verá uma saída semelhante a esta:

Collecting jax==0.4.38
  Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
  Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
  Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
  Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0

Configurar configurações do modelo

O comando de treinamento na próxima seção, Executar o modelo, usa dois arquivos de configuração JSON para definir parâmetros de modelo e a configuração de dados paralelos totalmente fragmentados (FSDP, na sigla em inglês). O sharding do FSDP é usado para que os pesos do modelo se ajustem a um tamanho de lote maior durante o treinamento. Ao treinar com modelos menores, pode ser suficiente usar o paralelismo de dados e replicar os pesos em cada dispositivo. Para mais informações sobre como dividir tensores em dispositivos no PyTorch/XLA, consulte o Guia do usuário do SPMD do PyTorch/XLA.

  1. Esse comando cria o arquivo de configuração de parâmetros do modelo para Llama3-8B. Para outros modelos, encontre a configuração no Hugging Face. Por exemplo, consulte a configuração Llama2-7B.

    cat > llama-config.json <<EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. Crie o arquivo de configuração do FSDP:

    cat > fsdp-config.json <<EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Para mais informações sobre o FSDP, consulte FSDPv2.

  3. Faça o upload dos arquivos de configuração para as VMs da TPU usando os seguintes comandos:

     ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent.
     gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

    Esse comando vai gerar uma saída semelhante a esta:

    Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers.
    SCP: Attempting to connect to worker 0...
    SCP: Attempting to connect to worker 1...
    SCP: Attempting to connect to worker 2...
    SCP: Attempting to connect to worker 3...
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.0KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00

Executar o modelo

Usando os arquivos de configuração criados na seção anterior, execute o script run_clm.py para treinar o modelo Llama 3 8B no conjunto de dados do WikiText. O script de treinamento leva aproximadamente 10 minutos para ser executado em uma TPU v5litepod-16.

  1. Gere um novo token do Hugging Face, caso ainda não tenha um:

    1. Clique em Seu perfil > Configurações > Tokens de acesso.
    2. Selecione Novo token.
    3. Especifique um Nome de sua escolha e um Papel de pelo menos "Ler".
    4. Selecione Gerar um token.
  2. Use seu token do Hugging Face para fazer login no Hugging Face na sua VM da TPU usando o comando abaixo.

    Substitua a variável de token huggingface-cli login pela que foi gerada pelo Hugging Face na etapa anterior:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    pip install -U "huggingface_hub[cli]"
    export PATH="/home/$USER/.local/bin/:$PATH"
    huggingface-cli login --token hf_abcxyzEFg'

    Esse comando vai fazer login no Hugging Face e mostrar o token ativo atual.

  3. Execute o treinamento do modelo:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --worker=all \
        --command='
        source ~/venv/bin/activate
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=your-bucket/profile_path
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

A etapa de treinamento leva cerca de 10 minutos. No final do treinamento, você vai receber mensagens semelhantes a esta:

[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >>   Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >>   Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >>   Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >>   Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >>   Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >>   Number of trainable parameters = 8,030,261,248
  0%|          | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
  5%|         | 1/20 [00:07<02:29,  7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  5%|         | 1/20 [00:07<02:29,  7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
 10%|         | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
 10%|█         | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
 15%|█▌        | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
 15%|█▌        | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
 20%|██        | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
 85%|████████▌ | 17/20 [07:55<00:06,  2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)

100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>

Training completed. Do not forget to share your model on huggingface.co/models =)

Limpar

Após o término do treinamento, use a etapa a seguir para excluir o recurso em fila e a VM de TPU. Isso vai interromper o faturamento do uso da VM TPU.

  gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --force \
       --async