TPU monitoring API

Unlock deep insights into your Cloud TPU hardware's performance and behavior with advanced TPU monitoring capabilities, built directly upon the foundational software layer, LibTPU. While LibTPU encompasses drivers, networking libraries, the XLA compiler, and TPU runtime for interacting with TPUs, the focus of this document is the TPU monitoring API.

The TPU monitoring API provides:

  • Comprehensive observability: Gain access to telemetry API and metrics suite. This lets you obtain detailed insights into the operational performance and specific behaviors of your TPUs.

  • Diagnostic toolkits: Provides an SDK and command-line interface (CLI) designed to enable debugging and in-depth performance analysis of your TPU resources.

These monitoring features are designed to be a top-level, customer-facing solution, providing you with the essential tools to optimize your TPU workloads effectively.

The TPU monitoring API gives you detailed information on how machine learning workloads are performing on TPU hardware. It's designed to help you understand your TPU utilization, identify bottlenecks, and debug performance issues. It gives you more detailed information than interruption metrics, goodput metrics, and other metrics.

Get started with the TPU monitoring API

Accessing these powerful insights is straightforward. The TPU monitoring functionality is integrated with the LibTPU SDK, so the functionality is included when you install LibTPU.

Install LibTPU

pip install libtpu

Alternately, the LibTPU updates are coordinated with JAX releases, meaning that when you install the latest JAX release (released monthly), it will typically pin you to the latest compatible LibTPU version and its features.

Install JAX

pip install -U "jax[tpu]"

For PyTorch users, installing PyTorch/XLA provides the latest LibTPU and TPU monitoring functionality.

Install PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

For more information about installing PyTorch/XLA, see Installation in the PyTorch/XLA GitHub repository.

Import the API in Python

To start using the TPU monitoring API, you need to import the libtpu module in your Python code.

import libtpu as sdk

Or

from libtpu import sdk

List all supported functionality

List all metric names and the functionality they support:


import libtpu as sdk

sdk.monitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Supported metrics

The following code sample shows how to list all supported metric names:

from libtpu import sdk

sdk.monitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

The following table shows all metrics and their corresponding definitions:

Metric Definition Metric name for API Example values
Tensor Core Utilization Measures the percentage of your TensorCore usage, calculated as the percentage of operations that are part of the TensorCore operations. Sampled 10 microseconds every 1 second. You cannot modify the sampling rate. This metric lets you monitor the efficiency of your workloads on TPU devices. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3
Duty Cycle Percentage Percentage of time over the past sample period (every 5 seconds; can be tuned by setting the LIBTPU_INIT_ARG flag) during which the accelerator was actively processing (recorded with cycles used to execute HLO programs over the last sampling period). This metric represents how busy a TPU is. The metric is emitted per chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Duty cycle percentage for accelerator ID 0-3
HBM Capacity Total This metric reports the total HBM capacity in bytes. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Total HBM capacity in bytes that attached to accelerator ID 0-3
HBM Capacity Usage This metric reports the usage of HBM capacity in bytes over the past sample period (every 5 seconds; can be tuned by setting the LIBTPU_INIT_ARG flag). hbm_capacity_usage ['100', '200', '300', '400']

# Capacity usage for HBM in bytes that attached to accelerator ID 0-3
Buffer transfer latency Network transfer latencies for megascale multi-slice traffic. This visualization lets you understand the overall network performance environment. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution
High Level Operation Execution Time Distribution Metrics Provides granular performance insights into the HLO compiled binary execution status, enabling regression detection and model-level debugging. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999
High Level Optimizer queue size HLO execution queue size monitoring tracks the number of compiled HLO programs waiting or undergoing execution. This metric reveals execution pipeline congestion, enabling identification of performance bottlenecks in hardware execution, driver overhead, or resource allocation. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.

Read metric data - snapshot mode

To enable snapshot mode, specify the metric name when you call the sdk.monitoring.get_metric function. Snapshot mode lets you insert ad hoc metric checks into low-performance code to identify whether performance issues stem from software or hardware.

The following code sample shows how to use snapshot mode to read the duty_cycle.

import libtpu.sdk as sdk

metric = sdk.monitoring.get_metric(metric_name="duty_cycle")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Access metrics using the CLI

The following steps show how to interact with LibTPU metrics using the CLI:

  1. Install tpu-info:

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. Run the default vision of tpu-info:

    tpu-info
    

    The output is similar to the following:

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓
    Buffer Size       P50     P90   P95    P999       ┡━━━━━━━━╇━━━━━━━━━╇━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩
    8MB+   | 2233.25 |2182.0┃ 3761.9 |19277.0| 53553.6 |
   └────────┴─────────┴──────┴────────┴───────┴─────────┴

Use metrics to check TPU utilization

The following examples show how to use metrics from the monitoring API to track TPU utilization.

Monitor TPU duty cycle during JAX training

Scenario: You are running a JAX training script and want to monitor the TPU's duty_cycle metric throughout the training process to confirm your TPUs are being effectively utilized. You can log this metric periodically during training to track TPU utilization.

The following code sample shows how to monitor TPU Duty Cycle during JAX training:

import jax
import jax.numpy as jnp
import libtpu.sdk as sdk
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring API here to get duty_cycle ---
            duty_cycle_metric = sdk.monitoring.get_metric(metric_name="duty_cycle")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring API Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Check HBM utilization before running JAX inference

Scenario: Before running inference with your JAX model, check the current HBM (High Bandwidth Memory) utilization on the TPU to confirm that you have enough memory available and to get a baseline measurement before inference starts.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
import libtpu.sdk as sdk

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU monitoring API to get HBM utilization before inference
hbm_util_metric = sdk.monitoring.get_metric(metric_name="hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End LibTPU SDK Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Export frequency of TPU metrics

The refresh frequency of TPU metrics is constrained to a minimum of one second. Host metric data is exported at a fixed frequency of 1 Hz. The latency introduced by this export process is negligible. Runtime metrics from LibTPU are not subject to the same frequency constraint. However, for consistency, these metrics are also sampled at 1 Hz or 1 sample per second.