TPU-Monitoring-Bibliothek
Mit den erweiterten TPU-Überwachungsfunktionen, die direkt auf der grundlegenden Softwareschicht LibTPU basieren, erhalten Sie detaillierte Einblicke in die Leistung und das Verhalten Ihrer Cloud TPU-Hardware. LibTPU umfasst Treiber, Netzwerkbibliotheken, den XLA-Compiler und die TPU-Laufzeit für die Interaktion mit TPUs. In diesem Dokument geht es jedoch hauptsächlich um die TPU Monitoring Library.
Die TPU Monitoring Library bietet:
Umfassende Beobachtbarkeit: Zugriff auf die Telemetrie-API und die Messwerte-Suite. So erhalten Sie detaillierte Informationen zur Betriebsleistung und zum spezifischen Verhalten Ihrer TPUs.
Diagnosetoolkits: Bietet ein SDK und eine Befehlszeilenschnittstelle (Command-Line Interface, CLI), mit denen Sie Ihre TPU-Ressourcen debuggen und detaillierte Leistungsanalysen durchführen können.
Diese Monitoring-Funktionen sind als kundenorientierte Lösung auf höchster Ebene konzipiert und bieten Ihnen die wichtigsten Tools, um Ihre TPU-Arbeitslasten effektiv zu optimieren.
Die TPU Monitoring Library bietet Ihnen detaillierte Informationen zur Leistung von ML-Arbeitslasten auf TPU-Hardware. Sie soll Ihnen helfen, die TPU-Auslastung zu verstehen, Engpässe zu identifizieren und Leistungsprobleme zu beheben. Sie erhalten detailliertere Informationen als mit Unterbrechungs- und Goodput-Messwerten sowie anderen Messwerten.
Erste Schritte mit der TPU Monitoring Library
Der Zugriff auf diese aussagekräftigen Statistiken ist ganz einfach. Die TPU-Monitoring-Funktion ist in das LibTPU SDK integriert. Sie ist also enthalten, wenn Sie LibTPU installieren.
LibTPU installieren
pip install libtpu
Alternativ werden die LibTPU-Updates mit JAX-Releases koordiniert. Wenn Sie also das neueste JAX-Release (monatlich veröffentlicht) installieren, wird in der Regel die neueste kompatible LibTPU-Version mit ihren Funktionen verwendet.
JAX installieren
pip install -U "jax[tpu]"
Für PyTorch-Nutzer bietet die Installation von PyTorch/XLA die neuesten LibTPU- und TPU-Monitoring-Funktionen.
PyTorch/XLA installieren
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Weitere Informationen zur Installation von PyTorch/XLA finden Sie im PyTorch/XLA-GitHub-Repository unter Installation.
Bibliothek in Python importieren
Wenn Sie die TPU Monitoring Library verwenden möchten, müssen Sie das Modul libtpu
in Ihren Python-Code importieren.
from libtpu.sdk import tpumonitoring
Alle unterstützten Funktionen auflisten
Liste aller Messwertnamen und der Funktionen, die sie unterstützen:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Unterstützte Messwerte
Das folgende Codebeispiel zeigt, wie Sie alle unterstützten Messwertnamen auflisten:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
In der folgenden Tabelle sind alle Messwerte und die entsprechenden Definitionen aufgeführt:
Messwert | Definition | Messwertname für die API | Beispielwerte |
---|---|---|---|
Tensor Core-Auslastung | Misst den Prozentsatz der TensorCore-Nutzung, berechnet als Prozentsatz der Vorgänge, die Teil der TensorCore-Vorgänge sind. Alle 1 Sekunde wird eine Stichprobe von 10 Mikrosekunden erstellt. Die Sampling-Rate kann nicht geändert werden. Mit diesem Messwert können Sie die Effizienz Ihrer Arbeitslasten auf TPU-Geräten überwachen. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# utilization percentage for accelerator ID 0-3 |
Prozentsatz des Arbeitszyklus | Prozentsatz der Zeit im vergangenen Stichprobenzeitraum (alle 5 Sekunden; kann durch Festlegen des Flags LIBTPU_INIT_ARG angepasst werden), in der der Beschleuniger aktiv Daten verarbeitet hat (aufgezeichnet mit Zyklen, die zum Ausführen von HLO-Programmen im letzten Stichprobenzeitraum verwendet wurden). Dieser Messwert gibt an, wie ausgelastet eine TPU ist. Er wird pro Chip ausgegeben.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Duty cycle percentage for accelerator ID 0-3 |
HBM-Gesamtkapazität | Dieser Messwert gibt die gesamte HBM-Kapazität in Byte an. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Gesamte HBM-Kapazität in Byte, die an die Beschleuniger-ID 0–3 angehängt ist |
HBM-Kapazitätsnutzung | Dieser Messwert gibt die Nutzung der HBM-Kapazität in Byte im letzten Stichprobenzeitraum (alle 5 Sekunden; kann durch Festlegen des Flags LIBTPU_INIT_ARG angepasst werden) an.
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Kapazitätsnutzung für HBM in Byte, die an die Beschleuniger-ID 0–3 angehängt ist |
Latenz der Pufferübertragung | Netzwerkübertragungslatenzen für Megascale-Multi-Slice-Traffic. Diese Visualisierung gibt Ihnen einen Überblick über die allgemeine Netzwerkumgebung. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution |
Messwerte für die Verteilung der Ausführungszeit von Vorgängen auf hoher Ebene | Bietet detaillierte Informationen zum Ausführungsstatus des kompilierten HLO-Binärprogramms, sodass Regressionen erkannt und Fehler auf Modellebene behoben werden können. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999 |
Warteschlangengröße für High Level Optimizer | Mit der Überwachung der Größe der HLO-Ausführungswarteschlange wird die Anzahl der kompilierten HLO-Programme erfasst, die auf die Ausführung warten oder gerade ausgeführt werden. Dieser Messwert gibt Aufschluss über Engpässe in der Ausführungspipeline und ermöglicht die Identifizierung von Leistungsengpässen bei der Hardwareausführung, beim Treiber-Overhead oder bei der Ressourcenzuweisung. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Measures queue size for CoreType-CoreID. |
Gesamt-End-to-End-Latenz | Dieser Messwert gibt die kollektive End-to-End-Latenz über das DCN in Mikrosekunden an, vom Host, der den Vorgang initiiert, bis alle Peers die Ausgabe empfangen. Dazu gehören die datenseitige Reduzierung auf dem Host und das Senden der Ausgabe an die TPU. Die Ergebnisse sind Strings mit Details zu Puffergröße, Typ und mittleren, p50-, p90-, p95- und p99,9-Latenzen. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency |
Messwertdaten lesen – Snapshot-Modus
Wenn Sie den Snapshot-Modus aktivieren möchten, geben Sie den Messwertnamen an, wenn Sie die Funktion tpumonitoring.get_metric
aufrufen. Im Snapshot-Modus können Sie Ad-hoc-Messwertprüfungen in Code mit geringer Leistung einfügen, um festzustellen, ob Leistungsprobleme auf Software oder Hardware zurückzuführen sind.
Das folgende Codebeispiel zeigt, wie der Snapshot-Modus zum Lesen von duty_cycle
verwendet wird.
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Über die Befehlszeile auf Messwerte zugreifen
In den folgenden Schritten wird gezeigt, wie Sie über die CLI mit LibTPU-Messwerten interagieren:
Installieren Sie
tpu-info
:pip install tpu-info
# Access help information of tpu-info tpu-info --help / -h
Führen Sie die Standardversion von
tpu-info
aus:tpu-info
Die Ausgabe sieht etwa so aus:
TPU Chips
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
┃ Chip ┃ Type ┃ Devices ┃ PID ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
│ /dev/accel0 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel1 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel2 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel3 │ TPU v4 chip │ 1 │ 130007 │
└─────────────┴─────────────┴─────────┴────────┘
TPU Runtime Utilization
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Device ┃ Memory usage ┃ Duty cycle ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ 0 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 1 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 2 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 3 │ 0.00 GiB / 31.75 GiB │ 0.00% │
└────────┴──────────────────────┴────────────┘
TensorCore Utilization
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ID ┃ TensorCore Utilization ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 0.00% │
│ 1 │ 0.00% │
│ 3 │ 0.00% │
│ 2 │ 0.00% |
└─────────┴────────────────────────┘
Buffer Transfer Latency
┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
┃ Buffer Size ┃ P50 ┃ P90 ┃ P95 ┃ P999 ┃
┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
│ 8MB+ | 0us │ 0us │ 0us │ 0us |
└─────────────┴─────┴─────┴─────┴──────┘
Messwerte zur Überprüfung der TPU-Auslastung verwenden
In den folgenden Beispielen wird gezeigt, wie Sie Messwerte aus der TPU Monitoring Library verwenden, um die TPU-Auslastung zu verfolgen.
TPU-Arbeitszyklus während des JAX-Trainings überwachen
Szenario:Sie führen ein JAX-Trainingsskript aus und möchten den duty_cycle_pct
-Messwert der TPU während des gesamten Trainingsprozesses überwachen, um zu bestätigen, dass Ihre TPUs effektiv genutzt werden. Sie können diesen Messwert während des Trainings regelmäßig protokollieren, um die TPU-Auslastung zu verfolgen.
Das folgende Codebeispiel zeigt, wie Sie den TPU-Duty-Cycle während des JAX-Trainings überwachen:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
HBM-Auslastung vor dem Ausführen der JAX-Inferenz prüfen
Szenario : Bevor Sie die Inferenz mit Ihrem JAX-Modell ausführen, prüfen Sie die aktuelle HBM-Auslastung (High Bandwidth Memory) auf der TPU, um zu bestätigen, dass genügend Arbeitsspeicher verfügbar ist, und um eine Basismessung zu erhalten, bevor die Inferenz beginnt.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Exporthäufigkeit von TPU-Messwerten
Die Aktualisierungshäufigkeit von TPU-Messwerten ist auf mindestens eine Sekunde begrenzt. Messwertdaten für Hosts werden mit einer festen Häufigkeit von 1 Hz exportiert. Die durch diesen Exportprozess verursachte Latenz ist vernachlässigbar. Für Laufzeitmesswerte aus LibTPU gilt diese Häufigkeitseinschränkung nicht. Aus Konsistenzgründen werden diese Messwerte jedoch auch mit 1 Hz oder 1 Stichprobe pro Sekunde erfasst.