TPU monitoring API
Unlock deep insights into your Cloud TPU hardware's performance and behavior with advanced TPU monitoring capabilities, built directly upon the foundational software layer, LibTPU. While LibTPU encompasses drivers, networking libraries, the XLA compiler, and TPU runtime for interacting with TPUs, the focus of this document is the TPU monitoring API.
The TPU monitoring API provides:
Comprehensive observability: Gain access to telemetry API and metrics suite. This lets you obtain detailed insights into the operational performance and specific behaviors of your TPUs.
Diagnostic toolkits: Provides an SDK and command-line interface (CLI) designed to enable debugging and in-depth performance analysis of your TPU resources.
These monitoring features are designed to be a top-level, customer-facing solution, providing you with the essential tools to optimize your TPU workloads effectively.
The TPU monitoring API gives you detailed information on how machine learning workloads are performing on TPU hardware. It's designed to help you understand your TPU utilization, identify bottlenecks, and debug performance issues. It gives you more detailed information than interruption metrics, goodput metrics, and other metrics.
Get started with the TPU monitoring API
Accessing these powerful insights is straightforward. The TPU monitoring functionality is integrated with the LibTPU SDK, so the functionality is included when you install LibTPU.
Install LibTPU
pip install libtpu
Alternately, the LibTPU updates are coordinated with JAX releases, meaning that when you install the latest JAX release (released monthly), it will typically pin you to the latest compatible LibTPU version and its features.
Install JAX
pip install -U "jax[tpu]"
For PyTorch users, installing PyTorch/XLA provides the latest LibTPU and TPU monitoring functionality.
Install PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
For more information about installing PyTorch/XLA, see Installation in the PyTorch/XLA GitHub repository.
Import the API in Python
To start using the TPU monitoring API, you need to import the libtpu
module in your Python code.
import libtpu as sdk
Or
from libtpu import sdk
List all supported functionality
List all metric names and the functionality they support:
import libtpu as sdk
sdk.monitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Supported metrics
The following code sample shows how to list all supported metric names:
from libtpu import sdk
sdk.monitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
The following table shows all metrics and their corresponding definitions:
Metric | Definition | Metric name for API | Example values |
---|---|---|---|
Tensor Core Utilization | Measures the percentage of your TensorCore usage, calculated as the percentage of operations that are part of the TensorCore operations. Sampled 10 microseconds every 1 second. You cannot modify the sampling rate. This metric lets you monitor the efficiency of your workloads on TPU devices. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# utilization percentage for accelerator ID 0-3 |
Duty Cycle Percentage | Percentage of time over the past sample period (every 5 seconds; can
be tuned by setting the LIBTPU_INIT_ARG flag) during which
the accelerator was actively processing (recorded with cycles used to
execute HLO programs over the last sampling period). This metric
represents how busy a TPU is. The metric is emitted per chip.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Duty cycle percentage for accelerator ID 0-3 |
HBM Capacity Total | This metric reports the total HBM capacity in bytes. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Total HBM capacity in bytes that attached to accelerator ID 0-3 |
HBM Capacity Usage | This metric reports the usage of HBM capacity in bytes over the past
sample period (every 5 seconds; can be tuned by setting
the LIBTPU_INIT_ARG flag).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Capacity usage for HBM in bytes that attached to accelerator ID 0-3 |
Buffer transfer latency | Network transfer latencies for megascale multi-slice traffic. This visualization lets you understand the overall network performance environment. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution |
High Level Operation Execution Time Distribution Metrics | Provides granular performance insights into the HLO compiled binary execution status, enabling regression detection and model-level debugging. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999 |
High Level Optimizer queue size | HLO execution queue size monitoring tracks the number of compiled HLO programs waiting or undergoing execution. This metric reveals execution pipeline congestion, enabling identification of performance bottlenecks in hardware execution, driver overhead, or resource allocation. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Measures queue size for CoreType-CoreID. |
Read metric data - snapshot mode
To enable snapshot mode, specify the metric name when you call the
sdk.monitoring.get_metric
function. Snapshot
mode lets you insert ad hoc metric checks into low-performance code to
identify whether performance issues stem from software or hardware.
The following code sample shows how to use snapshot mode to read the duty_cycle
.
import libtpu.sdk as sdk
metric = sdk.monitoring.get_metric(metric_name="duty_cycle")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Access metrics using the CLI
The following steps show how to interact with LibTPU metrics using the CLI:
Install
tpu-info
:pip install tpu-info
# Access help information of tpu-info tpu-info --help / -h
Run the default vision of
tpu-info
:tpu-info
The output is similar to the following:
TPU Chips
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
┃ Chip ┃ Type ┃ Devices ┃ PID ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
│ /dev/accel0 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel1 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel2 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel3 │ TPU v4 chip │ 1 │ 130007 │
└─────────────┴─────────────┴─────────┴────────┘
TPU Runtime Utilization
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Device ┃ Memory usage ┃ Duty cycle ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ 0 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 1 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 2 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 3 │ 0.00 GiB / 31.75 GiB │ 0.00% │
└────────┴──────────────────────┴────────────┘
TensorCore Utilization
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ID ┃ TensorCore Utilization ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 0.00% │
│ 1 │ 0.00% │
│ 3 │ 0.00% │
│ 2 │ 0.00% |
└─────────┴────────────────────────┘
Buffer Transfer Latency
┏━━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓
┃ Buffer Size ┃ P50 ┃ P90 ┃ P95 ┃ P999 ┃
┡━━━━━━━━╇━━━━━━━━━╇━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩
│ 8MB+ | 2233.25 |2182.0┃ 3761.9 |19277.0| 53553.6 |
└────────┴─────────┴──────┴────────┴───────┴─────────┴
Use metrics to check TPU utilization
The following examples show how to use metrics from the monitoring API to track TPU utilization.
Monitor TPU duty cycle during JAX training
Scenario: You are running a JAX training
script and want to monitor the TPU's duty_cycle
metric throughout the
training process to confirm your TPUs are being effectively utilized. You
can log this metric periodically during training to track TPU utilization.
The following code sample shows how to monitor TPU Duty Cycle during JAX training:
import jax
import jax.numpy as jnp
import libtpu.sdk as sdk
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring API here to get duty_cycle ---
duty_cycle_metric = sdk.monitoring.get_metric(metric_name="duty_cycle")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring API Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Check HBM utilization before running JAX inference
Scenario: Before running inference with your JAX model, check the current HBM (High Bandwidth Memory) utilization on the TPU to confirm that you have enough memory available and to get a baseline measurement before inference starts.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
import libtpu.sdk as sdk
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU monitoring API to get HBM utilization before inference
hbm_util_metric = sdk.monitoring.get_metric(metric_name="hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End LibTPU SDK Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Export frequency of TPU metrics
The refresh frequency of TPU metrics is constrained to a minimum of one second. Host metric data is exported at a fixed frequency of 1 Hz. The latency introduced by this export process is negligible. Runtime metrics from LibTPU are not subject to the same frequency constraint. However, for consistency, these metrics are also sampled at 1 Hz or 1 sample per second.