Exécuter du code JAX sur des tranches TPU

Avant d'exécuter les commandes de ce document, assurez-vous d'avoir suivi les instructions de la section Configurer un compte et un projet Cloud TPU.

Une fois que votre code JAX s'exécute sur une seule carte TPU, vous pouvez augmenter la capacité en l'exécutant sur une tranche TPU. Les tranches de TPU sont des cartes de TPU interconnectées sur des connexions réseau haut débit dédiées. Ce document est une introduction à l'exécution de code JAX sur des tranches de TPU. Pour des informations plus détaillées, consultez la section Utiliser JAX dans des environnements multihôtes et multiprocessus.

Créer une tranche Cloud TPU

  1. Créez des variables d'environnement:

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5litepod-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export TPU_NAME=your-tpu-name

    Descriptions des variables d'environnement

    PROJECT_ID
    L'ID de votre Google Cloud projet.
    ACCELERATOR_TYPE
    Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU.
    ZONE
    Zone dans laquelle vous prévoyez de créer votre Cloud TPU.
    RUNTIME_VERSION
    Version du runtime Cloud TPU.
    TPU_NAME
    Nom attribué par l'utilisateur à votre Cloud TPU.
  2. Créez une tranche TPU à l'aide de la commande gcloud. Par exemple, pour créer une tranche v5litepod-32, utilisez la commande suivante:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --accelerator-type=${ACCELERATOR_TYPE}  \
    --version=${RUNTIME_VERSION} 

Installer JAX sur votre tranche

Après avoir créé la tranche de TPU, vous devez installer JAX sur tous les hôtes de la tranche de TPU. Pour ce faire, utilisez la commande gcloud compute tpus tpu-vm ssh avec les paramètres --worker=all et --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Exécuter du code JAX sur la tranche

Pour exécuter du code JAX sur une tranche TPU, vous devez exécuter le code sur chaque hôte de la tranche TPU. L'appel jax.device_count() cesse de répondre jusqu'à ce qu'il soit appelé sur chaque hôte de la tranche. L'exemple suivant montre comment exécuter un calcul JAX sur une tranche de TPU.

Préparer le code

Vous avez besoin de la version gcloud >= 344.0.0 (pour la commande scp). Utilisez gcloud --version pour vérifier votre version de gcloud et exécutez gcloud components upgrade, si nécessaire.

Créez un fichier nommé example.py avec le code suivant:


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copiez example.py sur toutes les VM de nœuds de calcul TPU de la tranche.

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Si vous n'avez jamais utilisé la commande scp, un message d'erreur semblable à celui-ci peut s'afficher:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Pour résoudre l'erreur, exécutez la commande ssh-add comme indiqué dans le message d'erreur, puis réexécutez la commande.

Exécuter le code sur la tranche

Lancez le programme example.py sur chaque VM :

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Résultat (généré avec une tranche v5litepod-32):

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

Effectuer un nettoyage

Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.

  1. Supprimez vos ressources Cloud TPU et Compute Engine.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. Vérifiez que les ressources ont été supprimées en exécutant la commande gcloud compute tpus execution-groups list. La suppression peut prendre plusieurs minutes. Le résultat de la commande suivante ne doit inclure aucune des ressources créées dans ce tutoriel:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}