Ejecutar código JAX en slices de TPU

Antes de ejecutar los comandos de este documento, asegúrate de haber seguido las instrucciones que se indican en Configurar una cuenta y un proyecto de Cloud TPU.

Una vez que hayas ejecutado tu código JAX en una sola placa de TPU, podrás ampliarlo ejecutándolo en un sector de TPU. Las porciones de TPU son varias placas de TPU conectadas entre sí a través de conexiones de red específicas de alta velocidad. Este documento es una introducción a la ejecución de código JAX en slices de TPU. Para obtener información más detallada, consulta Usar JAX en entornos de varios hosts y varios procesos.

Crear un slice de TPU de Cloud

  1. Crea algunas variables de entorno:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5litepod-32
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno.
    TPU_NAME El nombre de la TPU.
    ZONE La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION La versión de software de la TPU de Cloud.

  2. Crea una porción de TPU con el comando gcloud. Por ejemplo, para crear una porción de v5litepod-32, usa el siguiente comando:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE}  \
        --version=${RUNTIME_VERSION} 

Instalar JAX en tu slice

Después de crear el segmento de TPU, debes instalar JAX en todos los hosts del segmento de TPU. Para ello, puedes usar el comando gcloud compute tpus tpu-vm ssh con los parámetros --worker=all y --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Ejecutar código JAX en el sector

Para ejecutar código JAX en un segmento de TPU, debes ejecutar el código en cada host del segmento de TPU. La llamada jax.device_count() deja de responder hasta que se llama a cada host de la porción. En el siguiente ejemplo se muestra cómo ejecutar un cálculo de JAX en una porción de TPU.

Preparar el código

Necesitas la versión gcloud >= 344.0.0 (para el comando scp). Usa gcloud --version para comprobar la versión de gcloud y ejecuta gcloud components upgrade si es necesario.

Crea un archivo llamado example.py con el siguiente código:


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copia example.py en todas las VMs de trabajador de TPU del segmento

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Si no has usado el comando scp anteriormente, es posible que veas un error similar al siguiente:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Para solucionar el error, ejecute el comando ssh-add tal como se muestra en el mensaje de error y vuelva a ejecutar el comando.

Ejecutar el código en el slice

Inicia el programa example.py en cada VM:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Salida (producida con un segmento v5litepod-32):

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

Limpieza

Cuando hayas terminado de usar tu VM de TPU, sigue estos pasos para limpiar tus recursos.

  1. Elimina los recursos de TPU de Cloud y de Compute Engine.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. Para comprobar que los recursos se han eliminado, ejecuta gcloud compute tpus execution-groups list. El proceso de eliminación puede tardar varios minutos. El resultado del siguiente comando no debe incluir ninguno de los recursos creados en este tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}