TPU 监控库

借助直接基于基础软件层 LibTPU 构建的高级 TPU 监控功能,深入了解 Cloud TPU 硬件的性能和行为。虽然 LibTPU 包含用于与 TPU 交互的驱动程序、网络库、XLA 编译器和 TPU 运行时,但本文档的重点是 TPU 监控库。

TPU 监控库提供以下功能:

  • 全面的可观测性:可访问遥测 API 和指标套件。这样,您就可以详细了解 TPU 的运行性能和具体行为。

  • 诊断工具包:提供 SDK 和命令行界面 (CLI),旨在对 TPU 资源进行调试和深入的性能分析。

这些监控功能适合面向客户的顶级解决方案,为您提供有效优化 TPU 工作负载所需的基本工具。

TPU 监控库可为您提供有关机器学习工作负载在 TPU 硬件上的执行情况的详细信息。它旨在帮助您了解 TPU 利用率、找出瓶颈并调试性能问题。与中断指标、goodput 指标和其他指标相比,它可为您提供更详细的信息。

开始使用 TPU 监控库

您可以轻松获取强有力的数据分析。TPU 监控功能与 LibTPU SDK 集成,因此在安装 LibTPU 时会包含该功能。

安装 LibTPU

pip install libtpu

或者,LibTPU 更新与 JAX 版本同步,这意味着当您安装最新的 JAX 版本(每月发布)时,系统通常会将您锁定到最新的兼容 LibTPU 版本及其功能。

安装 JAX

pip install -U "jax[tpu]"

对于 PyTorch 用户,安装 PyTorch/XLA 可提供最新的 LibTPU 和 TPU 监控功能。

安装 PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

如需详细了解如何安装 PyTorch/XLA,请参阅 PyTorch/XLA GitHub 代码库中的安装

在 Python 中导入库

如需开始使用 TPU 监控库,您需要在 Python 代码中导入 libtpu 模块。

from libtpu.sdk import tpumonitoring

列出所有支持的功能

列出所有指标名称及其支持的功能:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

支持的指标

以下代码示例展示了如何列出所有受支持的指标名称:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

下表展示了所有指标及其对应的定义:

指标 定义 API 的指标名称 示例值
Tensor Core 利用率 衡量 TensorCore 用量的百分比,以属于 TensorCore 操作一部分的操作百分比形式计算。每 1 秒采样 10 微秒。您无法修改采样率。借助此指标,您可以监控 TPU 设备上工作负载的效率。 tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# 加速器 ID 0-3 的利用率百分比
占空比百分比 加速器活跃处理(通过用于上一个采样周期内执行 HLO 程序的周期记录)的时间占过去的采样周期(每 5 秒;可通过设置 LIBTPU_INIT_ARG 标志进行调整)的百分比。此指标表示 TPU 的繁忙程度。该指标按芯片发出。 duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# 加速器 ID 0-3 的占空比百分比
HBM 总容量 此指标报告 HBM 总容量(以字节为单位)。 hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# 附加到加速器 ID 0-3 的 HBM 总容量(以字节为单位)
HBM 容量用量 此指标报告过去的采样周期(每 5 秒;可通过设置 LIBTPU_INIT_ARG 标志进行调整)的 HBM 容量用量(以字节为单位)。 hbm_capacity_usage ['100', '200', '300', '400']

# 附加到加速器 ID 0-3 的 HBM 容量用量(以字节为单位)
缓冲区传输延迟时间 超大规模多切片流量的网络传输延迟时间。这种可视化图表可让您了解整体网络性能环境。 buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# 网络传输延迟时间分布,包括缓冲区大小、平均值、p50、p90、p99、p99.9
高级别操作执行时间分布指标 针对 HLO 编译二进制文件执行状态提供详细的性能数据分析,以便进行回归检测和模型级调试。 hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# CoreType-CoreID 的 HLO 执行时间时长分布,包含平均值、p50、p90、p95、p999
高级优化器队列大小 HLO 执行队列大小监控会跟踪正在等待执行或正在执行的已编译 HLO 程序的数量。此指标展示了执行流水线拥塞情况,从而能够找出硬件执行、驱动程序开销或资源分配中的性能瓶颈。 hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# 衡量 CoreType-CoreID 的队列大小。
总体端到端延迟时间 此指标用于衡量 DCN 上从发起操作的主机到接收输出的所有对等方的端到端集体延迟时间(以微秒为单位)。它包括主机端数据缩减和向 TPU 发送输出。结果是字符串,详细说明了缓冲区大小、类型以及平均延迟时间、p50、p90、p95 和 p99.9 延迟时间。 collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# 传输大小-集体操作、集体端到端延迟时间的平均值、p50、p90、p95、p999

读取指标数据 - 快照模式

如需启用快照模式,请在调用 tpumonitoring.get_metric 函数时指定指标名称。借助快照模式,您可以将临时指标检查插入到低性能代码中,以确定性能问题是源自软件还是硬件。

以下代码示例展示了如何使用快照模式读取 duty_cycle

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

使用 CLI 访问指标

以下步骤展示了如何使用 CLI 与 LibTPU 指标进行交互:

  1. 安装 tpu-info

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. 运行 tpu-info 的默认版本:

    tpu-info
    

    输出类似于以下内容:

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

使用指标检查 TPU 利用率

以下示例展示了如何使用 TPU 监控库中的指标来跟踪 TPU 利用率。

在 JAX 训练期间监控 TPU 占空比

场景:您正在运行 JAX 训练脚本,并希望在整个训练过程中监控 TPU 的 duty_cycle_pct 指标,以确认 TPU 是否得到了有效利用。您可以在训练期间定期记录此指标,以跟踪 TPU 利用率。

以下代码示例展示了如何在 JAX 训练期间监控 TPU 占空比:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

在运行 JAX 推理之前检查 HBM 利用率

场景:在使用 JAX 模型运行推理之前,请检查 TPU 上的当前 HBM(高带宽内存)利用率,以确认您有足够的可用内存,并在推理开始前获取基准测量结果。

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

TPU 指标的导出频率

TPU 指标的刷新频率仅限于最短 1 秒。主机指标数据以 1 Hz 的固定频率导出。此导出过程导致的延迟时间可以忽略不计。LibTPU 中的运行时指标不受相同的频率限制的约束。但是,为了保持一致性,这些指标的采样频率也为 1 Hz,即每秒 1 次采样。