Entrenamiento con aceleradores de TPU

Vertex AI permite el entrenamiento con diversos frameworks y bibliotecas mediante una VM de TPU. Al configurar los recursos de computación, puedes especificar máquinas virtuales TPU v2, TPU v3 o TPU v5e. TPU v5e admite JAX 0.4.6+, TensorFlow 2.15+ y PyTorch 2.1+. TPU v6e admite Python 3.10+, JAX 0.4.37+ y PyTorch 2.1+ con PJRT como tiempo de ejecución predeterminado.

Para obtener información sobre cómo configurar máquinas virtuales de TPU para el entrenamiento personalizado, consulta Configurar recursos de computación para el entrenamiento personalizado.

Entrenamiento de TensorFlow

Contenedor prediseñado

Usa un contenedor de entrenamiento prediseñado que admita TPUs y crea una aplicación de entrenamiento de Python.

Contenedor personalizado

Usa un contenedor personalizado en el que hayas instalado versiones de tensorflow y libtpu creadas específicamente para las VMs de TPU. El servicio TPU de Cloud se encarga del mantenimiento de estas bibliotecas, que se enumeran en la documentación sobre las configuraciones de TPU compatibles.

Selecciona la versión tensorflow que quieras y la biblioteca libtpu correspondiente. A continuación, instálalos en la imagen de contenedor Docker al compilar el contenedor.

Por ejemplo, si quieres usar TensorFlow 2.15, incluye las siguientes instrucciones en tu Dockerfile:

  # Download and install `tensorflow`.
  RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

  # Download and install `libtpu`.
  # You must save `libtpu.so` in the '/lib' directory of the container image.
  RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

  # TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
  # runtime, configure the following environment variables in your Dockerfile.
  # For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
  # ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
  # ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so

pod de TPU

tensorflow entrenar a un TPU Pod requiere una configuración adicional en el contenedor de entrenamiento. Vertex AI mantiene una imagen de Docker base que gestiona la configuración inicial.

URIs de imagen Versión de Python
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
Python 3.8
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
Python 3.10

Para crear un contenedor personalizado, sigue estos pasos:

  1. Elige la imagen base de la versión de Python que quieras. Las ruedas de TensorFlow para TensorFlow 2.12 y versiones anteriores son compatibles con Python 3.8. TensorFlow 2.13 y versiones posteriores son compatibles con Python 3.10 o versiones posteriores. Para ver las versiones específicas de TensorFlow, consulta las configuraciones de Cloud TPU.
  2. Amplía la imagen con el código de entrenador y el comando de inicio.
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
WORKDIR /root

# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py

# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Use CMD, not ENTRYPOINT, to avoid accidentally overriding the pod base image's ENTRYPOINT.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]

Entrenamiento de PyTorch

Puedes usar contenedores precompilados o personalizados para PyTorch al entrenar con TPUs.

Contenedor prediseñado

Usa un contenedor de entrenamiento prediseñado que admita las TPUs y crea una aplicación de entrenamiento de Python.

Contenedor personalizado

Usa un contenedor personalizado en el que hayas instalado la biblioteca PyTorch.

Por ejemplo, tu archivo Dockerfile podría tener el siguiente aspecto:

FROM python:3.10

# v5e, v6e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU

# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
 -f https://storage.googleapis.com/libtpu-releases/index.html

# Add your artifacts here
COPY trainer.py .

# Run the trainer code
CMD ["python3", "trainer.py"]

pod de TPU

El entrenamiento se ejecuta en todos los hosts del pod de TPU (consulta Ejecutar código de PyTorch en sectores de pods de TPU).

Vertex AI espera una respuesta de todos los hosts para decidir si se ha completado el trabajo.

Entrenamiento de JAX

Contenedor prediseñado

No hay contenedores prediseñados para JAX.

Contenedor personalizado

Usa un contenedor personalizado en el que hayas instalado la biblioteca JAX.

Por ejemplo, tu archivo Dockerfile podría tener el siguiente aspecto:

# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Add your artifacts here
COPY trainer.py trainer.py

# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]

pod de TPU

El entrenamiento se ejecuta en todos los hosts del pod de TPU (consulta Ejecutar código JAX en sectores de pods de TPUs).

Vertex AI monitoriza el primer host del Pod de TPU para decidir si la tarea se ha completado. Puedes usar el siguiente fragmento de código para asegurarte de que todos los hosts se cierren al mismo tiempo:

# Your training logic
...

if jax.process_count() > 1:
  # Make sure all hosts stay up until the end of main.
  x = jnp.ones([jax.local_device_count()])
  x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
  assert x[0] == jax.device_count()

Variables de entorno

En la siguiente tabla se detallan las variables de entorno que puedes usar en el contenedor:

Nombre Valor
TPU_NODE_NAME my-first-tpu-node
TPU_CONFIG {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"}

Cuenta de servicio personalizada

Se puede usar una cuenta de servicio personalizada para el entrenamiento con TPU. Para saber cómo usar una cuenta de servicio personalizada, consulta la página sobre cómo usar una cuenta de servicio personalizada.

IP privada (emparejamiento de redes de VPC) para el entrenamiento

Se puede usar una IP privada para el entrenamiento con TPU. Consulta la página sobre cómo usar una IP privada para el entrenamiento personalizado.

Controles de Servicio de VPC

Los proyectos con Controles de Servicio de VPC habilitados pueden enviar trabajos de entrenamiento de TPU.

Limitaciones

Se aplican las siguientes limitaciones cuando entrenas con una VM de TPU:

Tipos de TPU

Consulta los tipos de TPU para obtener más información sobre los aceleradores de TPU, como el límite de memoria.