Introdução ao Trillium (v6e)

O v6e é usado para se referir ao Trillium nesta documentação, na API TPU e nos registros. O v6e representa a 6ª geração de TPU do Google.

Com 256 chips por pod, a arquitetura v6e compartilha muitas semelhanças com a v5e. Esse sistema é otimizado para transformador, texto para imagem e treinamento, ajuste fino e exibição de rede neural convolucional (CNN).

Para mais informações sobre a arquitetura do sistema v6e e as configurações, consulte TPU v6e.

Este documento de introdução se concentra nos processos de treinamento e serviço de modelos usando os frameworks JAX, PyTorch ou TensorFlow. Com cada framework, é possível provisionar TPUs usando recursos em fila ou o GKE. A configuração do GKE pode ser feita usando comandos XPK ou do GKE.

Procedimento geral para treinar ou exibir um modelo usando a v6e

  1. Preparar um Google Cloud projeto
  2. Capacidade segura
  3. Provisionar o ambiente do Cloud TPU
  4. Executar uma carga de trabalho de treinamento ou inferência de modelo

Preparar um projeto do Google Cloud

Antes de usar o Cloud TPU, você precisa:

  • Criar um projeto e uma Google Cloud conta com o faturamento ativado
  • Instale os componentes Alfa da CLI do Google Cloud.
  • Ativar a API Cloud TPU
  • Criar um agente de serviço do Cloud TPU
  • Criar uma conta de serviço do Cloud TPU e conceder permissões

Para mais informações, consulte Configurar o ambiente do Cloud TPU.

Capacidade segura

Entre em contato com o Google Cloud suporte para solicitar a cota do Cloud TPU v6e e tirar dúvidas sobre a capacidade.

Provisionar o ambiente do Cloud TPU

O Cloud TPU v6e pode ser provisionado e gerenciado com o GKE, com o GKE e o XPK (uma ferramenta de wrapper da CLI sobre o GKE) ou como recursos em fila.

Pré-requisitos

  • Verifique se o projeto tem cota de TPUS_PER_TPU_FAMILY suficiente, que especifica o número máximo de chips que você pode acessar no projeto Google Cloud.
  • O v6e foi testado com a seguinte configuração:
    • Python 3.10 ou mais recente
    • Versões noturnas do software:
      • JAX noturno 0.4.32.dev20240912
      • LibTPU noturno 0.1.dev20240912+nightly
    • Versões estáveis do software:
      • JAX + JAX Lib da v0.4.37
  • Verifique se o projeto tem cota suficiente para:

    • Cota de VM do Cloud TPU
    • Cota de endereços IP
    • Quota do Hyperdisk equilibrado

  • Se você estiver usando o GKE com XPK, consulte Permissões do console do Google Cloud na conta de usuário ou de serviço para conferir as permissões necessárias para executar o XPK.

Criar variáveis de ambiente

No Cloud Shell, crie as seguintes variáveis de ambiente:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for provisioning Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Descrições de sinalizações de comando

Variável Descrição
NODE_ID O ID atribuído pelo usuário do Cloud TPU, que é criado quando a solicitação de recurso enfileirada é alocada.
PROJECT_ID Google Cloud nome do projeto. Use um projeto existente ou crie um novo. Para mais informações, consulte Configurar seu projeto Google Cloud .
ZONA Consulte o documento Regiões e zonas do Cloud TPU para ver as zonas com suporte.
ACCELERATOR_TYPE Consulte Tipos de acelerador.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Esse é o endereço de e-mail da sua conta de serviço, que pode ser encontrado em Google Cloud Console -> IAM -> Contas de serviço

Por exemplo: tpu-service-account@.iam.gserviceaccount.com.com

NUM_SLICES O número de fatias a serem criadas (somente para fatias múltiplas).
QUEUED_RESOURCE_ID O ID de texto atribuído pelo usuário da solicitação de recurso em fila.
VALID_DURATION O período em que a solicitação de recurso em fila é válida.
NETWORK_NAME O nome de uma rede secundária a ser usada.
NETWORK_FW_NAME O nome de um firewall de rede secundário a ser usado.

Otimizar a performance da rede

Para ter a melhor performance,use uma rede com 8.896 MTU (unidade máxima de transmissão).

Por padrão, uma nuvem privada virtual (VPC) fornece apenas uma MTU de 1.460 bytes,o que vai proporcionar um desempenho de rede subótimo. É possível definir a MTU de uma rede VPC como qualquer valor entre 1.300 e 8.896 bytes (inclusivo). Os tamanhos comuns de MTU personalizados são 1.500 bytes (Ethernet padrão) ou 8.896 bytes (o máximo possível). Para mais informações, consulte Tamanhos válidos de MTU da rede VPC.

Para mais informações sobre como mudar a configuração de MTU de uma rede existente ou padrão, consulte Alterar a configuração de MTU de uma rede VPC.

O exemplo a seguir cria uma rede com 8.896 MTU.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
 --allow tcp,icmp,udp --project=${PROJECT_ID}

Como usar a multi-NIC (opção para Multislice)

As variáveis de ambiente a seguir são necessárias para uma sub-rede secundária quando você está usando um ambiente de várias fatias.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Use os comandos a seguir para criar um roteamento IP personalizado para a rede e a sub-rede.

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
  --project=${PROJECT_ID} \
  --network=${NETWORK_NAME_2} \
  --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
  --router=${ROUTER_NAME} \
  --region=${REGION} \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project=${PROJECT_ID} \
  --enable-logging

Depois de criar uma fatia de várias redes, é possível validar se as duas placas de interface de rede (NICs, na sigla em inglês) estão sendo usadas configurando um cluster XPK e adicionando a flag --command ifconfig ao comando de criação de carga de trabalho XPK.

Use o comando xpk workload a seguir para mostrar a saída do comando ifconfig nos registros do console do Google Cloud e verifique se eth0 e eth1 têm mtu=8896.

python3 xpk.py workload create \
   --cluster your-cluster-name \
   (--base-docker-image maxtext_base_image|--docker-image your-cloud-image-name \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

Verifique se eth0 e eth1 têm mtu=8,896. Para verificar se a NIC está em execução, adicione a flag --command ifconfig ao comando de criação de carga de trabalho XPK. Verifique a saída dessa carga de trabalho xpk nos registros do console do Google Cloud e verifique se eth0 e eth1 têm mtu=8896.

Melhorar as configurações de TCP

Se você criou os Cloud TPUs usando a interface de recursos enfileirados, execute o comando a seguir para melhorar a performance da rede aumentando os limites do buffer de recebimento do TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "${PROJECT}" \
  --zone "${ZONE}" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

Provisionar com recursos na fila

É possível criar uma Cloud TPU v6e usando recursos enfileirados. Os recursos em fila permitem que você receba capacidade quando ela estiver disponível. É possível especificar um horário de início e término opcional para quando a solicitação precisa ser preenchida. Para mais informações, consulte Gerenciar recursos na fila.

Provisionar Cloud TPUs v6e com o GKE ou XPK

Se você estiver usando comandos do GKE com o v6e, use comandos do Kubernetes ou o XPK para provisionar TPUs do Cloud e treinar ou oferecer modelos. Consulte Planejar Cloud TPUs no GKE para saber como planejar suas configurações do Cloud TPU em clusters do GKE. As seções a seguir fornecem comandos para criar um cluster XPK com suporte a uma NIC e suporte a várias NICs.

Criar um cluster XPK com suporte a uma única NIC

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
   gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
   gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
   python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=n1-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments=${CLUSTER_ARGUMENTS}  \
   --create-vertex-tensorboard

Descrições de sinalizações de comando

Variável Descrição
CLUSTER_NAME O nome atribuído pelo usuário ao cluster XPK.
PROJECT_ID Google Cloud nome do projeto. Use um projeto existente ou crie um novo. Para mais informações, consulte Configurar seu projeto Google Cloud .
ZONA Consulte o documento Regiões e zonas do Cloud TPU para ver as zonas com suporte.
TPU_TYPE Consulte Tipos de acelerador.
NUM_SLICES O número de fatias que você quer criar
CLUSTER_ARGUMENTS A rede e a sub-rede a serem usadas.

Por exemplo: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES O número de fatias a serem criadas.
NETWORK_NAME O nome de uma rede secundária a ser usada.
NETWORK_FW_NAME O nome de um firewall de rede secundário a ser usado.

Criar um cluster XPK com suporte a várias NICs

export CLUSTER_NAME xpk-cluster-name
export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
   gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
   gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
   gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
  gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_1} \
    --region=${REGION}
  gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
   gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
   gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
   gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
   gcloud compute routers create ${ROUTER_NAME} \
     --project=${PROJECT_ID} \
     --network=${NETWORK_NAME_2} \
     --region=${REGION}
   gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster=${CLUSTER_NAME} \
--num-slices=${NUM_SLICES} \
--tpu-type=${TPU_TYPE} \
--zone=${ZONE}  \
--project=${PROJECT_ID} \
--on-demand \
--custom-cluster-arguments=${CLUSTER_ARGUMENTS} \
--custom-nodepool-arguments=${NODE_POOL_ARGUMENTS} \
--create-vertex-tensorboard

Descrições de sinalizações de comando

Variável Descrição
CLUSTER_NAME O nome atribuído pelo usuário ao cluster XPK.
PROJECT_ID Google Cloud nome do projeto. Use um projeto existente ou crie um novo. Para mais informações, consulte Configurar seu projeto Google Cloud .
ZONA Consulte o documento Regiões e zonas do Cloud TPU para ver as zonas com suporte.
TPU_TYPE Consulte Tipos de acelerador.
NUM_SLICES O número de fatias que você quer criar
CLUSTER_ARGUMENTS A rede e a sub-rede a serem usadas.

Por exemplo: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Rede de nó adicional a ser usada.

Por exemplo: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES O número de fatias a serem criadas (somente para fatias múltiplas).
NETWORK_NAME O nome de uma rede secundária a ser usada.
NETWORK_FW_NAME O nome de um firewall de rede secundário a ser usado.

Configuração do framework

Esta seção descreve o processo de configuração geral para treinamento de modelos de ML usando os frameworks JAX, PyTorch ou TensorFlow. Se você estiver usando o GKE, poderá usar comandos XPK ou do Kubernetes para a configuração do framework.

Configuração para o JAX

Esta seção fornece instruções de configuração para executar cargas de trabalho do JAX no GKE, com ou sem XPK, e usar recursos em fila.

Configurar o JAX usando o GKE

Uma fatia em um host

O exemplo a seguir configura um pool de nós de host único 2x2 usando um arquivo YAML do Kubernetes.

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

Após a conclusão, você vai encontrar a seguinte mensagem no registro do GKE:

Total TPU chips: 4

Fração única em vários hosts

O exemplo a seguir configura um pool de nós multi-host 4x4 usando um arquivo YAML do Kubernetes.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

Após a conclusão, você vai encontrar a seguinte mensagem no registro do GKE:

Total TPU chips: 16

Multislice em vários hosts

O exemplo a seguir configura dois pools de nós multi-host 4x4 usando um arquivo YAML do Kubernetes.

Como pré-requisito, é necessário instalar o JobSet v0.2.3 ou mais recente.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

Após a conclusão, você vai encontrar a seguinte mensagem no registro do GKE:

Total TPU chips: 32

Para mais informações, consulte Executar uma carga de trabalho com vários setores na documentação do GKE.

Para melhorar o desempenho, ative a hostNetwork.

Várias NICs

Para aproveitar o recurso de várias NICs no GKE, o manifesto do pod do Kubernetes precisa ter outras anotações. Confira a seguir um exemplo de manifesto de carga de trabalho multi-NIC sem TPU.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Se você usar o comando exec para se conectar ao pod do Kubernetes, a NIC adicional vai aparecer usando o código a seguir.

$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

Configurar o JAX usando o GKE com o XPK

Para configurar o JAX usando o GKE e o XPK, consulte o README do xpk.

Para configurar e executar o XPK com o MaxText, consulte Como executar o MaxText.

Configurar o JAX usando recursos na fila

Instale o JAX em todas as VMs do Cloud TPU no seu ou nos seus segmentos simultaneamente usando o comando gcloud alpha compute tpus tpu-vm ssh. Para "Multislice", adicione a flag --node=all.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests
 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Você pode executar o comando a seguir para verificar quantos núcleos do Cloud TPU estão disponíveis no seu slice e testar se tudo está instalado corretamente:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

O resultado será semelhante ao seguinte quando executado em uma fatia v6e-16:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() mostra o número total de ícones na fração. jax.local_device_count() indica a contagem de chips acessíveis por uma única VM nesta fatia.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   && pip install -r requirements.txt  && pip install . '

Resolver problemas de configuração do JAX

Uma dica geral é ativar o registro detalhado no manifesto da carga de trabalho do GKE. Em seguida, envie os registros para o suporte do GKE.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Mensagens de erro

no endpoints available for service 'jobset-webhook-service'

Esse erro significa que o jobset não foi instalado corretamente. Verifique se os pods do Kubernetes da implantação jobset-controller-manager estão em execução. Para mais informações, consulte a documentação de solução de problemas do JobSet.

TPU initialization failed: Failed to connect

Verifique se a versão do nó do GKE é 1.30.4-gke.1348000 ou mais recente. Não há suporte para o GKE 1.31.

Configuração para PyTorch

Esta seção descreve como começar a usar o PJRT na v6e com o PyTorch/XLA. O Python 3.10 é a versão recomendada.

Configurar o PyTorch usando o GKE com o XPK

Você pode usar o contêiner do Docker abaixo com o XPK, que já tem as dependências do PyTorch instaladas:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Para criar uma carga de trabalho XPK, use o seguinte comando:

python3 xpk.py workload create \
    --cluster ${CLUSTER_NAME} \
    [--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
    --workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --num-slices=${NUM_SLICES}  \
    --on-demand \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --enable-debug-logs \
    --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

O uso de --base-docker-image cria uma nova imagem do Docker com o diretório de trabalho atual integrado ao novo Docker.

Configurar o PyTorch usando recursos na fila

Siga estas etapas para instalar o PyTorch usando recursos em fila e executar um pequeno script na v6e.

Instalar dependências usando SSH para acessar as VMs

Use o comando a seguir para instalar dependências em todas as VMs do Cloud TPU. Para Multislice, adicione a flag --node=all:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Melhorar o desempenho de modelos com alocações frequentes e de grande porte

Para modelos com alocações grandes e frequentes, o uso da função tcmalloc melhora significativamente o desempenho em comparação com a implementação padrão da função malloc. Portanto, a função malloc padrão usada na VM do Cloud TPU é tcmalloc. No entanto, dependendo da sua carga de trabalho (por exemplo, com o DLRM, que tem alocações muito grandes para as tabelas de incorporação), a função tcmalloc pode causar uma lentidão. Nesse caso, tente redefinir a seguinte variável usando a função padrão malloc:

unset LD_PRELOAD

Usar um script Python para fazer um cálculo na VM v6e

Use o comando abaixo para executar um script que cria dois tensores, os adiciona e imprime o resultado.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Isso gera um resultado semelhante ao seguinte:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Configuração para o TensorFlow

É possível redefinir o ambiente de execução do Cloud TPU com a versão do TensorFlow compatível com a v6e executando os comandos a seguir:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Use o SSH para acessar o worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Instale o TensorFlow no worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Exporte a variável de ambiente TPU_NAME:

export TPU_NAME=v6e-16

Você pode executar o script Python a seguir para verificar quantos núcleos do Cloud TPU estão disponíveis no seu slice e testar se tudo está instalado corretamente:

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

O resultado será semelhante ao seguinte quando executado em uma fatia v6e-16:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e com SkyPilot

É possível usar o Cloud TPU v6e com o SkyPilot. Siga as etapas abaixo para adicionar informações de local e preço relacionadas ao v6e ao SkyPilot.

  1. Adicione o seguinte ao final do arquivo ~/.sky/catalogs/v5/gcp/vms.csv:

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Especifique os seguintes recursos em um arquivo YAML:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Inicie um cluster com o Cloud TPU v6e:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Conecte-se ao Cloud TPU v6e usando SSH: ssh tpu_v6

Tutoriais de inferência

Os tutoriais a seguir mostram como executar a inferência na Cloud TPU v6e:

Exemplos de treinamento

As seções a seguir fornecem exemplos de treinamento de modelos MaxText, MaxDiffusion e PyTorch no Cloud TPU v6e.

Treinamento do MaxText e do MaxDiffusion na VM do Cloud TPU v6e

As seções a seguir abrangem o ciclo de vida de treinamento dos modelos MaxText e MaxDiffusion.

Em geral, as etapas gerais são:

  1. Crie a imagem de base da carga de trabalho.
  2. Execute a carga de trabalho usando o XPK.
    1. Crie o comando de treinamento para a carga de trabalho.
    2. Implante a carga de trabalho.
  3. Acompanhe a carga de trabalho e confira as métricas.
  4. Exclua a carga de trabalho do XPK se ela não for necessária.
  5. Exclua o cluster de XPK quando ele não for mais necessário.

Criar a imagem de base

Instale o MaxText ou o MaxDiffusion e crie a imagem do Docker:

  1. Clone o repositório que você quer usar e mude para o diretório do repositório:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Configure o Docker para usar a CLI do Google Cloud:

    gcloud auth configure-docker
    
  3. Crie a imagem do Docker usando o comando a seguir ou a pilha JAX Stable. Para mais informações sobre a pilha estável do JAX, consulte Criar uma imagem do Docker com a pilha estável do JAX.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
    
  4. Se você estiver iniciando a carga de trabalho em uma máquina que não tem a imagem criada localmente, faça o upload da imagem:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
Criar uma imagem do Docker com a pilha estável do JAX

É possível criar as imagens do Docker MaxText e MaxDiffusion usando a imagem base da pilha Stable do JAX.

A pilha estável do JAX oferece um ambiente consistente para MaxText e MaxDiffusion agrupando o JAX com pacotes principais, como orbax, flax e optax, além de um libtpu.so bem qualificado que direciona os utilitários de programa do Cloud TPU e outras ferramentas essenciais. Essas bibliotecas são testadas para garantir a compatibilidade e fornecer uma base estável para criar e executar o MaxText e o MaxDiffusion. Isso elimina possíveis conflitos devido a versões de pacotes incompatíveis.

O JAX Stable Stack inclui um libtpu.so totalmente lançado e qualificado, a biblioteca principal que orienta a compilação, a execução e a configuração da rede ICI do programa do Cloud TPU. A versão do libtpu substitui o build noturno usado anteriormente pelo JAX e garante a funcionalidade consistente das computações XLA no Cloud TPU com testes de qualificação no nível do PJRT em HLO/StableHLO IRs.

Para criar a imagem do Docker MaxText e MaxDiffusion com a pilha estável do JAX, ao executar o script docker_build_dependency_image.sh, defina a variável MODE como stable_stack e a variável BASEIMAGE como a imagem de base que você quer usar.

O exemplo a seguir especifica us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1 como a imagem base:

bash docker_build_dependency_image.sh MODE=stable_stack
BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1

Para uma lista de imagens de base do JAX Stable Stack disponíveis, consulte Imagens do JAX Stable Stack no Artifact Registry.

Executar a carga de trabalho usando o XPK

  1. Defina as seguintes variáveis de ambiente se você não estiver usando os valores padrão definidos por MaxText ou MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Crie o script do modelo. Esse script será copiado como um comando de treinamento em uma etapa posterior.

    Não execute o script do modelo ainda.

    MaxText

    O MaxText é um LLM de código aberto de alto desempenho e altamente escalonável escrito em Python e JAX puros e direcionado a Google Cloud TPUs e GPUs para treinamento e inferência.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    O Gemma é uma família de LLMs de pesos abertos desenvolvidos pelo Google DeepMind, com base na pesquisa e tecnologia do Gemini.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    O Mixtral é um modelo de IA de última geração desenvolvido pela Mistral AI, que utiliza uma arquitetura de mistura de especialistas (MoE, na sigla em inglês) esparsa.

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    O Llama é uma família de LLMs de peso aberto desenvolvidos pela Meta.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash
    

    MaxDiffusion

    O MaxDiffusion é uma coleção de implementações de referência de vários modelos de difusão latente escritos em Python puro e JAX que são executados em dispositivos XLA, incluindo Cloud TPUs e GPUs. O Stable Diffusion é um modelo latente de texto para imagem que gera imagens fotorrealistas a partir de qualquer entrada de texto.

    É necessário instalar uma ramificação específica do Git para executar o MaxDiffusion, conforme mostrado no comando git checkout a seguir.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Script de treinamento:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
        
  3. Execute o modelo usando o script criado na etapa anterior. É necessário especificar a flag --base-docker-image para usar a imagem base MaxText ou especificar a flag --docker-image e a imagem que você quer usar.

    Opcional: é possível ativar o registro de depuração incluindo a flag --enable-debug-logs. Para mais informações, consulte Depurar o JAX no MaxText.

    Opcional: é possível criar um experimento da Vertex AI para fazer upload de dados para o TensorBoard da Vertex AI incluindo a flag --use-vertex-tensorboard. Para mais informações, consulte Monitorar JAX no MaxText usando a Vertex AI.

    python3 xpk.py workload create \
        --cluster ${CLUSTER_NAME} \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command=$YOUR-MODEL-SCRIPT

    Exporte as seguintes variáveis:

    export CLUSTER_NAME=CLUSTER_NAME: The name of your XPK cluster.
    export ACCELERATOR_TYPEACCELERATOR_TYPE: The version and size of your TPU. For example, `v6e-256`.
    export NUM_SLICES=NUM_SLICES: The number of Cloud TPU slices.
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: The model script to execute as a training command.

    A saída inclui um link para acompanhar sua carga de trabalho, semelhante a este:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    Abra o link e clique na guia Registros para acompanhar sua carga de trabalho em tempo real.

Depurar o JAX no MaxText

Use comandos XPK complementares para diagnosticar por que o cluster ou a carga de trabalho não está em execução.

Monitorar JAX no MaxText usando a Vertex AI

Acesse dados escalares e de perfil pelo TensorBoard gerenciado da Vertex AI.

  1. Aumente as solicitações de gerenciamento de recursos (CRUD) da zona que você está usando de 600 para 5.000. Isso pode não ser um problema para cargas de trabalho pequenas que usam menos de 16 VMs.
  2. Instale dependências, como cloud-accelerator-diagnostics, para o Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crie o cluster XPK usando a flag --create-vertex-tensorboard, conforme documentado em Criar o TensorBoard da Vertex AI. Também é possível executar esse comando em clusters atuais.

  4. Crie seu experimento da Vertex AI ao executar a carga de trabalho do XPK usando a flag --use-vertex-tensorboard e a flag opcional --experiment-name. Para conferir a lista completa de etapas, consulte Criar um experimento da Vertex AI para fazer upload de dados para o TensorBoard da Vertex AI.

Os registros incluem um link para um TensorBoard da Vertex AI, semelhante a este:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Também é possível encontrar o link do TensorBoard da Vertex AI no console do Google Cloud. Acesse Experimentos da Vertex AI no console do Google Cloud. Selecione a região adequada no menu suspenso.

O diretório do TensorBoard também é gravado no bucket do Cloud Storage especificado com ${BASE_OUTPUT_DIR}.

Excluir cargas de trabalho do XPK

Use o comando xpk workload delete para excluir uma ou mais cargas de trabalho com base no prefixo ou status do job. Esse comando pode ser útil se você enviou cargas de trabalho XPK que não precisam mais ser executadas ou se tiver trabalhos presos na fila.

Excluir cluster XPK

Use o comando xpk cluster delete para excluir um cluster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone=${ZONE} --project=${PROJECT_ID}

Treinamento de Llama e PyTorch/XLA na VM da v6e do Cloud TPU

Este tutorial descreve como treinar modelos Llama usando PyTorch/XLA na Cloud TPU v6e com o conjunto de dados WikiText.

Receber acesso ao Hugging Face e ao modelo Llama 3

Você precisa de um token de acesso de usuário do Hugging Face para executar este tutorial. Para informações sobre como criar e usar tokens de acesso do usuário, consulte a documentação do Hugging Face sobre tokens de acesso do usuário.

Você também precisa de permissão para acessar o modelo Llama 3 8B no Hugging Face. Para ter acesso, acesse o modelo Meta-Llama-3-8B no HuggingFace e solicite acesso.

Criar uma VM do Cloud TPU

Crie um Cloud TPU v6e com 8 chips para executar o tutorial.

  1. Configure as variáveis de ambiente:

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT_ID=your-project-id
    export ZONE=your-zone
  2. Crie uma VM do Cloud TPU:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID}

Instalação

Instale o fork pytorch-tpu/transformers dos transformadores e dependências do Hugging Face. Este tutorial foi testado com as seguintes versões de dependência usadas neste exemplo:

  • torch: compatível com 2.5.0
  • torch_xla[tpu]: compatível com 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} --zone ${ZONE} \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Configurar as configurações do modelo

O comando de treinamento na próxima seção, Executar o modelo, usa dois arquivos de configuração JSON para definir parâmetros de modelo e a configuração de FSDP (Fully Sharded Data Parallel). O sharding de FSDP é usado para que os pesos do modelo se ajustem a um tamanho de lote maior durante o treinamento. Ao treinar com modelos menores, pode ser suficiente usar o paralelismo de dados e replicar os pesos em cada dispositivo. Para mais informações sobre como dividir tensores em dispositivos no PyTorch/XLA, consulte o Guia do usuário do PyTorch/XLA SPMD.

  1. Crie o arquivo de configuração do parâmetro do modelo. Confira a seguir a configuração do parâmetro do modelo para o Llama3-8B. Para outros modelos, encontre a configuração no Hugging Face. Por exemplo, consulte a configuração da Llama2-7B.

    cat > llama-config.json << EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
  2. Crie o arquivo de configuração do FSDP:

    cat > fsdp-config.json << EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF

    Para mais informações sobre o FSDP, consulte FSDPv2.

  3. Faça upload dos arquivos de configuração para as VMs do Cloud TPU usando o seguinte comando:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

Executar o modelo

Usando os arquivos de configuração criados na seção anterior, execute o script run_clm.py para treinar o modelo Llama 3 8B no conjunto de dados WikiText. O script de treinamento leva aproximadamente 10 minutos para ser executado em uma Cloud TPU v6e-8.

  1. Faça login no Hugging Face no Cloud TPU usando o seguinte comando:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Execute o treinamento do modelo:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

Solução de problemas do PyTorch/XLA

Se você definir as variáveis opcionais para depuração na seção anterior, o perfil do modelo será armazenado no local especificado pela variável PROFILE_LOGDIR. É possível extrair o arquivo xplane.pb armazenado neste local e usar tensorboard para conferir os perfis no navegador usando as instruções do TensorBoard. Se o PyTorch/XLA não estiver funcionando como esperado, consulte o guia de solução de problemas, que tem sugestões para depurar, criar perfis e otimizar seu modelo.

Treinamento do DLRM DCN v2 na v6e

Este tutorial mostra como treinar o modelo DLRM DCN v2 no Cloud TPU v6e. É necessário provisionar uma TPU v6e com 64, 128 ou 256 chips.

Se você estiver executando em um TPU com vários hosts, redefina tpu-runtime com a versão apropriada do TensorFlow executando os comandos abaixo. Se você estiver executando em uma TPU de host único, não será necessário executar os dois comandos a seguir.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Conectar-se ao worker-0 usando SSH

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

Definir o nome da Cloud TPU

export TPU_NAME=${TPU_NAME}

Executar o DLRM v2

Copie o seguinte snippet de código em um arquivo chamado script.sh:

pip install --user setuptools==65.5.0

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Se você estiver executando o TensorFlow no GKE, instale a roda do TensorFlow Cloud TPU e o libtpu usando o seguinte comando:

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Defina as seguintes flags, que são necessárias para executar cargas de trabalho de recomendação, como o DLRM DCN:

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Execute script.sh:

chmod +x script.sh
./script.sh

Resultados da comparação

A seção a seguir contém resultados de comparação de mercado para o DLRM DCN v2 e MaxDiffusion na v6e.

DLRM DCN v2

O script de treinamento do DLRM DCN v2 foi executado em diferentes escalas. Confira as taxas de transferência na tabela a seguir.

v6e-64 v6e-128 v6e-256
Etapas de treinamento 7.000 7.000 7.000
Tamanho global do lote 131072 262144 524288
Capacidade (exemplos/s) 2975334 5111808 10066329

MaxDiffusion

Executamos o script de treinamento para MaxDiffusion em uma v6e-4, uma v6e-16 e duas v6e-16. Confira as taxas de transferência na tabela a seguir.

v6e-4 v6e-16 Duas v6e-16
Etapas de treinamento 0,069 0,073 0,13
Tamanho global do lote 8 32 64
Capacidade (exemplos/s) 115,9 438,4 492.3