Treinar um modelo usando a TPU v5e
Com uma pegada menor de 256 chips por pod, a TPU v5e é otimizada para ser um produto de alto valor para treinamento, ajuste e serviço de transformadores, texto para imagem e redes neurais convolucionais (CNNs). Para mais informações sobre como usar a Cloud TPU v5e para veiculação, consulte Inferência usando v5e.
Para mais informações sobre o hardware e as configurações da TPU v5e do Cloud TPU, consulte TPU v5e.
Primeiros passos
As seções a seguir descrevem como começar a usar a TPU v5e.
Solicitação de cotas
Você precisa de cota para usar a TPU v5e no treinamento. Há diferentes tipos de cota para TPUs sob demanda, TPUs reservadas e VMs Spot de TPU. Há cotas separadas necessárias se você estiver usando a TPU v5e para inferência. Para mais informações sobre cotas, consulte Cotas. Para solicitar cota da TPU v5e, entre em contato com a equipe de vendas do Cloud.
Criar uma conta e um projeto do Google Cloud
Você precisa de uma conta e um projeto do Google Cloud para usar o Cloud TPU. Para mais informações, consulte Configurar um ambiente do Cloud TPU.
Criar uma Cloud TPU
A prática recomendada é provisionar Cloud TPU v5e como recursos enfileirados usando o comando queued-resource create
. Para mais informações, consulte
Gerenciar recursos em fila.
Você também pode usar a API Create Node (gcloud compute tpus tpu-vm create
) para
provisionar TPUs v5e do Cloud. Para mais informações, consulte Gerenciar recursos de TPU.
Para mais informações sobre as configurações v5e disponíveis para treinamento, consulte Tipos de Cloud TPU v5e para treinamento.
Configuração do framework
Esta seção descreve o processo geral de configuração para treinamento de modelo personalizado usando JAX ou PyTorch com TPU v5e.
Para instruções de configuração de inferência, consulte Introdução à inferência v5e.
Defina algumas variáveis de ambiente:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
Configuração para JAX
Se você tiver formas de fração maiores que oito chips, terá várias VMs em uma fração. Nesse caso, use a flag --worker=all
para executar a
instalação em todas as VMs da TPU em uma única etapa sem usar SSH para fazer login em cada uma
separadamente:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Descrições de sinalizações de comando
Variável | Descrição |
TPU_NAME | O ID de texto atribuído pelo usuário da TPU criada quando a solicitação de recurso na fila é alocada. |
PROJECT_ID | Google Cloud Nome do projeto. Use um projeto atual ou crie um novo em Configurar seu Google Cloud projeto |
ZONA | Consulte o documento Regiões e zonas de TPU para saber quais são as zonas compatíveis. |
worker | A VM de TPU que tem acesso às TPUs subjacentes. |
Execute o comando a seguir para verificar o número de dispositivos. As saídas mostradas aqui foram produzidas com uma fração v5litepod-16. Esse código testa se tudo está instalado corretamente. Para isso, ele verifica se o JAX vê os TensorCores da Cloud TPU e pode executar operações básicas:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
A saída será semelhante a esta:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
mostra o número total de chips na fração especificada.
jax.local_device_count()
indica a contagem de chips acessíveis por uma única VM nesta fração.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
A saída será semelhante a esta:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
Confira os tutoriais do JAX neste documento para começar a treinar o v5e usando o JAX.
Configuração do PyTorch
A v5e só é compatível com o ambiente de execução PJRT, e o PyTorch 2.1+ usa o PJRT como ambiente de execução padrão para todas as versões de TPU.
Esta seção descreve como começar a usar o PJRT na v5e com PyTorch/XLA com comandos para todos os workers.
Instalar dependências
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Substitua PYTORCH_VERSION
pela versão do PyTorch que você quer usar.
PYTORCH_VERSION
é usado para especificar a mesma versão do PyTorch/XLA. 2.6.0
é recomendado.
Para mais informações sobre as versões do PyTorch e do PyTorch/XLA, consulte PyTorch: primeiros passos e Versões do PyTorch/XLA.
Para mais informações sobre como instalar o PyTorch/XLA, consulte Instalação do PyTorch/XLA.
Se você receber um erro ao instalar as rodas para torch
, torch_xla
ou
torchvision
, como
pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
,
faça downgrade da sua versão com este comando:
pip3 install setuptools==62.1.0
Executar um script com PJRT
unset LD_PRELOAD
Confira um exemplo usando um script Python para fazer um cálculo em uma VM v5e:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Isso gera um resultado semelhante ao seguinte:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
Confira os tutoriais do PyTorch neste documento para começar a treinar com o v5e usando o PyTorch.
Exclua a TPU e o recurso na fila no fim da sessão. Para excluir um recurso em fila, exclua a fração e depois o recurso em fila em duas etapas:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Essas duas etapas também podem ser usadas para remover solicitações de recursos enfileiradas que estão no estado FAILED
.
Exemplos de JAX/FLAX
As seções a seguir descrevem exemplos de como treinar modelos JAX e FLAX em TPU v5e.
Treinar o ImageNet na v5e
Neste tutorial, descrevemos como treinar o ImageNet na v5e usando dados de entrada falsos. Se você quiser usar dados reais, consulte o arquivo README no GitHub.
Configurar
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_ID
O ID do seu projeto Google Cloud . Use um projeto atual ou crie um novo. TPU_NAME
O nome da TPU. ZONE
A zona em que a VM da TPU será criada. Para mais informações sobre as zonas compatíveis, consulte Regiões e zonas de TPU. ACCELERATOR_TYPE
O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSION
A versão do software da Cloud TPU. SERVICE_ACCOUNT
O endereço de e-mail da sua conta de serviço. Para encontrar o ID, acesse a página "Contas de serviço" no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
O ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Você poderá usar o SSH na VM da TPU quando o recurso enfileirado estiver no estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando o QueuedResource estiver no estado
ACTIVE
, a saída será semelhante a esta:state: ACTIVE
Instale a versão mais recente do JAX e do jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Clone o modelo do ImageNet e instale os requisitos correspondentes:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
Para gerar dados falsos, o modelo precisa de informações sobre as dimensões do conjunto de dados. Isso pode ser coletado dos metadados do conjunto de dados do ImageNet:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
Treine o modelo
Depois de concluir todas as etapas anteriores, você pode treinar o modelo.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"
Excluir a TPU e o recurso na fila
Exclua a TPU e o recurso na fila no fim da sessão.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Modelos do Hugging Face FLAX
Os modelos Hugging Face implementados em FLAX funcionam sem configuração na Cloud TPU v5e. Esta seção fornece instruções para executar modelos conhecidos.
Treinar o ViT no Imagenette
Neste tutorial, mostramos como treinar o modelo Vision Transformer (ViT) da HuggingFace usando o conjunto de dados Imagenette da Fast AI no Cloud TPU v5e.
O modelo ViT foi o primeiro a treinar um codificador Transformer no ImageNet com excelentes resultados em comparação com as redes convolucionais. Para mais informações, consulte os seguintes recursos:
Configurar
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_ID
O ID do seu projeto Google Cloud . Use um projeto atual ou crie um novo. TPU_NAME
O nome da TPU. ZONE
A zona em que a VM da TPU será criada. Para mais informações sobre as zonas compatíveis, consulte Regiões e zonas de TPU. ACCELERATOR_TYPE
O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSION
A versão do software da Cloud TPU. SERVICE_ACCOUNT
O endereço de e-mail da sua conta de serviço. Para encontrar o ID, acesse a página "Contas de serviço" no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
O ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Você poderá usar o SSH na VM de TPU quando o recurso enfileirado estiver no estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando o recurso enfileirado estiver no estado
ACTIVE
, a saída será semelhante a esta:state: ACTIVE
Instale o JAX e a biblioteca dele:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Faça o download do repositório do Hugging Face e instale os requisitos:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
Faça o download do conjunto de dados Imagenette:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Treine o modelo
Treine o modelo com um buffer pré-mapeado de 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
Excluir a TPU e o recurso na fila
Exclua a TPU e o recurso na fila no fim da sessão.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Resultados de comparativo de mercado do ViT
O script de treinamento foi executado em v5litepod-4, v5litepod-16 e v5litepod-64. A tabela a seguir mostra as taxas de transferência com diferentes tipos de aceleradores.
Tipo de acelerador | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Período | 3 | 3 | 3 |
Tamanho global do lote | 32 | 128 | 512 |
Capacidade (exemplos/segundo) | 263,40 | 429.34 | 470,71 |
Treinar a difusão no Pokémon
Neste tutorial, mostramos como treinar o modelo Stable Diffusion da HuggingFace usando o conjunto de dados Pokémon no Cloud TPU v5e.
O modelo Stable Diffusion é um modelo de texto latente para imagem que gera imagens fotorrealistas com base em qualquer entrada de texto. Para saber mais, acesse os recursos a seguir (links em inglês):
Configurar
Defina uma variável de ambiente para o nome do bucket de armazenamento:
export GCS_BUCKET_NAME=your_bucket_name
Configure um bucket de armazenamento para a saída do modelo:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_ID
O ID do seu projeto Google Cloud . Use um projeto atual ou crie um novo. TPU_NAME
O nome da TPU. ZONE
A zona em que a VM da TPU será criada. Para mais informações sobre as zonas compatíveis, consulte Regiões e zonas de TPU. ACCELERATOR_TYPE
O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSION
A versão do software da Cloud TPU. SERVICE_ACCOUNT
O endereço de e-mail da sua conta de serviço. Para encontrar o ID, acesse a página "Contas de serviço" no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
O ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Será possível usar o SSH na VM de TPU quando o recurso enfileirado estiver no estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando o recurso na fila estiver no estado
ACTIVE
, a saída será semelhante a esta:state: ACTIVE
Instale o JAX e a biblioteca dele.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Faça o download do repositório do HuggingFace e instale os requisitos.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
Treine o modelo
Treine o modelo com um buffer pré-mapeado de 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
pip3 install gcsfs
export LIBTPU_INIT_ARGS=''
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"
Limpar
Exclua a TPU, o recurso enfileirado e o bucket do Cloud Storage no final da sessão.
Exclua a TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Exclua o recurso na fila:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Exclua o bucket do Cloud Storage:
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
Resultados de comparativos de mercado para difusão
O script de treinamento foi executado em v5litepod-4, v5litepod-16 e v5litepod-64. A tabela a seguir mostra as capacidades de processamento.
Tipo de acelerador | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Etapa de treinamento | 1500 | 1500 | 1500 |
Tamanho global do lote | 32 | 64 | 128 |
Capacidade (exemplos/segundo) | 36,53 | 43,71 | 49,36 |
PyTorch/XLA
As seções a seguir descrevem exemplos de como treinar modelos PyTorch/XLA em TPUs v5e.
Treinar o ResNet usando o ambiente de execução PJRT
O PyTorch/XLA está migrando do XRT para o PjRt no PyTorch 2.0 e versões mais recentes. Confira as instruções atualizadas para configurar a v5e para cargas de trabalho de treinamento do PyTorch/XLA.
Configurar
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_ID
O ID do seu projeto Google Cloud . Use um projeto atual ou crie um novo. TPU_NAME
O nome da TPU. ZONE
A zona em que a VM da TPU será criada. Para mais informações sobre as zonas compatíveis, consulte Regiões e zonas de TPU. ACCELERATOR_TYPE
O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSION
A versão do software da Cloud TPU. SERVICE_ACCOUNT
O endereço de e-mail da sua conta de serviço. Para encontrar o ID, acesse a página "Contas de serviço" no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
O ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Será possível usar o SSH na VM da TPU quando o QueuedResource estiver no estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando o recurso na fila estiver no estado
ACTIVE
, a saída será semelhante a esta:state: ACTIVE
Instalar dependências específicas do Torch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Substitua
PYTORCH_VERSION
pela versão do PyTorch que você quer usar.PYTORCH_VERSION
é usado para especificar a mesma versão do PyTorch/XLA. 2.6.0 é recomendado.Para mais informações sobre as versões do PyTorch e do PyTorch/XLA, consulte PyTorch: primeiros passos e Versões do PyTorch/XLA.
Para mais informações sobre como instalar o PyTorch/XLA, consulte Instalação do PyTorch/XLA.
Treinar o modelo ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git checkout release-r2.6
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
Excluir a TPU e o recurso na fila
Exclua a TPU e o recurso na fila no fim da sessão.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Resultado do comparativo de mercado
A tabela a seguir mostra as capacidades de processamento de comparativo de mercado.
Tipo de acelerador | Capacidade (exemplos/segundo) |
v5litepod-4 | 4240 ex/s |
v5litepod-16 | 10.810 ex/s |
v5litepod-64 | 46.154 ex/s |
Treinar ViT na v5e
Neste tutorial, vamos abordar como executar o VIT na v5e usando o repositório do HuggingFace no PyTorch/XLA no conjunto de dados cifar10.
Configurar
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_ID
O ID do seu projeto Google Cloud . Use um projeto atual ou crie um novo. TPU_NAME
O nome da TPU. ZONE
A zona em que a VM da TPU será criada. Para mais informações sobre as zonas compatíveis, consulte Regiões e zonas de TPU. ACCELERATOR_TYPE
O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSION
A versão do software da Cloud TPU. SERVICE_ACCOUNT
O endereço de e-mail da sua conta de serviço. Para encontrar o ID, acesse a página "Contas de serviço" no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
O ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Você poderá usar o SSH na VM da TPU quando o QueuedResource estiver no estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Quando o recurso enfileirado estiver no estado
ACTIVE
, a saída será semelhante a esta:state: ACTIVE
Instalar dependências do PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Substitua
PYTORCH_VERSION
pela versão do PyTorch que você quer usar.PYTORCH_VERSION
é usado para especificar a mesma versão do PyTorch/XLA. 2.6.0 é recomendado.Para mais informações sobre as versões do PyTorch e do PyTorch/XLA, consulte PyTorch: primeiros passos e Versões do PyTorch/XLA.
Para mais informações sobre como instalar o PyTorch/XLA, consulte Instalação do PyTorch/XLA.
Faça o download do repositório do HuggingFace e instale os requisitos.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Treine o modelo
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
Excluir a TPU e o recurso na fila
Exclua a TPU e o recurso na fila no fim da sessão.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Resultado do comparativo de mercado
A tabela a seguir mostra as taxas de transferência de referência para diferentes tipos de aceleradores.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Período | 3 | 3 | 3 |
Tamanho global do lote | 32 | 128 | 512 |
Capacidade (exemplos/segundo) | 201 | 657 | 2.844 |