Menjalankan kode JAX pada slice Pod TPU

Setelah menjalankan kode JAX di satu board TPU, Anda dapat meningkatkan skala kode dengan menjalankannya di slice Pod TPU. Slice Pod TPU adalah beberapa board TPU yang terhubung satu sama lain melalui koneksi jaringan khusus berkecepatan tinggi. Dokumen ini merupakan pengantar untuk menjalankan kode JAX pada slice Pod TPU. Untuk mengetahui informasi yang lebih mendalam, lihat Menggunakan JAX di lingkungan multi-host dan multi-proses.

Jika ingin menggunakan NFS yang terpasang untuk penyimpanan data, Anda harus menetapkan Login OS untuk semua VM TPU dalam slice Pod. Untuk informasi selengkapnya, lihat Menggunakan NFS untuk penyimpanan data.

Membuat slice Pod TPU

Sebelum menjalankan perintah dalam dokumen ini, pastikan Anda telah mengikuti petunjuk di Menyiapkan akun dan project Cloud TPU. Jalankan perintah berikut di mesin lokal Anda.

Buat slice Pod TPU menggunakan perintah gcloud. Misalnya, untuk membuat slice Pod v4-32, gunakan perintah berikut:

$ gcloud compute tpus tpu-vm create tpu-name  \
  --zone=us-central2-b \
  --accelerator-type=v4-32  \
  --version=tpu-ubuntu2204-base 

Menginstal JAX di slice Pod

Setelah membuat slice Pod TPU, Anda harus menginstal JAX di semua host di slice Pod TPU. Anda dapat menginstal JAX di semua host dengan satu perintah menggunakan opsi --worker=all:

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Menjalankan kode JAX pada slice Pod

Untuk menjalankan kode JAX di slice Pod TPU, Anda harus menjalankan kode pada setiap host di slice Pod TPU. Panggilan jax.device_count() akan berhenti merespons hingga dipanggil pada setiap host di slice Pod. Contoh berikut mengilustrasikan cara menjalankan penghitungan JAX sederhana pada slice Pod TPU.

Menyiapkan kode

Anda memerlukan gcloud versi >= 344.0.0 (untuk perintah scp). Gunakan gcloud --version untuk memeriksa versi gcloud, dan menjalankan gcloud components upgrade, jika diperlukan.

Buat file bernama example.py dengan kode berikut:

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Salin example.py ke semua VM worker TPU di slice Pod

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all \
  --zone=us-central2-b

Jika belum pernah menggunakan perintah scp, Anda mungkin melihat error seperti berikut:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Untuk mengatasi error ini, jalankan perintah ssh-add seperti yang ditampilkan dalam pesan error dan jalankan kembali perintah tersebut.

Menjalankan kode di slice Pod

Luncurkan program example.py di setiap VM:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

Output (diproduksi dengan slice Pod v4-32):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Pembersihan

Setelah selesai, Anda dapat melepaskan resource VM TPU menggunakan perintah gcloud:

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone=us-central2-b