Introdução ao Trillium (v6e)

v6e é usado para se referir ao Trillium nesta documentação, na API TPU e nos registros. v6e representa a 6ª geração de TPUs do Google.

Com 256 chips por Pod, a arquitetura v6e compartilha muitas semelhanças com a v5e. Esse sistema é otimizado para treinamento, ajuste e serviço de transformadores, texto para imagem e redes neurais convolucionais (CNNs).

Para mais informações sobre a arquitetura e as configurações do sistema v6e, consulte TPU v6e.

Este documento de introdução se concentra nos processos de treinamento e veiculação de modelos usando os frameworks JAX ou PyTorch. Com cada framework, é possível provisionar TPUs usando recursos enfileirados ou o GKE. A configuração do GKE pode ser feita usando comandos XPK ou do GKE.

Procedimento geral para treinar ou veicular um modelo usando a v6e

  1. Preparar um Google Cloud projeto
  2. Capacidade segura
  3. Provisione o ambiente do Cloud TPU
  4. Executar uma carga de trabalho de treinamento ou inferência de modelo

Preparar um projeto do Google Cloud

Antes de usar o Cloud TPU, você precisa:

  • Crie uma Google Cloud conta e um projeto com o faturamento ativado
  • Instalar os componentes Alfa da Google Cloud CLI
  • Ativar a API Cloud TPU
  • Criar um agente de serviço da Cloud TPU
  • Criar uma conta de serviço do Cloud TPU e conceder permissões

Para mais informações, consulte Configurar o ambiente do Cloud TPU.

Capacidade segura

Entre em contato com o suporte doGoogle Cloud para solicitar a cota de Cloud TPU v6e e responder a dúvidas sobre capacidade.

Provisionar o ambiente da Cloud TPU

A v6e Cloud TPU pode ser provisionada e gerenciada com o GKE, com o GKE e o XPK (uma ferramenta CLI wrapper no GKE) ou como recursos enfileirados.

Pré-requisitos

  • Verifique se o projeto tem cota de TPUS_PER_TPU_FAMILY suficiente, que especifica o número máximo de chips que podem ser acessados no projeto Google Cloud.
  • A v6e foi testada com a seguinte configuração:
    • Python 3.10 ou mais recente
    • Versões noturnas do software:
      • JAX noturno 0.4.32.dev20240912
      • LibTPU noturna 0.1.dev20240912+nightly
    • Versões estáveis do software:
      • JAX + JAX Lib v0.4.37
  • Verifique se o projeto tem cota suficiente para:

    • Cota de VM da Cloud TPU
    • Cota de endereços IP
    • Cota para o Hyperdisk Balanced e para qualquer outro tipo de disco que você queira usar

  • Se você estiver usando o GKE com o XPK, consulte Permissões do console do Cloud na conta de usuário ou de serviço para saber quais permissões são necessárias para executar o XPK.

Criar variáveis de ambiente

No Cloud Shell, crie as seguintes variáveis de ambiente:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Descrições de sinalizações de comando

Variável Descrição
NODE_ID O ID atribuído pelo usuário da Cloud TPU criada quando a solicitação de recurso enfileirada é alocada.
PROJECT_ID Google Cloud nome do projeto. Use um projeto atual ou crie um novo. Para mais informações, consulte Configurar seu Google Cloud projeto.
ZONA Consulte o documento Regiões e zonas do Cloud TPU para saber quais são as zonas compatíveis.
ACCELERATOR_TYPE Consulte Tipos de aceleradores.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Esse é o endereço de e-mail da sua conta de serviço, que pode ser encontrado em Google Cloud Console -> IAM -> Contas de serviço

Por exemplo: tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES O número de intervalos a serem criados (necessário apenas para Multislice).
QUEUED_RESOURCE_ID O ID de texto atribuído pelo usuário da solicitação de recurso em fila.
VALID_DURATION O período em que a solicitação de recurso na fila é válida.
NETWORK_NAME O nome de uma rede secundária a ser usada.
NETWORK_FW_NAME O nome de um firewall de rede secundário a ser usado.

Otimizar o desempenho da rede

Para ter o melhor desempenho,use uma rede com MTU (unidade máxima de transmissão) de 8.896.

Por padrão, uma nuvem privada virtual (VPC) só fornece uma MTU de 1.460 bytes,o que resulta em um desempenho de rede abaixo do ideal. É possível definir a MTU de uma rede VPC como qualquer valor entre 1.300 e 8.896 bytes (inclusive). Os tamanhos comuns de MTU personalizados são 1.500 bytes (Ethernet padrão) ou 8.896 bytes (o máximo possível). Para mais informações, consulte Tamanhos válidos de MTU da rede VPC.

Para mais informações sobre como mudar a configuração de MTU de uma rede padrão ou atual, consulte Alterar a configuração de MTU de uma rede VPC.

O exemplo a seguir cria uma rede com MTU de 8.896.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
   --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp --project=${PROJECT_ID}

Usar multi-NIC (opção para Multislice)

As seguintes variáveis de ambiente são necessárias para uma sub-rede secundária ao usar um ambiente Multislice.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Use os comandos a seguir para criar o roteamento de IP personalizado para a rede e a sub-rede.

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging

Depois de criar uma fatia de várias redes, valide se as duas placas de interface de rede (NICs) estão sendo usadas configurando um cluster XPK e adicionando a flag --command ifconfig ao comando de criação de carga de trabalho XPK.

Use o comando workload create a seguir para mostrar a saída do comando ifconfig nos registros do console Google Cloud e verifique se eth0 e eth1 têm mtu=8896.

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

Se você quiser ativar os registros de depuração ou usar o TensorBoard da Vertex AI, adicione os seguintes argumentos opcionais ao comando:

   --enable-debug-logs \
   --use-vertex-tensorboard

Verifique se eth0 e eth1 têm mtu=8896. Para verificar se a multi-NIC está em execução, adicione a flag --command ifconfig ao comando de criação da carga de trabalho XPK. Verifique a saída dessa carga de trabalho XPK nos registros do console Google Cloud e confira se eth0 e eth1 têm mtu=8.896.

Melhorar as configurações de TCP

Se você criou as Cloud TPUs usando a interface de recursos enfileirados, execute o comando a seguir para melhorar o desempenho da rede aumentando os limites do buffer de recebimento de TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "${PROJECT_ID}" \
   --zone "${ZONE}" \
   --node=all \
   --worker=all \
   --command='
   sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Provisionar com recursos na fila

É possível criar uma Cloud TPU v6e usando recursos enfileirados. Com os recursos enfileirados, você recebe capacidade assim que ela fica disponível. É possível especificar um horário de início e término opcional para quando a solicitação deve ser atendida. Para mais informações, consulte Gerenciar recursos em fila.

Provisionar Cloud TPUs v6e com GKE ou XPK

Se você estiver usando comandos do GKE com v6e, use comandos do Kubernetes ou XPK para provisionar TPUs do Cloud e treinar ou veicular modelos. Consulte Planejar o uso de Cloud TPUs no GKE para saber como planejar as configurações da Cloud TPU em clusters do GKE. As seções a seguir fornecem comandos para criar um cluster XPK com suporte para uma única NIC e várias NICs.

Criar um cluster XPK com suporte a uma única NIC

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

Descrições de sinalizações de comando

Variável Descrição
CLUSTER_NAME O nome atribuído pelo usuário para o cluster XPK.
PROJECT_ID Google Cloud nome do projeto. Use um projeto atual ou crie um novo. Para mais informações, consulte Configurar seu Google Cloud projeto.
ZONA Consulte o documento Regiões e zonas do Cloud TPU para saber quais são as zonas compatíveis.
TPU_TYPE Consulte Tipos de aceleradores.
NUM_SLICES O número de intervalos que você quer criar
CLUSTER_ARGUMENTS A rede e a sub-rede a serem usadas.

Por exemplo: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES O número de intervalos a serem criados.
NETWORK_NAME O nome de uma rede secundária a ser usada.
NETWORK_FW_NAME O nome de um firewall de rede secundário a ser usado.

Criar um cluster XPK com suporte a várias NICs

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Descrições de sinalizações de comando

Variável Descrição
CLUSTER_NAME O nome atribuído pelo usuário para o cluster XPK.
PROJECT_ID Google Cloud nome do projeto. Use um projeto atual ou crie um novo. Para mais informações, consulte Configurar seu Google Cloud projeto.
ZONA Consulte o documento Regiões e zonas do Cloud TPU para saber quais são as zonas compatíveis.
TPU_TYPE Consulte Tipos de aceleradores.
NUM_SLICES O número de intervalos que você quer criar
CLUSTER_ARGUMENTS A rede e a sub-rede a serem usadas.

Por exemplo: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Rede de nós adicional a ser usada.

Por exemplo: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES O número de intervalos a serem criados (necessário apenas para Multislice).
NETWORK_NAME O nome de uma rede secundária a ser usada.
NETWORK_FW_NAME O nome de um firewall de rede secundário a ser usado.

Configuração do framework

Nesta seção, descrevemos o processo geral de configuração para treinamento de modelo de ML usando os frameworks JAX e PyTorch. Se você estiver usando o GKE, use XPK ou comandos do Kubernetes para configurar o framework.

Configuração para JAX

Esta seção fornece instruções de configuração para executar cargas de trabalho do JAX no GKE, com ou sem XPK, além de usar recursos enfileirados.

Configurar o JAX usando o GKE

Fatia única em um único host

O exemplo a seguir configura um pool de nós de host único 2x2 usando um arquivo YAML do Kubernetes.

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

Após a conclusão, você vai receber a seguinte mensagem no registro do GKE:

Total TPU chips: 4

Uma única fração em vários hosts

O exemplo a seguir configura um pool de nós multihost 4x4 usando um arquivo YAML do Kubernetes.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

Após a conclusão, você vai receber a seguinte mensagem no registro do GKE:

Total TPU chips: 16

Multislice em vários hosts

O exemplo a seguir configura dois pools de nós multihospedeiro 4x4 usando um arquivo YAML do Kubernetes.

Como pré-requisito, você precisa instalar o JobSet v0.2.3 ou mais recente.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

Após a conclusão, você vai receber a seguinte mensagem no registro do GKE:

Total TPU chips: 32

Para mais informações, consulte Executar uma carga de trabalho com vários setores na documentação do GKE.

Para melhorar a performance, ative o hostNetwork.

Multi-NIC

Para usar o manifesto de várias NICs a seguir, configure suas redes. Para mais informações, consulte Configurar o suporte a várias redes para pods do Kubernetes.

Para aproveitar a multi-NIC no GKE, inclua algumas anotações adicionais no manifesto do pod do Kubernetes.

Confira a seguir um exemplo de manifesto de carga de trabalho com várias NICs sem TPU.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Se você usar o comando exec para se conectar ao pod do Kubernetes, verá a NIC adicional usando o seguinte código:

$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

Configurar o JAX usando o GKE com XPK

Para configurar o JAX usando o GKE e o XPK, consulte o README do XPK.

Para configurar e executar o XPK com o MaxText, consulte Como executar o MaxText.

Configurar o JAX usando recursos na fila

Instale o JAX em todas as VMs do Cloud TPU na sua fração ou frações simultaneamente usando o comando gcloud alpha compute tpus tpu-vm ssh. Para Multislice, adicione a flag --node=all.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='
   pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Execute o comando a seguir para verificar quantos núcleos do Cloud TPU estão disponíveis na sua fração e testar se tudo está instalado corretamente:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='
   python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

A saída será semelhante a esta ao executar em uma fração v6e-16:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() mostra o número total de chips na fração especificada. jax.local_device_count() indica a contagem de chips acessíveis por uma única VM nessa fatia.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
   git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 &&
   pip install setuptools==59.6.0 &&
   pip install -r requirements.txt && pip install .'

Resolver problemas de configuração do JAX

Uma dica geral é ativar o registro detalhado no manifesto da carga de trabalho do GKE. Em seguida, envie os registros para o suporte do GKE.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Mensagens de erro

no endpoints available for service 'jobset-webhook-service'

Esse erro significa que o jobset não foi instalado corretamente. Verifique se os pods do Kubernetes de implantação do jobset-controller-manager estão em execução. Para mais informações, consulte a documentação de solução de problemas do JobSet.

TPU initialization failed: Failed to connect

Verifique se a versão do nó do GKE é 1.30.4-gke.1348000 ou mais recente (o GKE 1.31 não é compatível).

Configuração do PyTorch

Esta seção descreve como começar a usar o PJRT na v6e com PyTorch/XLA. A versão 3.10 do Python é recomendada.

Configurar o PyTorch usando o GKE com XPK

Você pode usar o seguinte contêiner do Docker com XPK, que já tem as dependências do PyTorch instaladas:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Para criar uma carga de trabalho do XPK, use o seguinte comando:

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone ${ZONE} \
   --project ${PROJECT_ID} \
   --enable-debug-logs \
   --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

Usar --base-docker-image cria uma nova imagem do Docker com o diretório de trabalho atual integrado ao novo Docker.

Configurar o PyTorch usando recursos na fila

Siga estas etapas para instalar o PyTorch usando recursos enfileirados e executar um pequeno script no v6e.

Instalar dependências usando SSH para acessar as VMs

Use o comando a seguir para instalar dependências em todas as VMs do Cloud TPU. Para Multislice, adicione a flag --worker=all:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
   sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
   pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Melhorar a performance de modelos com alocações frequentes e consideráveis

Para modelos com alocações frequentes e grandes, o uso da função tcmalloc melhora significativamente o desempenho em comparação com a implementação padrão da função malloc. Portanto, a função malloc padrão usada na VM do Cloud TPU é tcmalloc. No entanto, dependendo da sua carga de trabalho (por exemplo, com DLRM, que tem alocações muito grandes para as tabelas de incorporação), a função tcmalloc pode causar uma lentidão. Nesse caso, tente desativar a seguinte variável usando a função malloc padrão:

unset LD_PRELOAD

Usar um script Python para fazer um cálculo em uma VM v6e

Use o comando a seguir para executar um script que cria dois tensores, os adiciona e imprime o resultado:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} \
   --worker all \
   --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

Isso gera um resultado semelhante ao seguinte:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

v6e com SkyPilot

Você pode usar a Cloud TPU v6e com o SkyPilot. Siga estas etapas para adicionar informações de local e preço relacionadas ao v6e ao SkyPilot. Para mais informações, consulte o exemplo do SkyPilot TPU v6e.

Tutoriais de inferência

Os tutoriais a seguir mostram como executar a inferência na Cloud TPU v6e:

Exemplos de treinamento

As seções a seguir fornecem exemplos de treinamento de modelos MaxText, MaxDiffusion e PyTorch na Cloud TPU v6e.

Treinamento de MaxText e MaxDiffusion na VM do Cloud TPU v6e

As seções a seguir abordam o ciclo de vida de treinamento dos modelos MaxText e MaxDiffusion.

Em geral, as etapas de alto nível são:

  1. Crie a imagem de base da carga de trabalho.
  2. Execute a carga de trabalho usando XPK.
    1. Crie o comando de treinamento para a carga de trabalho.
    2. Implante a carga de trabalho.
  3. Acompanhe a carga de trabalho e confira as métricas.
  4. Exclua a carga de trabalho XPK se ela não for necessária.
  5. Exclua o cluster XPK quando ele não for mais necessário.

Criar imagem de base

Instale o MaxText ou o MaxDiffusion e crie a imagem do Docker:

  1. Clone o repositório que você quer usar e mude para o diretório dele:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Configure o Docker para usar a Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Crie a imagem do Docker usando o comando a seguir ou a pilha estável do JAX. Para mais informações sobre o JAX Stable Stack, consulte Criar uma imagem Docker com o JAX Stable Stack.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Defina o ID do projeto na configuração ativa da CLI gcloud:

    gcloud config set project ${PROJECT_ID}
    
  5. Se você estiver iniciando a carga de trabalho em uma máquina que não tem a imagem criada localmente, faça upload dela.

    1. Defina a variável de ambiente CLOUD_IMAGE_NAME:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Faça o upload da imagem:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Executar a carga de trabalho usando XPK

  1. Defina as seguintes variáveis de ambiente se você não estiver usando os valores padrão definidos pelo MaxText ou MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Crie o script do modelo. Esse script será copiado como um comando de treinamento em uma etapa posterior.

    Não execute o script do modelo ainda.

    MaxText

    O MaxText é um LLM de código aberto de alto desempenho e altamente escalonável escrito em Python e JAX puros e destinado a TPUs e GPUs para treinamento e inferência. Google Cloud

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    O Gemma é uma família de LLMs de peso aberto desenvolvidos pelo Google DeepMind com base na pesquisa e tecnologia do Gemini.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    O Mixtral é um modelo de IA de última geração desenvolvido pela Mistral AI, que usa uma arquitetura esparsa de combinação de especialistas (MoE, na sigla em inglês).

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    O Llama é uma família de LLMs de peso aberto desenvolvidos pela Meta.

    Para ver um exemplo de como executar o Llama3 no PyTorch, consulte modelos torch_xla no repositório torchprime.

    MaxDiffusion

    O MaxDiffusion é uma coleção de implementações de referência de vários modelos de difusão latente escritos em Python e JAX puros que são executados em dispositivos XLA, incluindo TPUs e GPUs do Cloud. O Stable Diffusion é um modelo de texto latente para imagem que gera imagens fotorrealistas com base em qualquer entrada de texto.

    É necessário instalar uma ramificação específica do Git para executar o MaxDiffusion, conforme mostrado no comando git clone a seguir.

    Script de treinamento:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16  resolution=1024  per_device_batch_size=1 output_dir=${OUT_DIR} jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=200 attention=flash run_name=sdxl-ddp-v6e
    
  3. Exporte as seguintes variáveis:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Descrições de variáveis de ambiente

    Variável Descrição
    CLUSTER_NAME O nome do cluster do XPK.
    ACCELERATOR_TYPE Consulte Tipos de aceleradores.
    NUM_SLICES O número de fatias de TPU.
    YOUR_MODEL_SCRIPT O script do modelo a ser executado como um comando de treinamento.
  4. Execute o modelo usando o script criado na etapa anterior. É necessário especificar a flag --base-docker-image para usar a imagem de base do MaxText ou especificar a flag --docker-image e a imagem que você quer usar.

    Opcional: é possível ativar o registro de depuração incluindo a flag --enable-debug-logs. Para mais informações, consulte Depurar o JAX no MaxText.

    Opcional: é possível criar um experimento da Vertex AI para fazer upload de dados para o TensorBoard da Vertex AI incluindo a flag --use-vertex-tensorboard. Para mais informações, consulte Monitorar o JAX no MaxText usando a Vertex AI.

    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      [--enable-debug-logs] \
      [--use-vertex-tensorboard] \
      --command="${YOUR_MODEL_SCRIPT}"

    A saída inclui um link para acompanhar sua carga de trabalho. Abra o link e clique na guia Registros para acompanhar sua carga de trabalho em tempo real.

Depurar o JAX no MaxText

Use comandos XPK complementares para diagnosticar por que o cluster ou a carga de trabalho não está sendo executada:

Monitorar o JAX no MaxText usando a Vertex AI

Para usar o TensorBoard, sua conta de usuário do Google Cloud precisa ter a função aiplatform.user. Execute o seguinte comando para conceder esse papel:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Conferir dados escalares e de perfil no TensorBoard gerenciado da Vertex AI.

  1. Aumente as solicitações de gerenciamento de recursos (CRUD) para a zona que você está usando de 600 para 5.000. Isso pode não ser um problema para cargas de trabalho pequenas que usam menos de 16 VMs.

  2. Instale dependências como cloud-accelerator-diagnostics para a Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crie o cluster XPK usando a flag --create-vertex-tensorboard, conforme documentado em Criar o TensorBoard da Vertex AI. Você também pode executar esse comando em clusters atuais.

  4. Crie seu experimento da Vertex AI ao executar a carga de trabalho do XPK usando a flag --use-vertex-tensorboard e a flag opcional --experiment-name. Para conferir a lista completa de etapas, consulte Criar um experimento da Vertex AI para fazer upload de dados no TensorBoard da Vertex AI.

Os registros incluem um link para um TensorBoard da Vertex AI, semelhante ao seguinte:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Também é possível encontrar o link do TensorBoard da Vertex AI no console Google Cloud . Acesse Experimentos da Vertex AI no console Google Cloud . Selecione a região apropriada no menu suspenso.

O diretório do TensorBoard também é gravado no bucket do Cloud Storage especificado com ${BASE_OUTPUT_DIR}.

Excluir cargas de trabalho XPK

Use o comando xpk workload delete para excluir uma ou mais cargas de trabalho com base no prefixo ou status do job. Esse comando pode ser útil se você enviou cargas de trabalho XPK que não precisam mais ser executadas ou se há jobs presos na fila.

Excluir cluster XPK

Use o comando xpk cluster delete para excluir um cluster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
   --zone=${ZONE} --project=${PROJECT_ID}

Treinamento do Llama e do PyTorch/XLA em uma VM da Cloud TPU v6e

Neste tutorial, descrevemos como treinar modelos Llama usando PyTorch/XLA no Cloud TPU v6e com o conjunto de dados WikiText.

Acessar o Hugging Face e o modelo Llama 3

Você precisa de um token de acesso de usuário do Hugging Face para executar este tutorial. Para informações sobre como criar tokens de acesso do usuário, consulte a documentação do Hugging Face sobre tokens de acesso do usuário.

Você também precisa de permissão para acessar o modelo Llama-3-8B no Hugging Face. Para ter acesso, acesse o modelo Meta-Llama-3-8B no HuggingFace e solicite acesso.

Criar uma VM da Cloud TPU

Crie um Cloud TPU v6e com oito chips para executar o tutorial.

  1. Configure as variáveis de ambiente:

    export NODE_ID=your-tpu-name
    export PROJECT_ID=your-project-id
    export ACCELERATOR_TYPE=v6e-8
    export ZONE=us-east1-d
    export RUNTIME_VERSION=v2-alpha-tpuv6e
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export VALID_DURATION=your-duration 
  2. Crie uma VM do Cloud TPU:

    gcloud alpha compute tpus tpu-vm create ${NODE_ID} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Instalação

Instale o fork pytorch-tpu/transformers dos transformadores e dependências do Hugging Face. Este tutorial foi testado com as seguintes versões de dependência usadas neste exemplo:

  • torch: compatível com 2.5.0
  • torch_xla[tpu]: compatível com 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Configurar configurações de modelo

O comando de treinamento na próxima seção, Executar o modelo, usa dois arquivos de configuração JSON para definir parâmetros do modelo e configuração de paralelismo de dados totalmente fragmentados (FSDP, na sigla em inglês). O sharding do FSDP permite usar um tamanho de lote maior durante o treinamento, fazendo o sharding dos pesos do modelo em várias TPUs. Ao treinar com modelos menores, pode ser suficiente usar o paralelismo de dados e replicar os pesos em cada dispositivo. Para mais informações sobre como fragmentar tensores em dispositivos no PyTorch/XLA, consulte o guia do usuário do PyTorch/XLA SPMD.

  1. Crie o arquivo de configuração de parâmetros do modelo. Confira a seguir a configuração de parâmetros do modelo para o Llama-3-8B. Para outros modelos, encontre a configuração no Hugging Face (em inglês). Por exemplo, consulte a configuração do Llama-2-7B.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Crie o arquivo de configuração do FSDP:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Para mais informações sobre a FSDP, consulte FSDPv2.

  3. Faça upload dos arquivos de configuração para as VMs da TPU usando o seguinte comando:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${NODE_ID}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Executar o modelo

Usando os arquivos de configuração criados na seção anterior, execute o script run_clm.py para treinar o modelo Llama-3-8B no conjunto de dados WikiText. O script de treinamento leva aproximadamente 10 minutos para ser executado em uma Cloud TPU v6e-8.

  1. Faça login no Hugging Face na sua Cloud TPU usando o seguinte comando:

    gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Execute o treinamento de modelo:

    gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Solução de problemas do PyTorch/XLA

Se você definiu as variáveis opcionais para depuração na seção anterior, o perfil do modelo será armazenado no local especificado pela variável PROFILE_LOGDIR. É possível extrair o arquivo xplane.pb armazenado nesse local e usar tensorboard para conferir os perfis no navegador seguindo as instruções do TensorBoard.

Se o PyTorch/XLA não estiver funcionando como esperado, consulte o Guia de solução de problemas, que tem sugestões para depurar, criar perfis e otimizar seu modelo.

Resultados de comparativo de mercado

A seção a seguir contém resultados de comparativos de mercado para o MaxDiffusion no v6e.

MaxDiffusion

Executamos o script de treinamento do MaxDiffusion em uma v6e-4, uma v6e-16 e duas v6e-16. Confira as taxas de transferência na tabela a seguir.

v6e-4 v6e-16 Dois v6e-16
Etapas de treinamento 0,069 0,073 0,13
Tamanho global do lote 8 32 64
Capacidade (exemplos/segundo) 115,9 438,4 492,3