Entrenar un modelo con la versión 5e de TPU
Con una superficie de 256 chips por pod, la TPU v5e se ha optimizado para ser un producto de alto valor para el entrenamiento, el ajuste y el servicio de transformadores, texto a imagen y redes neuronales convolucionales (CNN). Para obtener más información sobre cómo usar la versión 5e de TPU de Cloud para el servicio, consulta Inferencia con la versión 5e.
Para obtener más información sobre el hardware y las configuraciones de la TPU v5e de Cloud, consulta TPU v5e.
Empezar
En las siguientes secciones se describe cómo empezar a usar TPU v5e.
Cuota de solicitudes
Necesitas cuota para usar la TPU v5e en el entrenamiento. Hay distintos tipos de cuotas para las TPUs bajo demanda, las TPUs reservadas y las VMs Spot de TPU. Si usas tu TPU v5e para la inferencia, se requieren cuotas independientes. Para obtener más información sobre las cuotas, consulta Cuotas. Para solicitar cuota de TPU v5e, ponte en contacto con el equipo de Ventas de Cloud.
Crea una Google Cloud cuenta y un proyecto
Para usar Cloud TPU, necesitas una Google Cloud cuenta y un proyecto. Para obtener más información, consulta el artículo Configurar un entorno de TPU de Cloud.
Crear una TPU de Cloud
La práctica recomendada es aprovisionar las TPU de Cloud v5e como recursos en cola con el comando queued-resource create
. Para obtener más información, consulta Gestionar recursos en cola.
También puedes usar la API Create Node (gcloud compute tpus tpu-vm create
) para aprovisionar TPUs de Cloud v5e. Para obtener más información, consulta Gestionar recursos de TPU.
Para obtener más información sobre las configuraciones de v5e disponibles para el entrenamiento, consulta Tipos de TPU de Cloud v5e para el entrenamiento.
Configuración del framework
En esta sección se describe el proceso de configuración general para entrenar modelos personalizados con JAX o PyTorch con TPU v5e.
Para obtener instrucciones sobre cómo configurar la inferencia, consulta la introducción a la inferencia de la versión 5e.
Define algunas variables de entorno:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
Configuración de JAX
Si tienes formas de porción con más de 8 chips, tendrás varias máquinas virtuales en una porción. En este caso, debes usar la marca --worker=all
para ejecutar la instalación en todas las VMs de TPU en un solo paso sin usar SSH para iniciar sesión en cada una por separado:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Descripciones de marcas de comandos
Variable | Descripción |
TPU_NAME | Es el ID de texto asignado por el usuario de la TPU que se crea cuando se asigna la solicitud de recurso en cola. |
PROJECT_ID | Google Cloud Nombre del proyecto. Usa un proyecto que ya tengas o crea uno nuevo en Configurar un Google Cloud proyecto |
ZONE | Consulta el documento Regiones y zonas de TPU para ver las zonas admitidas. |
trabajador | La VM de TPU que tiene acceso a las TPUs subyacentes. |
Puedes ejecutar el siguiente comando para comprobar el número de dispositivos (los resultados que se muestran aquí se han obtenido con un slice v5litepod-16). Este código comprueba que todo esté instalado correctamente verificando que JAX detecte los TensorCores de Cloud TPU y pueda ejecutar operaciones básicas:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
La salida será similar a la siguiente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
muestra el número total de chips del segmento en cuestión.
jax.local_device_count()
indica el número de chips a los que puede acceder una sola máquina virtual en esta porción.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
La salida será similar a la siguiente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
Prueba los tutoriales de JAX de este documento para empezar a entrenar la versión 5e con JAX.
Configuración de PyTorch
Ten en cuenta que la versión 5e solo admite el entorno de ejecución de PJRT y que PyTorch 2.1 o versiones posteriores usarán PJRT como entorno de ejecución predeterminado para todas las versiones de TPU.
En esta sección se describe cómo empezar a usar PJRT en la versión 5e con PyTorch/XLA con comandos para todos los trabajadores.
Instalar dependencias
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Sustituye PYTORCH_VERSION
por la versión de PyTorch que quieras usar.
PYTORCH_VERSION
se usa para especificar la misma versión de PyTorch/XLA. Se recomienda la versión 2.6.0.
Para obtener más información sobre las versiones de PyTorch y PyTorch/XLA, consulta PyTorch - Get Started (PyTorch: primeros pasos) y PyTorch/XLA releases (Lanzamientos de PyTorch/XLA).
Para obtener más información sobre cómo instalar PyTorch/XLA, consulta Instalación de PyTorch/XLA.
Si se produce un error al instalar las ruedas de torch
, torch_xla
o torchvision
, como pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
, cambia a una versión anterior con este comando:
pip3 install setuptools==62.1.0
Ejecutar una secuencia de comandos con PJRT
unset LD_PRELOAD
A continuación, se muestra un ejemplo en el que se usa una secuencia de comandos de Python para hacer un cálculo en una VM v5e:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Se generará un resultado similar al siguiente:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
Consulta los tutoriales de PyTorch de este documento para empezar a entrenar la versión 5e con PyTorch.
Elimina tu TPU y el recurso en cola al final de la sesión. Para eliminar un recurso en cola, elimina el segmento y, a continuación, el recurso en cola en dos pasos:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Estos dos pasos también se pueden usar para eliminar solicitudes de recursos en cola que estén en el estado FAILED
.
Ejemplos de JAX/FLAX
En las siguientes secciones se describen ejemplos de cómo entrenar modelos de JAX y Flax en TPU v5e.
Entrenar ImageNet en v5e
En este tutorial se describe cómo entrenar ImageNet en v5e con datos de entrada falsos. Si quieres usar datos reales, consulta el archivo README en GitHub.
Configurar
Crea variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descripciones de las variables de entorno
Variable Descripción PROJECT_ID
El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno. TPU_NAME
El nombre de la TPU. ZONE
La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU. ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. RUNTIME_VERSION
La versión de software de la TPU de Cloud. SERVICE_ACCOUNT
La dirección de correo de tu cuenta de servicio. Para encontrarlo, ve a la página Cuentas de servicio de la consola Google Cloud . Por ejemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texto asignado por el usuario de la solicitud de recurso en cola. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Podrás conectarte a tu VM de TPU mediante SSH cuando el recurso en cola esté en estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Cuando QueuedResource esté en el estado
ACTIVE
, el resultado será similar al siguiente:state: ACTIVE
Instala la versión más reciente de JAX y jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Clona el modelo ImageNet e instala los requisitos correspondientes:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
Para generar datos falsos, el modelo necesita información sobre las dimensiones del conjunto de datos. Esta información se puede obtener de los metadatos del conjunto de datos ImageNet:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
Preparar el modelo
Una vez que hayas completado todos los pasos anteriores, podrás entrenar el modelo.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"
Eliminar la TPU y el recurso en cola
Elimina tu TPU y el recurso en cola al final de la sesión.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Modelos de Flax de Hugging Face
Los modelos de Hugging Face implementados en Flax funcionan sin problemas en Cloud TPU v5e. En esta sección se proporcionan instrucciones para ejecutar modelos populares.
Entrenar ViT en Imagenette
En este tutorial se muestra cómo entrenar el modelo Vision Transformer (ViT) de Hugging Face con el conjunto de datos Imagenette de Fast AI en la versión 5e de TPU de Cloud.
El modelo ViT fue el primero que entrenó correctamente un codificador Transformer en ImageNet con resultados excelentes en comparación con las redes convolucionales. Para obtener más información, consulta los siguientes recursos:
Configurar
Crea variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descripciones de las variables de entorno
Variable Descripción PROJECT_ID
El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno. TPU_NAME
El nombre de la TPU. ZONE
La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU. ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. RUNTIME_VERSION
La versión de software de la TPU de Cloud. SERVICE_ACCOUNT
La dirección de correo de tu cuenta de servicio. Para encontrarlo, ve a la página Cuentas de servicio de la consola Google Cloud . Por ejemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texto asignado por el usuario de la solicitud de recurso en cola. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Podrás conectarte a tu VM de TPU mediante SSH cuando el recurso en cola tenga el estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Cuando el recurso en cola esté en el estado
ACTIVE
, el resultado será similar al siguiente:state: ACTIVE
Instala JAX y su biblioteca:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Descarga el repositorio de Hugging Face e instala los requisitos:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
Descarga el conjunto de datos Imagenette:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Preparar el modelo
Entrena el modelo con un búfer preasignado de 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
Eliminar la TPU y el recurso en cola
Elimina tu TPU y tu recurso en cola al final de la sesión.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Resultados de las comparativas de ViT
La secuencia de comandos de entrenamiento se ejecutó en v5litepod-4, v5litepod-16 y v5litepod-64. En la siguiente tabla se muestran los rendimientos con diferentes tipos de aceleradores.
Tipo de acelerador | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Época | 3 | 3 | 3 |
Tamaño de lote global | 32 | 128 | 512 |
Rendimiento (ejemplos/s) | 263,40 EGP | 429,34 | 470,71 |
Entrenar Diffusion con Pokémon
En este tutorial se explica cómo entrenar el modelo Stable Diffusion de Hugging Face con el conjunto de datos Pokémon en la TPU v5e de Cloud.
El modelo Stable Diffusion es un modelo de conversión de texto a imagen latente que genera imágenes fotorrealistas a partir de cualquier entrada de texto. Para obtener más información, consulta los siguientes recursos:
Configurar
Define una variable de entorno para el nombre de tu segmento de almacenamiento:
export GCS_BUCKET_NAME=your_bucket_name
Configura un segmento de almacenamiento para la salida del modelo:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
Crea variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descripciones de las variables de entorno
Variable Descripción PROJECT_ID
El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno. TPU_NAME
El nombre de la TPU. ZONE
La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU. ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. RUNTIME_VERSION
La versión de software de la TPU de Cloud. SERVICE_ACCOUNT
La dirección de correo de tu cuenta de servicio. Para encontrarlo, ve a la página Cuentas de servicio de la consola Google Cloud . Por ejemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texto asignado por el usuario de la solicitud de recurso en cola. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Podrás conectarte a tu VM de TPU mediante SSH cuando el recurso en cola esté en el estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Cuando el recurso en cola esté en el estado
ACTIVE
, el resultado será similar al siguiente:state: ACTIVE
Instala JAX y su biblioteca.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Descarga el repositorio de Hugging Face e instala los requisitos.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
Preparar el modelo
Entrena el modelo con un búfer preasignado de 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
pip3 install gcsfs
export LIBTPU_INIT_ARGS=''
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"
Limpieza
Elimina tu TPU, el recurso en cola y el segmento de Cloud Storage al final de la sesión.
Elimina tu TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Elimina el recurso en cola:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Elimina el segmento de Cloud Storage:
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
Resultados de las comparativas de la difusión
La secuencia de comandos de entrenamiento se ejecutó en v5litepod-4, v5litepod-16 y v5litepod-64. En la tabla siguiente se muestran los rendimientos.
Tipo de acelerador | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Paso de entrenamiento | 1500 | 1500 | 1500 |
Tamaño de lote global | 32 | 64 | 128 |
Rendimiento (ejemplos/s) | 36,53 | 43,71 | 49,36 |
PyTorch/XLA
En las siguientes secciones se describen ejemplos de cómo entrenar modelos de PyTorch/XLA en TPU v5e.
Entrenar ResNet con el tiempo de ejecución de PJRT
PyTorch/XLA está migrando de XRT a PjRt a partir de PyTorch 2.0. Aquí tienes las instrucciones actualizadas para configurar v5e para cargas de trabajo de entrenamiento de PyTorch/XLA.
Configurar
Crea variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descripciones de las variables de entorno
Variable Descripción PROJECT_ID
El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno. TPU_NAME
El nombre de la TPU. ZONE
La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU. ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. RUNTIME_VERSION
La versión de software de la TPU de Cloud. SERVICE_ACCOUNT
La dirección de correo de tu cuenta de servicio. Para encontrarlo, ve a la página Cuentas de servicio de la consola Google Cloud . Por ejemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texto asignado por el usuario de la solicitud de recurso en cola. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Podrás conectarte a tu VM de TPU mediante SSH cuando tu QueuedResource esté en estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Cuando el recurso en cola esté en el estado
ACTIVE
, el resultado será similar al siguiente:state: ACTIVE
Instalar dependencias específicas de Torch o XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Sustituye
PYTORCH_VERSION
por la versión de PyTorch que quieras usar.PYTORCH_VERSION
se usa para especificar la misma versión de PyTorch/XLA. Se recomienda la versión 2.6.0.Para obtener más información sobre las versiones de PyTorch y PyTorch/XLA, consulta PyTorch - Get Started (PyTorch: primeros pasos) y PyTorch/XLA releases (Lanzamientos de PyTorch/XLA).
Para obtener más información sobre cómo instalar PyTorch/XLA, consulta Instalación de PyTorch/XLA.
Entrenar el modelo ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git checkout release-r2.6
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
Eliminar la TPU y el recurso en cola
Elimina tu TPU y el recurso en cola al final de la sesión.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Resultado de la comparativa
En la tabla siguiente se muestran los rendimientos de referencia.
Tipo de acelerador | Rendimiento (ejemplos/segundo) |
v5litepod-4 | 4240 ex/s |
v5litepod-16 | 10.810 ex/s |
v5litepod-64 | 46.154 ex/s |
Entrenar ViT en v5e
En este tutorial se explica cómo ejecutar VIT en la versión 5e mediante el repositorio de Hugging Face en PyTorch/XLA con el conjunto de datos cifar10.
Configurar
Crea variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descripciones de las variables de entorno
Variable Descripción PROJECT_ID
El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno. TPU_NAME
El nombre de la TPU. ZONE
La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU. ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. RUNTIME_VERSION
La versión de software de la TPU de Cloud. SERVICE_ACCOUNT
La dirección de correo de tu cuenta de servicio. Para encontrarlo, ve a la página Cuentas de servicio de la consola Google Cloud . Por ejemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID de texto asignado por el usuario de la solicitud de recurso en cola. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Podrás conectarte a tu VM de TPU mediante SSH cuando tu recurso QueuedResource esté en el estado
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Cuando el recurso en cola esté en el estado
ACTIVE
, el resultado será similar al siguiente:state: ACTIVE
Instalar dependencias de PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Sustituye
PYTORCH_VERSION
por la versión de PyTorch que quieras usar.PYTORCH_VERSION
se usa para especificar la misma versión de PyTorch/XLA. Se recomienda la versión 2.6.0.Para obtener más información sobre las versiones de PyTorch y PyTorch/XLA, consulta PyTorch - Get Started (PyTorch: primeros pasos) y PyTorch/XLA releases (Lanzamientos de PyTorch/XLA).
Para obtener más información sobre cómo instalar PyTorch/XLA, consulta Instalación de PyTorch/XLA.
Descarga el repositorio de Hugging Face e instala los requisitos.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Preparar el modelo
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
Eliminar la TPU y el recurso en cola
Elimina tu TPU y el recurso en cola al final de la sesión.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Resultado de la comparativa
En la siguiente tabla se muestran los rendimientos de referencia de los distintos tipos de aceleradores.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Época | 3 | 3 | 3 |
Tamaño de lote global | 32 | 128 | 512 |
Rendimiento (ejemplos/s) | 201 | 657 | 2844 |