TPU Monitoring Library

Dapatkan insight mendalam tentang performa dan perilaku hardware Cloud TPU Anda dengan kemampuan pemantauan TPU tingkat lanjut, yang dibangun langsung di atas lapisan software dasar, LibTPU. Meskipun LibTPU mencakup driver, library jaringan, compiler XLA, dan runtime TPU untuk berinteraksi dengan TPU, fokus dokumen ini adalah TPU Monitoring Library.

TPU Monitoring Library menyediakan:

  • Kemampuan pengamatan yang komprehensif: Dapatkan akses ke Telemetry API dan rangkaian metrik. Dengan demikian, Anda dapat memperoleh insight mendetail tentang performa operasional dan perilaku spesifik TPU Anda.

  • Toolkit diagnostik: Menyediakan SDK dan antarmuka command line (CLI) yang dirancang untuk memungkinkan proses debug dan analisis performa mendalam pada resource TPU Anda.

Fitur pemantauan ini dirancang sebagai solusi tingkat teratas yang ditujukan untuk pelanggan, yang memberi Anda alat penting untuk mengoptimalkan workload TPU secara efektif.

TPU Monitoring Library memberi Anda informasi mendetail tentang performa workload machine learning di hardware TPU. Alat ini dirancang untuk membantu Anda memahami penggunaan TPU, mengidentifikasi bottleneck, dan men-debug masalah performa. Metrik ini memberi Anda informasi yang lebih mendetail daripada metrik gangguan, metrik goodput, dan metrik lainnya.

Mulai menggunakan TPU Monitoring Library

Insight yang efektif ini dapat diakses dengan mudah. Fungsi pemantauan TPU terintegrasi dengan LibTPU SDK, sehingga fungsi ini disertakan saat Anda menginstal LibTPU.

Instal LibTPU

pip install libtpu

Atau, update LibTPU dikoordinasikan dengan rilis JAX, yang berarti bahwa saat Anda menginstal rilis JAX terbaru (dirilis setiap bulan), Anda biasanya akan diarahkan ke versi LibTPU yang kompatibel terbaru dan fiturnya.

Menginstal JAX

pip install -U "jax[tpu]"

Untuk pengguna PyTorch, menginstal PyTorch/XLA akan memberikan fungsi pemantauan TPU dan LibTPU terbaru.

Menginstal PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Untuk mengetahui informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan di repositori GitHub PyTorch/XLA.

Mengimpor library di Python

Untuk mulai menggunakan TPU Monitoring Library, Anda perlu mengimpor modul libtpu dalam kode Python Anda.

from libtpu.sdk import tpumonitoring

Mencantumkan semua fungsi yang didukung

Mencantumkan semua nama metrik dan fungsi yang didukungnya:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Metrik yang didukung

Contoh kode berikut menunjukkan cara mencantumkan semua nama metrik yang didukung:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

Tabel berikut menampilkan semua metrik dan definisi yang sesuai:

Metrik Definisi Nama metrik untuk API Contoh nilai
Penggunaan Tensor Core Mengukur persentase penggunaan TensorCore Anda, yang dihitung sebagai persentase operasi yang merupakan bagian dari operasi TensorCore. Sampel diambil setiap 1 detik selama 10 mikrodetik. Anda tidak dapat mengubah rasio pengambilan sampel. Metrik ini memungkinkan Anda memantau efisiensi workload di perangkat TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# persentase pemakaian untuk ID akselerator 0-3
Persentase Siklus Tugas Persentase waktu selama periode sampel terakhir (setiap 5 detik; dapat disesuaikan dengan menetapkan flag LIBTPU_INIT_ARG) saat akselerator secara aktif memproses (direkam dengan siklus yang digunakan untuk mengeksekusi program HLO selama periode pengambilan sampel terakhir). Metrik ini menunjukkan seberapa sibuk TPU. Metrik ini dikeluarkan per chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Persentase siklus tugas untuk ID akselerator 0-3
Total Kapasitas HBM Metrik ini melaporkan total kapasitas HBM dalam byte. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Total kapasitas HBM dalam byte yang terpasang ke ID akselerator 0-3
Penggunaan Kapasitas HBM Metrik ini melaporkan penggunaan kapasitas HBM dalam byte selama periode pengambilan sampel terakhir (setiap 5 detik; dapat disesuaikan dengan menyetel flag LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Penggunaan kapasitas untuk HBM dalam byte yang terpasang ke ID akselerator 0-3
Latensi transfer buffer Latensi transfer jaringan untuk traffic multi-slice berskala besar. Visualisasi ini memungkinkan Anda memahami lingkungan performa jaringan secara keseluruhan. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# ukuran buffer, rata-rata, p50, p90, p99, p99,9 dari distribusi latensi transfer jaringan
Metrik Distribusi Waktu Eksekusi Operasi Tingkat Tinggi Memberikan insight performa terperinci tentang status eksekusi biner yang dikompilasi HLO, sehingga memungkinkan deteksi regresi dan proses debug tingkat model. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# Distribusi durasi waktu eksekusi HLO untuk CoreType-CoreID dengan mean, p50, p90, p95, p999
Ukuran antrean Pengoptimal Tingkat Tinggi Pemantauan ukuran antrean eksekusi HLO melacak jumlah program HLO yang dikompilasi yang menunggu atau sedang dieksekusi. Metrik ini mengungkapkan kemacetan pipeline eksekusi, sehingga memungkinkan identifikasi bottleneck performa dalam eksekusi hardware, overhead driver, atau alokasi resource. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Mengukur ukuran antrean untuk CoreType-CoreID.
Latensi End-to-End Kolektif Metrik ini mengukur latensi kolektif end-to-end melalui DCN dalam mikrodetik, dari host yang memulai operasi hingga semua peer menerima output. Hal ini mencakup pengurangan data sisi host dan pengiriman output ke TPU. Hasilnya adalah string yang menjelaskan ukuran buffer, jenis, dan latensi rata-rata, p50, p90, p95, dan p99,9. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency

Membaca data metrik - mode snapshot

Untuk mengaktifkan mode snapshot, tentukan nama metrik saat Anda memanggil fungsi tpumonitoring.get_metric. Mode Snapshot memungkinkan Anda menyisipkan pemeriksaan metrik ad hoc ke dalam kode berperforma rendah untuk mengidentifikasi apakah masalah performa berasal dari software atau hardware.

Contoh kode berikut menunjukkan cara menggunakan mode snapshot untuk membaca duty_cycle.

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Mengakses metrik menggunakan CLI

Langkah-langkah berikut menunjukkan cara berinteraksi dengan metrik LibTPU menggunakan CLI:

  1. Instal tpu-info:

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. Jalankan visi default tpu-info:

    tpu-info
    

    Outputnya mirip dengan hal berikut ini:

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

Menggunakan metrik untuk memeriksa pemanfaatan TPU

Contoh berikut menunjukkan cara menggunakan metrik dari TPU Monitoring Library untuk melacak pemanfaatan TPU.

Memantau siklus tugas TPU selama pelatihan JAX

Skenario: Anda menjalankan skrip pelatihan JAX dan ingin memantau metrik duty_cycle_pct TPU selama proses pelatihan untuk mengonfirmasi bahwa TPU Anda digunakan secara efektif. Anda dapat mencatat metrik ini secara berkala selama pelatihan untuk melacak pemanfaatan TPU.

Contoh kode berikut menunjukkan cara memantau Siklus Tugas TPU selama pelatihan JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Memeriksa pemanfaatan HBM sebelum menjalankan inferensi JAX

Skenario: Sebelum menjalankan inferensi dengan model JAX, periksa penggunaan HBM (High Bandwidth Memory) saat ini di TPU untuk mengonfirmasi bahwa Anda memiliki memori yang cukup dan untuk mendapatkan pengukuran dasar sebelum inferensi dimulai.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Frekuensi ekspor metrik TPU

Frekuensi pemuatan ulang metrik TPU dibatasi hingga minimum satu detik. Data metrik host diekspor pada frekuensi tetap 1 Hz. Latensi yang disebabkan oleh proses ekspor ini dapat diabaikan. Metrik runtime dari LibTPU tidak tunduk pada batasan frekuensi yang sama. Namun, agar konsisten, metrik ini juga diambil sampelnya pada 1 Hz atau 1 sampel per detik.