Tetap teratur dengan koleksi
Simpan dan kategorikan konten berdasarkan preferensi Anda.
Menjalankan penghitungan di VM Cloud TPU menggunakan JAX
Dokumen ini memberikan pengantar singkat tentang cara menggunakan JAX dan Cloud TPU.
Sebelum memulai
Sebelum menjalankan perintah dalam dokumen ini, Anda harus membuat akun Google Cloud, menginstal Google Cloud CLI, dan mengonfigurasi perintah gcloud. Untuk
informasi selengkapnya, lihat Menyiapkan lingkungan Cloud TPU.
Membuat VM Cloud TPU menggunakan gcloud
Tentukan beberapa variabel lingkungan agar perintah lebih mudah digunakan.
Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru.
TPU_NAME
Nama TPU.
ZONE
Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat
Region dan zona TPU.
ACCELERATOR_TYPE
Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat
versi TPU.
Jika Anda gagal terhubung ke VM TPU menggunakan SSH, hal ini mungkin karena VM TPU tidak memiliki alamat IP eksternal. Untuk mengakses VM TPU tanpa alamat IP
eksternal, ikuti petunjuk di Menghubungkan ke VM TPU tanpa alamat IP
publik.
Verifikasi bahwa JAX dapat mengakses TPU dan dapat menjalankan operasi dasar:
Mulai penafsir Python 3:
(vm)$python3
>>>importjax
Menampilkan jumlah core TPU yang tersedia:
>>>jax.device_count()
Jumlah core TPU ditampilkan. Jumlah core yang ditampilkan bergantung pada versi TPU yang Anda gunakan. Untuk mengetahui informasi selengkapnya, lihat versi TPU.
Melakukan penghitungan
>>>jax.numpy.add(1,1)
Hasil penambahan numpy ditampilkan:
Output dari perintah:
Array(2,dtype=int32,weak_type=True)
Keluar dari penafsiran Python
>>>exit()
Menjalankan kode JAX di VM TPU
Sekarang Anda dapat menjalankan kode JAX yang diinginkan. Contoh Flax
adalah tempat yang tepat untuk memulai menjalankan model ML standar di JAX. Misalnya,
untuk melatih jaringan konvolusi MNIST dasar:
Pastikan resource telah dihapus dengan menjalankan perintah berikut. Pastikan
TPU Anda tidak lagi tercantum. Proses penghapusan mungkin memerlukan waktu beberapa menit.
$gcloudcomputetpustpu-vmlist\--zone=$ZONE
Catatan performa
Berikut beberapa detail penting yang sangat relevan dengan penggunaan TPU di
JAX.
Padding
Salah satu penyebab paling umum untuk performa lambat di TPU adalah memperkenalkan
padding yang tidak disengaja:
Array di Cloud TPU disusun dalam ubin. Hal ini memerlukan padding salah satu
dimensi ke kelipatan 8, dan dimensi yang berbeda ke kelipatan
128.
Unit perkalian matriks berperforma terbaik dengan pasangan matriks besar
yang meminimalkan kebutuhan padding.
dtype bfloat16
Secara default, perkalian matriks di JAX pada TPU menggunakan bfloat16
dengan akumulasi float32. Hal ini dapat dikontrol dengan argumen presisi pada
panggilan fungsi jax.numpy yang relevan (matmul, dot, einsum, dll.). Pada khususnya:
precision=jax.lax.Precision.DEFAULT: menggunakan presisi bfloat16
campuran (tercepat)
precision=jax.lax.Precision.HIGH: menggunakan beberapa kartu MXU untuk
mencapai presisi yang lebih tinggi
precision=jax.lax.Precision.HIGHEST: menggunakan lebih banyak kartu MXU
untuk mencapai presisi float32 penuh
JAX juga menambahkan dtype bfloat16, yang dapat Anda gunakan untuk secara eksplisit mentransmisikan array ke
bfloat16. Contoh, jax.numpy.array(x, dtype=jax.numpy.bfloat16).
Langkah berikutnya
Untuk informasi selengkapnya tentang Cloud TPU, lihat:
[[["Mudah dipahami","easyToUnderstand","thumb-up"],["Memecahkan masalah saya","solvedMyProblem","thumb-up"],["Lainnya","otherUp","thumb-up"]],[["Sulit dipahami","hardToUnderstand","thumb-down"],["Informasi atau kode contoh salah","incorrectInformationOrSampleCode","thumb-down"],["Informasi/contoh yang saya butuhkan tidak ada","missingTheInformationSamplesINeed","thumb-down"],["Masalah terjemahan","translationIssue","thumb-down"],["Lainnya","otherDown","thumb-down"]],["Terakhir diperbarui pada 2025-08-18 UTC."],[],[],null,["# Run a calculation on a Cloud TPU VM using JAX\n=============================================\n\nThis document provides a brief introduction to working with JAX and Cloud TPU.\n| **Note:** This example shows how to run code on a v5litepod-8 (v5e) TPU which is a single-host TPU. Single-host TPUs have only 1 TPU VM. To run code on TPUs with more than one TPU VM (for example, v5litepod-16 or larger), see [Run JAX code on Cloud TPU slices](/tpu/docs/jax-pods).\n\n\nBefore you begin\n----------------\n\nBefore running the commands in this document, you must create a Google Cloud\naccount, install the Google Cloud CLI, and configure the `gcloud` command. For\nmore information, see [Set up the Cloud TPU environment](/tpu/docs/setup-gcp-account).\n\nCreate a Cloud TPU VM using `gcloud`\n------------------------------------\n\n1. Define some environment variables to make commands easier to use.\n\n\n ```bash\n export PROJECT_ID=your-project-id\n export TPU_NAME=your-tpu-name\n export ZONE=us-east5-a\n export ACCELERATOR_TYPE=v5litepod-8\n export RUNTIME_VERSION=v2-alpha-tpuv5-lite\n ``` \n\n #### Environment variable descriptions\n\n \u003cbr /\u003e\n\n2. Create your TPU VM by running the following command from a Cloud Shell or\n your computer terminal where the [Google Cloud CLI](/sdk/docs/install)\n is installed.\n\n ```bash\n $ gcloud compute tpus tpu-vm create $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE \\\n --accelerator-type=$ACCELERATOR_TYPE \\\n --version=$RUNTIME_VERSION\n ```\n\nConnect to your Cloud TPU VM\n----------------------------\n\nConnect to your TPU VM over SSH by using the following command: \n\n```bash\n$ gcloud compute tpus tpu-vm ssh $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE\n```\n\nIf you fail to connect to a TPU VM using SSH, it might be because the TPU VM\ndoesn't have an external IP address. To access a TPU VM without an external IP\naddress, follow the instructions in [Connect to a TPU VM without a public IP\naddress](/tpu/docs/tpu-iap).\n\nInstall JAX on your Cloud TPU VM\n--------------------------------\n\n```bash\n(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\n```\n\nSystem check\n------------\n\nVerify that JAX can access the TPU and can run basic operations:\n\n1. Start the Python 3 interpreter:\n\n ```bash\n (vm)$ python3\n ``` \n\n ```bash\n \u003e\u003e\u003e import jax\n ```\n2. Display the number of TPU cores available:\n\n ```bash\n \u003e\u003e\u003e jax.device_count()\n ```\n\nThe number of TPU cores is displayed. The number of cores displayed is dependent\non the TPU version you are using. For more information, see [TPU versions](/tpu/docs/system-architecture-tpu-vm#versions).\n\n### Perform a calculation\n\n```bash\n\u003e\u003e\u003e jax.numpy.add(1, 1)\n```\n\nThe result of the numpy add is displayed:\n\nOutput from the command: \n\n```bash\nArray(2, dtype=int32, weak_type=True)\n```\n\n\u003cbr /\u003e\n\n### Exit the Python interpreter\n\n```bash\n\u003e\u003e\u003e exit()\n```\n\nRunning JAX code on a TPU VM\n----------------------------\n\nYou can now run any JAX code you want. The [Flax examples](https://github.com/google/flax/tree/master/examples)\nare a great place to start with running standard ML models in JAX. For example,\nto train a basic MNIST convolutional network:\n\n1. Install Flax examples dependencies:\n\n ```bash\n (vm)$ pip install --upgrade clu\n (vm)$ pip install tensorflow\n (vm)$ pip install tensorflow_datasets\n ```\n2. Install Flax:\n\n ```bash\n (vm)$ git clone https://github.com/google/flax.git\n (vm)$ pip install --user flax\n ```\n3. Run the Flax MNIST training script:\n\n ```bash\n (vm)$ cd flax/examples/mnist\n (vm)$ python3 main.py --workdir=/tmp/mnist \\\n --config=configs/default.py \\\n --config.learning_rate=0.05 \\\n --config.num_epochs=5\n ```\n\nThe script downloads the dataset and starts training. The script output should\nlook like this: \n\n```bash\nI0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88\nI0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72\nI0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04\nI0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15\nI0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18\n```\n\n\nClean up\n--------\n\n\nTo avoid incurring charges to your Google Cloud account for\nthe resources used on this page, follow these steps.\n\nWhen you are done with your TPU VM, follow these steps to clean up your resources.\n\n1. Disconnect from the Cloud TPU instance, if you have not already done so:\n\n ```bash\n (vm)$ exit\n ```\n\n Your prompt should now be username@projectname, showing you are in the Cloud Shell.\n2. Delete your Cloud TPU:\n\n ```bash\n $ gcloud compute tpus tpu-vm delete $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE\n ```\n3. Verify the resources have been deleted by running the following command. Make\n sure your TPU is no longer listed. The deletion might take several minutes.\n\n ```bash\n $ gcloud compute tpus tpu-vm list \\\n --zone=$ZONE\n ```\n\nPerformance notes\n-----------------\n\nHere are a few important details that are particularly relevant to using TPUs in\nJAX.\n\n### Padding\n\nOne of the most common causes for slow performance on TPUs is introducing\ninadvertent padding:\n\n- Arrays in the Cloud TPU are tiled. This entails padding one of the dimensions to a multiple of 8, and a different dimension to a multiple of 128.\n- The matrix multiplication unit performs best with pairs of large matrices that minimize the need for padding.\n\n### bfloat16 dtype\n\nBy default, matrix multiplication in JAX on TPUs uses [bfloat16](/tpu/docs/bfloat16)\nwith float32 accumulation. This can be controlled with the precision argument on\nrelevant `jax.numpy` function calls (matmul, dot, einsum, etc). In particular:\n\n- `precision=jax.lax.Precision.DEFAULT`: uses mixed bfloat16 precision (fastest)\n- `precision=jax.lax.Precision.HIGH`: uses multiple MXU passes to achieve higher precision\n- `precision=jax.lax.Precision.HIGHEST`: uses even more MXU passes to achieve full float32 precision\n\nJAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays to\n`bfloat16`. For example,\n`jax.numpy.array(x, dtype=jax.numpy.bfloat16)`.\n\n\nWhat's next\n-----------\n\nFor more information about Cloud TPU, see:\n\n- [Run JAX code on TPU slices](/tpu/docs/jax-pods)\n- [Manage TPUs](/tpu/docs/managing-tpus-tpu-vm)\n- [Cloud TPU System architecture](/tpu/docs/system-architecture-tpu-vm)"]]