Introduzione a Trillium (v6e)

v6e viene utilizzato per fare riferimento a Trillium in questa documentazione, nell'API TPU e nei log. v6e rappresenta la sesta generazione di TPU di Google.

Con 256 chip per pod, l'architettura v6e condivide molte somiglianze con la v5e. Questo sistema è ottimizzato per l'addestramento, la messa a punto e la pubblicazione di trasformatori, conversione di testo in immagini e reti neurali convoluzionali (CNN).

Per ulteriori informazioni sull'architettura e sulle configurazioni di sistema v6e, consulta TPU v6e.

Questo documento introduttivo si concentra sulle procedure di addestramento e pubblicazione dei modelli utilizzando i framework JAX, PyTorch o TensorFlow. Con ogni framework, puoi eseguire il provisioning delle TPU utilizzando le risorse in coda o GKE. La configurazione di GKE può essere eseguita utilizzando XPK o i comandi GKE.

Procedura generale per addestrare o eseguire il servizio di un modello utilizzando la versione 6e

  1. Preparare un Google Cloud progetto
  2. Capacità sicura
  3. Esegui il provisioning dell'ambiente Cloud TPU
  4. Esegui un carico di lavoro di addestramento o inferenza del modello

Preparare un Google Cloud progetto

Prima di poter utilizzare Cloud TPU, devi:

  • Crea un Google Cloud account e un progetto con la fatturazione abilitata
  • Installa i componenti alpha di Google Cloud CLI
  • Abilita l'API Cloud TPU
  • Crea un agente di servizio Cloud TPU
  • Crea un account di servizio Cloud TPU e concedi le autorizzazioni

Per saperne di più, vedi Configurare l'ambiente Cloud TPU.

Capacità sicura

Contatta l'Google Cloud assistenza per richiedere una quota Cloud TPU v6e e per rispondere a qualsiasi domanda sulla capacità.

Esegui il provisioning dell'ambiente Cloud TPU

È possibile eseguire il provisioning e la gestione di Cloud TPU v6e con GKE, con GKE e XPK (uno strumento CLI wrapper su GKE) o come risorse in coda.

Prerequisiti

  • Verifica che il tuo progetto disponga di una quota TPUS_PER_TPU_FAMILY sufficiente, che specifica il numero massimo di chip a cui puoi accedere all'interno del progetto. Google Cloud
  • La versione 6e è stata testata con la seguente configurazione:
    • Python 3.10 o versioni successive
    • Versioni software Nightly:
      • JAX per notte 0.4.32.dev20240912
      • LibTPU notturna 0.1.dev20240912+nightly
    • Versioni software stabili:
      • JAX + JAX Lib della versione 0.4.37
  • Verifica che il tuo progetto disponga di una quota sufficiente per:

    • Quota di VM Cloud TPU
    • Quota di indirizzi IP
    • Quota Hyperdisk bilanciato

  • Se utilizzi GKE con XPK, consulta Autorizzazioni di Cloud Console per l'account utente o di servizio per conoscere le autorizzazioni necessarie per eseguire XPK.

Creare variabili di ambiente

In Cloud Shell, crea le seguenti variabili di ambiente:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Descrizioni dei flag dei comandi

Variabile Descrizione
NODE_ID L'ID assegnato dall'utente della Cloud TPU che viene creato quando viene allocata la richiesta di risorse in coda.
PROJECT_ID Google Cloud nome progetto. Utilizza un progetto esistente o creane uno nuovo. Per ulteriori informazioni, vedi Configurare il Google Cloud progetto.
ZONA Consulta il documento Regioni e zone di Cloud TPU per le zone supportate.
ACCELERATOR_TYPE Consulta la sezione Tipi di acceleratore.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Si tratta dell'indirizzo email del tuo account di servizio che puoi trovare in Google Cloud Console -> IAM -> Account di servizio

Ad esempio: tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES Il numero di slice da creare (necessario solo per Multislice).
QUEUED_RESOURCE_ID L'ID testo assegnato dall'utente della richiesta di risorsa in coda.
VALID_DURATION La durata di validità della richiesta di risorse in coda.
NETWORK_NAME Il nome di una rete secondaria da utilizzare.
NETWORK_FW_NAME Il nome di un firewall di rete secondario da utilizzare.

Ottimizzare le prestazioni della rete

Per le migliori prestazioni,utilizza una rete con 8896 MTU (unità massima di trasmissione).

Per impostazione predefinita, un Virtual Private Cloud (VPC) fornisce solo un MTU di 1460 byte,che offrirà prestazioni di rete non ottimali. Puoi impostare l'MTU di una rete VPC su qualsiasi valore compreso tra 1300 e 8896 byte (inclusi). Le dimensioni MTU personalizzate comuni sono 1500 byte (Ethernet standard) o 8896 byte (il massimo possibile). Per maggiori informazioni, consulta Dimensioni MTU valide per le reti VPC.

Per saperne di più sulla modifica dell'impostazione MTU per una rete esistente o predefinita, consulta Modificare l'impostazione MTU di una rete VPC.

L'esempio seguente crea una rete con 8896 MTU.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
 --allow tcp,icmp,udp --project=${PROJECT_ID}

Utilizzo di più NIC (opzione per Multislice)

Le seguenti variabili di ambiente sono necessarie per una subnet secondaria quando utilizzi un ambiente Multislice.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Utilizza i seguenti comandi per creare il routing IP personalizzato per la rete e la subnet.

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
  --project=${PROJECT_ID} \
  --network=${NETWORK_NAME_2} \
  --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
  --router=${ROUTER_NAME} \
  --region=${REGION} \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project=${PROJECT_ID} \
  --enable-logging

Dopo aver creato uno slice multirete, puoi verificare che entrambe le schede di interfaccia di rete (NIC) siano in uso configurando un cluster XPK e aggiungendo il flag --command ifconfig al comando di creazione del carico di lavoro XPK.

Utilizza il seguente comando xpk workload per visualizzare l'output del comando ifconfig nei log della console Google Cloud e controlla che sia eth0 sia eth1 abbiano mtu=8896.

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

Se vuoi abilitare i log di debug o utilizzare Vertex AI TensorBoard, aggiungi i seguenti argomenti facoltativi al comando:

    --enable-debug-logs \
    --use-vertex-tensorboard

Verifica che sia eth0 che eth1 abbiano mtu=8896. Puoi verificare che la multi-NIC sia in esecuzione aggiungendo il flag --command ifconfig al comando di creazione del carico di lavoro XPK. Controlla l'output del carico di lavoro xpk nei log della console Google Cloud e verifica che sia eth0 sia eth1 abbiano mtu=8896.

Migliorare le impostazioni TCP

Se hai creato le tue Cloud TPU utilizzando l'interfaccia delle risorse in coda, puoi eseguire il seguente comando per migliorare le prestazioni della rete aumentando i limiti del buffer di ricezione TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "${PROJECT}" \
  --zone "${ZONE}" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

Eseguire il provisioning con risorse in coda

Puoi creare una Cloud TPU v6e utilizzando le risorse in coda. Le risorse in coda ti consentono di ricevere la capacità non appena diventa disponibile. Puoi specificare un'ora di inizio e di fine facoltative per indicare quando deve essere compilata la richiesta. Per ulteriori informazioni, consulta Gestire le risorse in coda.

Provisionare Cloud TPU v6e con GKE o XPK

Se utilizzi i comandi GKE con la versione 6e, puoi utilizzare i comandi Kubernetes o XPK per eseguire il provisioning delle TPU Cloud e addestrare o pubblicare i modelli. Consulta Pianificare le Cloud TPU in GKE per scoprire come pianificare le configurazioni di Cloud TPU nei cluster GKE. Le seguenti sezioni forniscono i comandi per creare un cluster XPK con supporto per una singola NIC e per più NIC.

Creare un cluster XPK con il supporto di una singola NIC

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=n1-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

Descrizioni dei flag dei comandi

Variabile Descrizione
CLUSTER_NAME Il nome assegnato dall'utente al cluster XPK.
PROJECT_ID Google Cloud nome progetto. Utilizza un progetto esistente o creane uno nuovo. Per ulteriori informazioni, vedi Configurare il Google Cloud progetto.
ZONA Consulta il documento Regioni e zone di Cloud TPU per le zone supportate.
TPU_TYPE Consulta la sezione Tipi di acceleratore.
NUM_SLICES Il numero di slice che vuoi creare
CLUSTER_ARGUMENTS La rete e la subnet da utilizzare.

Ad esempio: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Il numero di sezioni da creare.
NETWORK_NAME Il nome di una rete secondaria da utilizzare.
NETWORK_FW_NAME Il nome di un firewall di rete secondario da utilizzare.

Creare un cluster XPK con supporto multi-NIC

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_1} \
    --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
     --project=${PROJECT_ID} \
     --network=${NETWORK_NAME_2} \
     --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
    --cluster=${CLUSTER_NAME} \
    --num-slices=${NUM_SLICES} \
    --tpu-type=${TPU_TYPE} \
    --zone=${ZONE}  \
    --project=${PROJECT_ID} \
    --on-demand \
    --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
    --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
    --create-vertex-tensorboard

Descrizioni dei flag dei comandi

Variabile Descrizione
CLUSTER_NAME Il nome assegnato dall'utente al cluster XPK.
PROJECT_ID Google Cloud nome progetto. Utilizza un progetto esistente o creane uno nuovo. Per ulteriori informazioni, vedi Configurare il Google Cloud progetto.
ZONA Consulta il documento Regioni e zone di Cloud TPU per le zone supportate.
TPU_TYPE Consulta la sezione Tipi di acceleratore.
NUM_SLICES Il numero di slice che vuoi creare
CLUSTER_ARGUMENTS La rete e la subnet da utilizzare.

Ad esempio: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS La rete del nodo aggiuntivo da utilizzare.

Ad esempio: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Il numero di slice da creare (obbligatorio solo per Multislice).
NETWORK_NAME Il nome di una rete secondaria da utilizzare.
NETWORK_FW_NAME Il nome di un firewall di rete secondario da utilizzare.

Configurazione del framework

Questa sezione descrive la procedura di configurazione generale per l'addestramento dei modelli di ML utilizzando i framework JAX, PyTorch o TensorFlow. Se utilizzi GKE, puoi utilizzare XPK o i comandi Kubernetes per la configurazione del framework.

Configurazione per JAX

Questa sezione fornisce istruzioni di configurazione per l'esecuzione di workload JAX su GKE, con o senza XPK, nonché per l'utilizzo delle risorse in coda.

Configura JAX utilizzando GKE

Singolo slice su un singolo host

L'esempio seguente configura un pool di nodi a un host 2x2 utilizzando un file YAML di Kubernetes.

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

Al termine dell'operazione, dovresti visualizzare il seguente messaggio nel log GKE:

Total TPU chips: 4

Singolo slice su più host

L'esempio seguente configura un node pool multi-host 4x4 utilizzando un file YAML di Kubernetes.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

Al termine dell'operazione, dovresti visualizzare il seguente messaggio nel log GKE:

Total TPU chips: 16

Multislice su più host

L'esempio seguente configura due pool di nodi multi-host 4x4 utilizzando un file YAML Kubernetes.

Come prerequisito, devi installare JobSet versione 0.2.3 o successive.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

Al termine dell'operazione, dovresti visualizzare il seguente messaggio nel log GKE:

Total TPU chips: 32

Per saperne di più, consulta Eseguire un workload multislice nella documentazione di GKE.

Per migliorare le prestazioni, abilita hostNetwork.

Multi-NIC

Per sfruttare la multi-NIC in GKE, il manifest del pod Kubernetes deve avere annotazioni aggiuntive. Di seguito è riportato un manifest di esempio per il carico di lavoro NIC multiple non TPU.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Se utilizzi il comando exec per connetterti al pod Kubernetes, dovresti vedere la NIC aggiuntiva utilizzando il seguente codice.

$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

Configura JAX utilizzando GKE con XPK

Per configurare JAX utilizzando GKE e XPK, consulta il file README di xpk.

Per configurare ed eseguire XPK con MaxText, consulta Come eseguire MaxText.

Configurare JAX utilizzando le risorse in coda

Installa JAX su tutte le VM Cloud TPU nel tuo o nei tuoi slice contemporaneamente utilizzando il comando gcloud alpha compute tpus tpu-vm ssh. Per Multislice, aggiungi il flag --node=all.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Puoi eseguire il seguente comando per verificare quanti core Cloud TPU sono disponibili nel tuo slice e per verificare che tutto sia installato correttamente:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

L'output è simile al seguente quando viene eseguito su una sezione v6e-16:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() mostra il numero totale di chip nell'intervallo specificato. jax.local_device_count() indica il numero di chip accessibili da una singola VM in questo slice.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   && pip install -r requirements.txt  && pip install . '

Risolvere i problemi di configurazione di JAX

Un suggerimento generale è abilitare la registrazione dettagliata nel manifest del workload GKE. Quindi, fornisci i log all'assistenza GKE.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Messaggi di errore

no endpoints available for service 'jobset-webhook-service'

Questo errore indica che il jobset non è stato installato correttamente. Controlla se i pod Kubernetes del deployment jobset-controller-manager sono in esecuzione. Per ulteriori informazioni, consulta la documentazione sulla risoluzione dei problemi relativi a JobSet.

TPU initialization failed: Failed to connect

Assicurati che la versione del nodo GKE sia 1.30.4-gke.1348000 o successiva (GKE 1.31 non è supportato).

Configurazione per PyTorch

Questa sezione descrive come iniziare a utilizzare PJRT nella versione 6e con PyTorch/XLA. Python 3.10 è la versione di Python consigliata.

Configurare PyTorch utilizzando GKE con XPK

Puoi utilizzare il seguente contenitore Docker con XPK in cui sono già installate le dipendenze di PyTorch:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Per creare un carico di lavoro XPK, utilizza il seguente comando:

python3 xpk.py workload create \
    --cluster ${CLUSTER_NAME} \
    {--docker-image | --base-docker-image} us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
    --workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --num-slices=${NUM_SLICES}  \
    --on-demand \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --enable-debug-logs \
    --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

L'utilizzo di --base-docker-image crea una nuova immagine Docker con la directory di lavoro corrente integrata nel nuovo Docker.

Configurare PyTorch utilizzando le risorse in coda

Segui questi passaggi per installare PyTorch utilizzando le risorse in coda ed eseguire un piccolo script sulla versione 6e.

Installa le dipendenze utilizzando SSH per accedere alle VM

Utilizza il seguente comando per installare le dipendenze su tutte le VM Cloud TPU. Per Multislice, aggiungi il flag --worker=all:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 && \
               pip install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Migliora le prestazioni dei modelli con allocazioni significative e frequenti

Per i modelli con allocazioni di dimensioni considerevoli e frequenti, l'utilizzo della funzione tcmalloc migliora notevolmente le prestazioni rispetto all'implementazione della funzione malloc predefinita, pertanto la funzione malloc predefinita utilizzata sulla VM Cloud TPU è tcmalloc. Tuttavia, a seconda del carico di lavoro (ad esempio, con DLRM che ha allocazioni molto grandi per le tabelle di embedding), la funzione tcmalloc potrebbe causare un rallentamento, nel qual caso puoi provare a reimpostare la seguente variabile utilizzando la funzione malloc predefinita:

unset LD_PRELOAD

Utilizzare uno script Python per eseguire un calcolo sulla VM v6e

Utilizza il seguente comando per eseguire uno script che crea due tensori, li somma e stampa il risultato.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Viene generato un output simile al seguente:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Configurazione per TensorFlow

Puoi reimpostare il runtime Cloud TPU con la versione di TensorFlow compatibile con la v6e eseguendo i seguenti comandi:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Utilizza SSH per accedere a worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Installa TensorFlow su worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Esporta la variabile di ambiente TPU_NAME:

export TPU_NAME=v6e-16

Puoi eseguire il seguente script Python per verificare quanti core Cloud TPU sono disponibili nel tuo slice e per verificare che tutto sia installato correttamente:

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

L'output è simile al seguente quando viene eseguito su una sezione v6e-16:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e con SkyPilot

Puoi utilizzare Cloud TPU v6e con SkyPilot. Segui questa procedura per aggiungere a SkyPilot informazioni su prezzi e località relativi a v6e.

  1. Aggiungi quanto segue alla fine del file ~/.sky/catalogs/v5/gcp/vms.csv:

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Specifica le seguenti risorse in un file YAML:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Avvia un cluster con Cloud TPU v6e:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Connettiti a Cloud TPU v6e tramite SSH: ssh tpu_v6

Tutorial sull'inferenza

I seguenti tutorial mostrano come eseguire l'inferenza su Cloud TPU v6e:

Esempi di addestramento

Le sezioni seguenti forniscono esempi per l'addestramento di modelli MaxText, MaxDiffusion e PyTorch su Cloud TPU v6e.

Addestramento di MaxText e MaxDiffusion su VM Cloud TPU v6e

Le sezioni seguenti illustrano il ciclo di vita dell'addestramento dei modelli MaxText e MaxDiffusion.

In generale, i passaggi di alto livello sono:

  1. Crea l'immagine di base del workload.
  2. Esegui il tuo carico di lavoro utilizzando XPK.
    1. Crea il comando di addestramento per il carico di lavoro.
    2. Esegui il deployment del carico di lavoro.
  3. Monitora il carico di lavoro e visualizza le metriche.
  4. Elimina il workload XPK se non è necessario.
  5. Elimina il cluster XPK quando non è più necessario.

Crea l'immagine di base

Installa MaxText o MaxDiffusion e crea l'immagine Docker:

  1. Clona il repository che vuoi utilizzare e passa alla directory del repository:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Configura Docker in modo che utilizzi Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Crea l'immagine Docker utilizzando il seguente comando o JAX Stable Stack. Per ulteriori informazioni su JAX Stable Stack, consulta Creare un'immagine Docker con JAX Stable Stack.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. Se avvii il carico di lavoro da una macchina su cui l'immagine non è stata compilata localmente, caricala:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    

Esegui il carico di lavoro utilizzando XPK

  1. Imposta le seguenti variabili di ambiente se non utilizzi i valori predefiniti impostati da MaxText o MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Crea lo script del modello. Questo script verrà copiato come comando di addestramento in un passaggio successivo.

    Non eseguire ancora lo script del modello.

    MaxText

    MaxText è un LLM open source ad alte prestazioni e altamente scalabile scritto in Python e JAX puro e che ha come target Google Cloud TPU e GPU per l'addestramento e l'inferenza.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma è una famiglia di LLM con pesi aperti sviluppati da Google DeepMind, basata sulla ricerca e sulla tecnologia di Gemini.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral è un modello di IA all'avanguardia sviluppato da Mistral AI, che utilizza un'architettura sparse mixture-of-experts (MoE).

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama è una famiglia di LLM con pesi aperti sviluppati da Meta.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash
    

    MaxDiffusion

    MaxDiffusion è una raccolta di implementazioni di riferimento di vari modelli di diffusione latente scritti in puro Python e JAX che vengono eseguiti su dispositivi XLA, tra cui GPU e Cloud TPU. Stable Diffusion è un modello latente di conversione di testo in immagine che genera immagini fotorealistiche da qualsiasi input di testo.

    Per eseguire MaxDiffusion devi installare un ramo Git specifico, come mostrato nel seguente comando git checkout.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Script di addestramento:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
        
  3. Esporta le seguenti variabili:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Descrizioni delle variabili di ambiente

    Variabile Descrizione
    CLUSTER_NAME Il nome del cluster XPK.
    ACCELERATOR_TYPE Consulta la sezione Tipi di acceleratore.
    NUM_SLICES Il numero di sezioni TPU.
    YOUR_MODEL_SCRIPT Lo script del modello da eseguire come comando di addestramento.
  4. Esegui il modello utilizzando lo script creato nel passaggio precedente. Devi specificare il flag --base-docker-image per utilizzare l'immagine di base MaxText o il flag --docker-image e l'immagine che vuoi utilizzare.

    (Facoltativo) Puoi attivare la registrazione di log di debug includendo il flag --enable-debug-logs. Per ulteriori informazioni, consulta Eseguire il debug di JAX su MaxText.

    (Facoltativo) Puoi creare un esperimento Vertex AI per caricare i dati in Vertex AI TensorBoard includendo il flag --use-vertex-tensorboard. Per ulteriori informazioni, consulta Monitorare JAX su MaxText utilizzando Vertex AI.

    python3 xpk.py workload create \
        --cluster ${CLUSTER_NAME} \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command=${YOUR_MODEL_SCRIPT}

    L'output include un link per seguire il carico di lavoro. Apri il link e fai clic sulla scheda Log per monitorare il carico di lavoro in tempo reale.

Eseguire il debug di JAX su MaxText

Utilizza i comandi XPK supplementari per diagnosticare il motivo per cui il cluster o il carico di lavoro non è in esecuzione.

Monitorare JAX su MaxText utilizzando Vertex AI

Visualizza i dati scalari e di profilo tramite TensorBoard gestito da Vertex AI.

  1. Aumenta le richieste di gestione delle risorse (CRUD) per la zona in uso da 600 a 5000. Questo potrebbe non essere un problema per i carichi di lavoro di piccole dimensioni che utilizzano meno di 16 VM.
  2. Installa le dipendenze, ad esempio cloud-accelerator-diagnostics per Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crea il tuo cluster XPK utilizzando il flag --create-vertex-tensorboard, come documentato in Creare Vertex AI TensorBoard. Puoi anche eseguire questo comando sui cluster esistenti.

  4. Crea l'esperimento Vertex AI quando esegui il tuo carico di lavoro XPK utilizzando il flag --use-vertex-tensorboard e il flag facoltativo --experiment-name. Per l'elenco completo dei passaggi, consulta Creare un esperimento Vertex AI per caricare i dati su Vertex AI TensorBoard.

I log includono un link a un Vertex AI TensorBoard, simile al seguente:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Puoi trovare il link a Vertex AI TensorBoard anche nella console Google Cloud. Vai a Vertex AI Experiments nella console Google Cloud. Seleziona la regione appropriata dal menu a discesa.

La directory di TensorBoard viene scritta anche nel bucket Cloud Storage che hai specificato con ${BASE_OUTPUT_DIR}.

Eliminare i workload XPK

Utilizza il comando xpk workload delete per eliminare uno o più workload in base al prefisso o allo stato del job. Questo comando può essere utile se hai inviato carichi di lavoro XPK che non devono più essere eseguiti o se hai job bloccati nella coda.

Elimina il cluster XPK

Utilizza il comando xpk cluster delete per eliminare un cluster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone=${ZONE} --project=${PROJECT_ID}

Addestramento di Llama e PyTorch/XLA su una VM Cloud TPU v6e

Questo tutorial descrive come addestrare i modelli Llama utilizzando PyTorch/XLA su Cloud TPU v6e utilizzando il set di dati WikiText.

Accedi a Hugging Face e al modello Llama 3

Per eseguire questo tutorial, devi disporre di un token di accesso utente Hugging Face. Per informazioni sulla creazione e sui token di accesso utente, consulta la documentazione di Hugging Face sui token di accesso utente.

Devi anche disporre dell'autorizzazione per accedere al modello Llama 3 8B su Hugging Face. Per ottenere l'accesso, vai al modello Meta-Llama-3-8B su HuggingFace e richiedi l'accesso.

Crea una VM Cloud TPU

Crea una Cloud TPU v6e con 8 chip per eseguire il tutorial.

  1. Configura le variabili di ambiente:

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT_ID=your-project-id
    export ZONE=us-east1-d
  2. Crea una VM Cloud TPU:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID}

Installazione

Installa il pytorch-tpu/transformersfork di Hugging Face Transformers e delle relative dipendenze. Questo tutorial è stato testato con le seguenti versioni delle dipendenze utilizzate in questo esempio:

  • torch: compatibile con 2.5.0
  • torch_xla[tpu]: compatibile con 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} --zone ${ZONE} \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Configura le configurazioni del modello

Il comando di addestramento nella sezione successiva, Esegui il modello, utilizza due file di configurazione JSON per definire i parametri del modello e la configurazione FSDP (Fully Sharded Data Parallel). Lo sharding FSDP viene utilizzato per adattare i pesi del modello a un batch di dimensioni maggiori durante l'addestramento. Quando si esegue l'addestramento con modelli più piccoli, potrebbe essere sufficiente utilizzare il parallelismo dei dati e replicare i pesi su ogni dispositivo. Per ulteriori informazioni su come suddividere i tensori su più dispositivi in PyTorch/XLA, consulta la Guida dell'utente di PyTorch/XLA SPMD.

  1. Crea il file di configurazione dei parametri del modello. Di seguito è riportata la configurazione del parametro del modello per Llama3-8B. Per altri modelli, trova la configurazione su Hugging Face. Ad esempio, consulta la configurazione Llama2-7B.

    cat > llama-config.json << EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. Crea il file di configurazione FSDP:

    cat > fsdp-config.json << EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Per ulteriori informazioni su FSDP, consulta FSDPv2.

  3. Carica i file di configurazione nelle VM Cloud TPU utilizzando il seguente comando:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

Esegui il modello

Utilizzando i file di configurazione creati nella sezione precedente, esegui lo script run_clm.py per addestrare il modello Llama 3 8B sul set di dati WikiText. L'esecuzione dello script di addestramento su una Cloud TPU v6e-8 richiede circa 10 minuti.

  1. Accedi a Hugging Face su Cloud TPU utilizzando il seguente comando:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Esegui l'addestramento del modello:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

Risoluzione dei problemi relativi a PyTorch/XLA

Se imposti le variabili facoltative per il debug nella sezione precedente, il profilo del modello verrà archiviato nella posizione specificata dalla variabile PROFILE_LOGDIR. Puoi estrarre il file xplane.pb memorizzato in questa posizione e utilizzare tensorboard per visualizzare i profili nel browser seguendo le istruzioni di TensorBoard. Se PyTorch/XLA non funziona come previsto, consulta la guida alla risoluzione dei problemi, che contiene suggerimenti per il debug, il profiling e l'ottimizzazione del modello.

Addestramento di DLRM DCN v2 su v6e

Questo tutorial mostra come addestrare il modello DLRM DCN v2 su Cloud TPU v6e. Devi eseguire il provisioning di una TPU v6e con 64, 128 o 256 chip.

Se esegui l'operazione su una TPU multi-host, reimposta tpu-runtime con la versione di TensorFlow appropriata eseguendo i seguenti comandi. Se stai eseguendo l'operazione su una TPU a un solo host, non devi eseguire i due comandi riportati di seguito.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --worker=all \
    --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone  ${ZONE} \
    --worker=all \
    --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Connettiti a worker-0 tramite SSH

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project ${PROJECT_ID}

Imposta il nome di Cloud TPU

export TPU_NAME=your-tpu-name

Esegui DLRM v2

Copia il seguente snippet di codice in un file denominato script.sh:

pip install --user setuptools==65.5.0

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Se esegui TensorFlow su GKE, installa il pacchetto TensorFlow Cloud TPU e libtpu utilizzando il seguente comando:

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Imposta i seguenti flag, necessari per eseguire i workload di consigli (ad esempio DLRM DCN):

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Esegui script.sh:

chmod +x script.sh
./script.sh

Risultati del benchmarking

La sezione seguente contiene i risultati del benchmarking per DLRM DCN v2 e MaxDiffusion su v6e.

DLRM DCN v2

Lo script di addestramento DLRM DCN v2 è stato eseguito su scale diverse. Consulta le portate nella tabella seguente.

v6e-64 v6e-128 v6e-256
Passaggi di addestramento 7000 7000 7000
Dimensione del batch globale 131072 262144 524288
Velocità effettiva (esempi/sec) 2975334 5111808 10066329

MaxDiffusion

Abbiamo eseguito lo script di addestramento per MaxDiffusion su una v6e-4, una v6e-16 e due v6e-16. Consulta le portate nella tabella seguente.

v6e-4 v6e-16 Due v6e-16
Passaggi di addestramento 0,069 0,073 0,13
Dimensione del batch globale 8 32 64
Velocità effettiva (esempi/sec) 115,9 438,4 492,3