Biblioteca de monitoramento de TPU
Tenha insights detalhados sobre o desempenho e o comportamento do hardware da Cloud TPU com recursos avançados de monitoramento, criados diretamente na camada de software fundamental, a LibTPU. Embora a LibTPU inclua drivers, bibliotecas de rede, o compilador XLA e o ambiente de execução de TPU para interagir com TPUs, o foco deste documento é a biblioteca de monitoramento de TPU.
A biblioteca de monitoramento de TPU oferece:
Observabilidade abrangente: tenha acesso à API de telemetria e ao conjunto de métricas. Assim, você recebe insights detalhados sobre a performance operacional e comportamentos específicos das suas TPUs.
Kits de ferramentas de diagnóstico: fornecem um SDK e uma interface de linha de comando (CLI) projetados para permitir a depuração e a análise detalhada de performance dos seus recursos de TPU.
Esses recursos de monitoramento foram projetados para serem uma solução de alto nível voltada ao cliente, oferecendo as ferramentas essenciais para otimizar suas cargas de trabalho de TPU de maneira eficaz.
A biblioteca de monitoramento de TPU oferece informações detalhadas sobre o desempenho das cargas de trabalho de machine learning no hardware de TPU. Ele foi projetado para ajudar você a entender a utilização da TPU, identificar gargalos e depurar problemas de desempenho. Ela oferece informações mais detalhadas do que as métricas de interrupção, de goodput e outras.
Começar a usar a biblioteca de monitoramento de TPU
É fácil acessar esses insights valiosos. A funcionalidade de monitoramento de TPU é integrada ao SDK LibTPU. Portanto, ela é incluída quando você instala o LibTPU.
Instalar a LibTPU
pip install libtpu
Como alternativa, as atualizações do LibTPU são coordenadas com os lançamentos do JAX. Isso significa que, ao instalar a versão mais recente do JAX (lançada mensalmente), você geralmente terá acesso à versão mais recente compatível do LibTPU e aos recursos dela.
Instalar o JAX
pip install -U "jax[tpu]"
Para usuários do PyTorch, a instalação do PyTorch/XLA oferece a funcionalidade mais recente de LibTPU e monitoramento de TPU.
Instalar o PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Para mais informações sobre como instalar o PyTorch/XLA, consulte Instalação no repositório do PyTorch/XLA no GitHub.
Importar a biblioteca em Python
Para começar a usar a biblioteca de monitoramento de TPU, importe o módulo libtpu
no seu código Python.
from libtpu.sdk import tpumonitoring
Listar todas as funcionalidades compatíveis
Liste todos os nomes de métricas e a funcionalidade que eles oferecem:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Métricas compatíveis
O exemplo de código a seguir mostra como listar todos os nomes de métricas aceitos:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
A tabela a seguir mostra todas as métricas e as definições correspondentes:
Métrica | Definição | Nome da métrica para a API | Exemplos de valores |
---|---|---|---|
Utilização do Tensor Core | Mede a porcentagem do uso do TensorCore, calculada como a porcentagem de operações que fazem parte das operações do TensorCore. Amostras coletadas a cada 1 segundo. Não é possível modificar a taxa de amostragem. Com essa métrica, é possível monitorar a eficiência das cargas de trabalho em dispositivos TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# utilization percentage for accelerator ID 0-3 |
Porcentagem do ciclo de trabalho | Porcentagem de tempo durante o período de amostra anterior (a cada 5 segundos; pode ser ajustada definindo a flag LIBTPU_INIT_ARG ) em que o acelerador estava processando ativamente (registrado com ciclos usados para executar programas HLO durante o último período de amostragem). Essa métrica representa o grau de ocupação de uma TPU e é emitida por chip.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Porcentagem do ciclo de trabalho para o ID do acelerador 0 a 3 |
Capacidade total de HBM | Essa métrica informa a capacidade total de HBM em bytes. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Capacidade total de HBM em bytes anexada aos ID de acelerador 0 a 3 |
Uso da capacidade de HBM | Essa métrica informa o uso da capacidade de HBM em bytes no último período de amostragem (a cada 5 segundos, que pode ser ajustado definindo a flag LIBTPU_INIT_ARG ).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Uso da capacidade para HBM em bytes anexados ao ID do acelerador 0-3 |
Latência de transferência de buffer | Latências de transferência de rede para tráfego multislice em megascale. Com essa visualização, você entende o ambiente geral de performance da rede. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# tamanho do buffer, média, p50, p90, p99, p99.9 da distribuição de latência de transferência de rede |
Métricas de distribuição do tempo de execução de operações de alto nível | Fornece insights granulares de desempenho sobre o status de execução do binário compilado do HLO, permitindo a detecção de regressão e a depuração no nível do modelo. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# A distribuição da duração do tempo de execução do HLO para CoreType-CoreID com média, p50, p90, p95, p999 |
Tamanho da fila do otimizador de alto nível | O monitoramento do tamanho da fila de execução de HLO rastreia o número de programas HLO compilados aguardando ou em execução. Essa métrica revela congestionamento no pipeline de execução, permitindo a identificação de gargalos de desempenho na execução de hardware, sobrecarga de driver ou alocação de recursos. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Mede o tamanho da fila para CoreType-CoreID. |
Latência coletiva de ponta a ponta | Essa métrica mede a latência coletiva de ponta a ponta no DCN em microssegundos, desde o host que inicia a operação até todos os peers que recebem a saída. Isso inclui a redução de dados do lado do host e o envio da saída para a TPU. Os resultados são strings que detalham o tamanho, o tipo e as latências média, p50, p90, p95 e p99,9 do buffer. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency |
Ler dados de métricas: modo de snapshot
Para ativar o modo de snapshot, especifique o nome da métrica ao chamar a função
tpumonitoring.get_metric
. O modo
de snapshot permite inserir verificações de métricas ad hoc em códigos de baixo desempenho para
identificar se os problemas de performance são causados por software ou hardware.
O exemplo de código a seguir mostra como usar o modo de snapshot para ler o duty_cycle
.
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Acessar métricas usando a CLI
As etapas a seguir mostram como interagir com as métricas da LibTPU usando a CLI:
Instale
tpu-info
:pip install tpu-info
# Access help information of tpu-info tpu-info --help / -h
Execute a visão padrão de
tpu-info
:tpu-info
O resultado será assim:
TPU Chips
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
┃ Chip ┃ Type ┃ Devices ┃ PID ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
│ /dev/accel0 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel1 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel2 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel3 │ TPU v4 chip │ 1 │ 130007 │
└─────────────┴─────────────┴─────────┴────────┘
TPU Runtime Utilization
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Device ┃ Memory usage ┃ Duty cycle ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ 0 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 1 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 2 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 3 │ 0.00 GiB / 31.75 GiB │ 0.00% │
└────────┴──────────────────────┴────────────┘
TensorCore Utilization
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ID ┃ TensorCore Utilization ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 0.00% │
│ 1 │ 0.00% │
│ 3 │ 0.00% │
│ 2 │ 0.00% |
└─────────┴────────────────────────┘
Buffer Transfer Latency
┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
┃ Buffer Size ┃ P50 ┃ P90 ┃ P95 ┃ P999 ┃
┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
│ 8MB+ | 0us │ 0us │ 0us │ 0us |
└─────────────┴─────┴─────┴─────┴──────┘
Usar métricas para verificar a utilização da TPU
Os exemplos a seguir mostram como usar métricas da biblioteca de monitoramento de TPU para acompanhar a utilização da TPU.
Monitorar o ciclo de trabalho da TPU durante o treinamento do JAX
Cenário:você está executando um script de treinamento do JAX e quer monitorar a métrica duty_cycle_pct
da TPU durante todo o processo de treinamento para confirmar se as TPUs estão sendo usadas de maneira eficaz. É possível registrar essa métrica periodicamente durante o treinamento para acompanhar a utilização da TPU.
O exemplo de código a seguir mostra como monitorar o ciclo de trabalho da TPU durante o treinamento do JAX:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Verificar a utilização da HBM antes de executar a inferência do JAX
Cenário : antes de executar a inferência com seu modelo JAX, verifique a utilização atual da HBM (memória de alta largura de banda) na TPU para confirmar se você tem memória suficiente disponível e para receber uma medição de base antes do início da inferência.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Frequência de exportação das métricas da TPU
A frequência de atualização das métricas de TPU é limitada a um mínimo de um segundo. Os dados de métricas do host são exportados em uma frequência fixa de 1 Hz. A latência introduzida por esse processo de exportação é insignificante. As métricas de tempo de execução da LibTPU não estão sujeitas à mesma restrição de frequência. No entanto, para fins de consistência, essas métricas também são amostradas a 1 Hz ou 1 amostra por segundo.