JAX-Code auf TPU-Slices ausführen

Bevor Sie die Befehle in diesem Dokument ausführen, müssen Sie den Anleitungen in Konto und Cloud TPU-Projekt einrichten gefolgt sein.

Nachdem Sie den JAX-Code auf einem einzelnen TPU-Board ausgeführt haben, können Sie den Code skalieren, indem Sie ihn auf einem TPU-Slice ausführen. TPU-Slices sind mehrere TPU-Boards, die über dedizierte Hochgeschwindigkeits-Netzwerkverbindungen miteinander verbunden sind. Dieses Dokument bietet eine Einführung zum Ausführen von JAX-Code auf TPU-Slices. Ausführlichere Informationen finden Sie unter JAX in Umgebungen mit mehreren Hosts und mehreren Prozessen verwenden.

Cloud TPU-Speichereinheit erstellen

  1. Erstellen Sie einige Umgebungsvariablen:

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5litepod-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export TPU_NAME=your-tpu-name

    Beschreibungen von Umgebungsvariablen

    PROJECT_ID
    Ihre Google Cloud Projekt-ID.
    ACCELERATOR_TYPE
    Mit dem Beschleunigertyp geben Sie die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
    ZONE
    Die Zone, in der Sie die Cloud TPU erstellen möchten.
    RUNTIME_VERSION
    Die Version der Cloud TPU-Laufzeit.
    TPU_NAME
    Der vom Nutzer zugewiesene Name für Ihre Cloud TPU.
  2. Erstellen Sie mit dem Befehl gcloud ein TPU-Slice. Verwenden Sie beispielsweise den folgenden Befehl, um ein v5litepod-32-Slice zu erstellen:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --accelerator-type=${ACCELERATOR_TYPE}  \
    --version=${RUNTIME_VERSION} 

JAX auf Ihrem Slice installieren

Nachdem Sie das TPU-Slice erstellt haben, müssen Sie JAX auf allen Hosts im TPU-Slice installieren. Verwenden Sie dazu den Befehl gcloud compute tpus tpu-vm ssh mit den Parametern --worker=all und --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

JAX-Code auf dem Slice ausführen

Wenn Sie JAX-Code auf einem TPU-Slice ausführen möchten, müssen Sie den Code auf jedem Host im TPU-Slice ausführen. Der jax.device_count()-Aufruf reagiert nicht mehr, bis er auf jedem Host im Slice aufgerufen wird. Im folgenden Beispiel wird gezeigt, wie eine JAX-Berechnung auf einem TPU-Speicherbereich ausgeführt wird.

Code vorbereiten

Sie benötigen die Version gcloud >= 344.0.0 (für den Befehl scp). Prüfen Sie mit gcloud --version Ihre gcloud-Version und führen Sie bei Bedarf gcloud components upgrade aus.

Erstellen Sie eine Datei mit dem Namen example.py und folgendem Code:


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py in alle TPU-Worker-VMs im Slice kopieren

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Wenn Sie den Befehl scp noch nicht verwendet haben, wird möglicherweise ein Fehler wie der folgende angezeigt:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Führen Sie den Befehl ssh-add wie in der Fehlermeldung angezeigt aus und führen Sie den Befehl noch einmal aus, um den Fehler zu beheben.

Code auf dem Slice ausführen

Starten Sie auf jeder VM das Programm example.py:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Ausgabe (mit einem v5litepod-32-Slice erzeugt)

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

Bereinigen

Wenn Sie mit Ihrer TPU-VM fertig sind, führen Sie die folgenden Schritte aus, um Ihre Ressourcen zu bereinigen.

  1. Löschen Sie Ihre Cloud TPU- und Compute Engine-Ressourcen.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. Prüfen Sie, ob die Ressourcen gelöscht wurden. Führen Sie dazu gcloud compute tpus execution-groups list aus. Der Löschvorgang kann einige Minuten dauern. Die Ausgabe des folgenden Befehls sollte keine der in dieser Anleitung erstellten Ressourcen enthalten:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}