Pengantar Trillium (v6e)

v6e digunakan untuk merujuk ke Trillium dalam dokumentasi, TPU API, dan log ini. v6e mewakili TPU generasi ke-6 Google.

Dengan 256 chip per Pod, arsitektur v6e memiliki banyak kesamaan dengan v5e. Sistem ini dioptimalkan untuk pelatihan, penyesuaian, dan penayangan transformer, text-to-image, dan jaringan neural konvolusional (CNN).

Untuk mengetahui informasi selengkapnya tentang arsitektur dan konfigurasi sistem v6e, lihat TPU v6e.

Dokumen pengantar ini berfokus pada proses pelatihan dan penyajian model menggunakan framework JAX atau PyTorch. Dengan setiap framework, Anda dapat menyediakan TPU menggunakan resource dalam antrean atau GKE. Penyiapan GKE dapat dilakukan menggunakan XPK atau perintah GKE.

Prosedur umum untuk melatih atau menyajikan model menggunakan v6e

  1. Menyiapkan Google Cloud project
  2. Kapasitas aman
  3. Menyediakan lingkungan Cloud TPU
  4. Menjalankan workload pelatihan atau inferensi model

Menyiapkan project Google Cloud

Sebelum dapat menggunakan Cloud TPU, Anda harus:

  • Buat akun dan project dengan penagihan diaktifkan Google Cloud
  • Menginstal komponen alfa Google Cloud CLI
  • Aktifkan Cloud TPU API
  • Membuat agen layanan Cloud TPU
  • Buat akun layanan Cloud TPU dan berikan izin

Untuk mengetahui informasi selengkapnya, lihat Menyiapkan lingkungan Cloud TPU.

Kapasitas yang aman

Hubungi Google Cloud dukungan untuk meminta kuota Cloud TPU v6e dan untuk mendapatkan jawaban atas pertanyaan tentang kapasitas.

Menyediakan lingkungan Cloud TPU

Cloud TPU v6e dapat disediakan dan dikelola dengan GKE, dengan GKE dan XPK (alat CLI wrapper melalui GKE), atau sebagai resource dalam antrean.

Prasyarat

  • Pastikan project Anda memiliki kuota TPUS_PER_TPU_FAMILY yang cukup, yang menentukan jumlah maksimum chip yang dapat Anda akses dalam project Google Cloud Anda.
  • v6e telah diuji dengan konfigurasi berikut:
    • Python 3.10 atau yang lebih baru
    • Versi software setiap hari:
      • JAX harian 0.4.32.dev20240912
      • LibTPU Harian 0.1.dev20240912+nightly
    • Versi software stabil:
      • JAX + JAX Lib v0.4.37
  • Pastikan project Anda memiliki kuota yang cukup untuk:

    • Kuota VM Cloud TPU
    • Kuota alamat IP
    • Kuota untuk Hyperdisk Balanced dan untuk jenis disk lain yang ingin Anda gunakan

  • Jika Anda menggunakan GKE dengan XPK, lihat Izin Konsol Cloud di akun pengguna atau akun layanan untuk mengetahui izin yang diperlukan untuk menjalankan XPK.

Buat variabel lingkungan

Di Cloud Shell, buat variabel lingkungan berikut:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Deskripsi tanda perintah

Variabel Deskripsi
NODE_ID ID yang ditetapkan pengguna dari Cloud TPU yang dibuat saat permintaan resource dalam antrean dialokasikan.
PROJECT_ID Nama projectGoogle Cloud . Gunakan project yang sudah ada atau buat project baru. Untuk mengetahui informasi selengkapnya, lihat Menyiapkan Google Cloud project.
ZONE Lihat dokumen Region dan zona Cloud TPU untuk mengetahui zona yang didukung.
ACCELERATOR_TYPE Lihat Jenis Akselerator.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Ini adalah alamat email untuk akun layanan Anda yang dapat Anda temukan di Google Cloud Console -> IAM -> Service Accounts

Contoh: tpu-service-account@your-project-ID.iam.gserviceaccount.com.com

NUM_SLICES Jumlah slice yang akan dibuat (hanya diperlukan untuk Multislice).
QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.
VALID_DURATION Durasi validitas permintaan resource yang diantrekan.
NETWORK_NAME Nama jaringan sekunder yang akan digunakan.
NETWORK_FW_NAME Nama firewall jaringan sekunder yang akan digunakan.

Mengoptimalkan performa jaringan

Untuk performa terbaik,gunakan jaringan dengan MTU (unit transmisi maksimum) 8.896.

Secara default, Virtual Private Cloud (VPC) hanya menyediakan MTU sebesar 1.460 byte yang akan memberikan performa jaringan yang kurang optimal. Anda dapat menyetel MTU jaringan VPC ke nilai apa pun antara 1.300 byte dan 8.896 byte (inklusif). Ukuran MTU kustom umum adalah 1.500 byte (Ethernet standar) atau 8.896 byte (maksimum yang mungkin). Untuk mengetahui informasi selengkapnya, lihat Ukuran MTU jaringan VPC yang valid.

Untuk mengetahui informasi selengkapnya tentang cara mengubah setelan MTU untuk jaringan yang ada atau default, lihat Mengubah setelan MTU jaringan VPC.

Contoh berikut membuat jaringan dengan MTU 8.896.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
   --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp --project=${PROJECT_ID}

Menggunakan multi-NIC (opsi untuk Multislice)

Variabel lingkungan berikut diperlukan untuk subnet sekunder saat Anda menggunakan lingkungan Multislice.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Gunakan perintah berikut untuk membuat perutean IP kustom untuk jaringan dan subnet.

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging

Setelah membuat slice multi-jaringan, Anda dapat memvalidasi bahwa kedua kartu antarmuka jaringan (NIC) sedang digunakan dengan menyiapkan cluster XPK dan menambahkan flag --command ifconfig ke perintah pembuatan workload XPK.

Gunakan perintah workload create berikut untuk menampilkan output perintah ifconfig di log konsol Google Cloud dan periksa apakah eth0 dan eth1 memiliki mtu=8896.

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --command "ifconfig"

Jika Anda ingin mengaktifkan log debug atau menggunakan Vertex AI TensorBoard, tambahkan argumen opsional berikut ke perintah:

   --enable-debug-logs \
   --use-vertex-tensorboard

Verifikasi bahwa eth0 dan eth1 memiliki mtu=8.896. Anda dapat memverifikasi bahwa multi-NIC berjalan dengan menambahkan flag --command ifconfig ke perintah pembuatan workload XPK. Periksa output beban kerja XPK tersebut di log konsol Google Cloud dan verifikasi bahwa eth0 dan eth1 memiliki mtu=8.896.

Meningkatkan kualitas setelan TCP

Jika Anda membuat Cloud TPU menggunakan antarmuka resource dalam antrean, Anda dapat menjalankan perintah berikut untuk meningkatkan performa jaringan dengan meningkatkan batas buffer penerimaan TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "${PROJECT_ID}" \
   --zone "${ZONE}" \
   --node=all \
   --worker=all \
   --command='
   sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Penyediaan dengan resource dalam antrean

Anda dapat membuat Cloud TPU v6e menggunakan resource dalam antrean. Resource dalam antrean memungkinkan Anda menerima kapasitas setelah tersedia. Anda dapat menentukan waktu mulai dan berakhir opsional untuk kapan permintaan harus diisi. Untuk mengetahui informasi selengkapnya, lihat Mengelola resource dalam antrean.

Menyediakan Cloud TPU v6e dengan GKE atau XPK

Jika menggunakan perintah GKE dengan v6e, Anda dapat menggunakan perintah Kubernetes atau XPK untuk menyediakan Cloud TPU dan melatih atau menyajikan model. Lihat Merencanakan Cloud TPU di GKE untuk mempelajari cara merencanakan konfigurasi Cloud TPU di cluster GKE. Bagian berikut memberikan perintah untuk membuat cluster XPK dengan dukungan NIC tunggal dan dukungan multi-NIC.

Membuat cluster XPK dengan dukungan NIC tunggal

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

Deskripsi tanda perintah

Variabel Deskripsi
CLUSTER_NAME Nama yang ditetapkan pengguna untuk cluster XPK.
PROJECT_ID Nama projectGoogle Cloud . Gunakan project yang sudah ada atau buat project baru. Untuk mengetahui informasi selengkapnya, lihat Menyiapkan Google Cloud project.
ZONE Lihat dokumen Region dan zona Cloud TPU untuk mengetahui zona yang didukung.
TPU_TYPE Lihat Jenis Akselerator.
NUM_SLICES Jumlah irisan yang ingin Anda buat
CLUSTER_ARGUMENTS Jaringan dan subnetwork yang akan digunakan.

Contoh: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Jumlah irisan yang akan dibuat.
NETWORK_NAME Nama jaringan sekunder yang akan digunakan.
NETWORK_FW_NAME Nama firewall jaringan sekunder yang akan digunakan.

Membuat cluster XPK dengan dukungan multi-NIC

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Deskripsi tanda perintah

Variabel Deskripsi
CLUSTER_NAME Nama yang ditetapkan pengguna untuk cluster XPK.
PROJECT_ID Nama projectGoogle Cloud . Gunakan project yang sudah ada atau buat project baru. Untuk mengetahui informasi selengkapnya, lihat Menyiapkan Google Cloud project.
ZONE Lihat dokumen Region dan zona Cloud TPU untuk mengetahui zona yang didukung.
TPU_TYPE Lihat Jenis Akselerator.
NUM_SLICES Jumlah irisan yang ingin Anda buat
CLUSTER_ARGUMENTS Jaringan dan subnetwork yang akan digunakan.

Contoh: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Jaringan node tambahan yang akan digunakan.

Contoh: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Jumlah slice yang akan dibuat (hanya diperlukan untuk Multislice).
NETWORK_NAME Nama jaringan sekunder yang akan digunakan.
NETWORK_FW_NAME Nama firewall jaringan sekunder yang akan digunakan.

Penyiapan framework

Bagian ini menjelaskan proses penyiapan umum untuk pelatihan model ML menggunakan framework JAX dan PyTorch. Jika menggunakan GKE, Anda dapat menggunakan perintah XPK atau Kubernetes untuk penyiapan framework.

Penyiapan untuk JAX

Bagian ini memberikan petunjuk penyiapan untuk menjalankan workload JAX di GKE, dengan atau tanpa XPK, serta menggunakan resource dalam antrean.

Menyiapkan JAX menggunakan GKE

Slice tunggal di host tunggal

Contoh berikut menyiapkan node pool satu host 2x2 menggunakan file YAML Kubernetes.

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

Setelah berhasil diselesaikan, Anda akan melihat pesan berikut di log GKE:

Total TPU chips: 4

Slice tunggal di multi-host

Contoh berikut menyiapkan node pool multi-host 4x4 menggunakan file YAML Kubernetes.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

Setelah berhasil diselesaikan, Anda akan melihat pesan berikut di log GKE:

Total TPU chips: 16

Multi-slice di multi-host

Contoh berikut menyiapkan dua node pool multi-host 4x4 menggunakan file YAML Kubernetes.

Sebagai prasyarat, Anda perlu menginstal JobSet v0.2.3 atau yang lebih baru.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

Setelah berhasil diselesaikan, Anda akan melihat pesan berikut di log GKE:

Total TPU chips: 32

Untuk mengetahui informasi selengkapnya, lihat Menjalankan workload Multislice dalam dokumentasi GKE.

Untuk performa yang lebih baik, Aktifkan hostNetwork.

Multi-NIC

Untuk menggunakan manifes multi-NIC berikut, Anda harus menyiapkan jaringan. Untuk mengetahui informasi selengkapnya, lihat Menyiapkan dukungan multi-jaringan untuk Pod Kubernetes.

Untuk memanfaatkan multi-NIC di GKE, Anda harus menyertakan beberapa anotasi tambahan ke manifes Pod Kubernetes.

Berikut adalah contoh manifes workload multi-NIC non-TPU.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Jika Anda menggunakan perintah exec untuk terhubung ke Pod Kubernetes, Anda akan melihat NIC tambahan menggunakan kode berikut:

$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

Menyiapkan JAX menggunakan GKE dengan XPK

Untuk menyiapkan JAX menggunakan GKE dan XPK, lihat README XPK.

Untuk menyiapkan dan menjalankan XPK dengan MaxText, lihat Cara menjalankan MaxText.

Menyiapkan JAX menggunakan resource dalam antrean

Instal JAX di semua VM Cloud TPU dalam slice atau slice Anda secara bersamaan menggunakan perintah gcloud alpha compute tpus tpu-vm ssh. Untuk Multislice, tambahkan tanda --node=all.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='
   pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Anda dapat menjalankan perintah berikut untuk memeriksa jumlah core Cloud TPU yang tersedia di slice Anda dan untuk menguji apakah semuanya telah diinstal dengan benar:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='
   python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

Outputnya mirip dengan berikut ini saat dijalankan pada slice v6e-16:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() menampilkan jumlah total chip dalam slice yang diberikan. jax.local_device_count() menunjukkan jumlah chip yang dapat diakses oleh satu VM dalam slice ini.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
   git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 &&
   pip install setuptools==59.6.0 &&
   pip install -r requirements.txt && pip install .'

Memecahkan masalah penyiapan JAX

Tips umumnya adalah mengaktifkan logging verbose dalam manifes workload GKE Anda. Kemudian, berikan log tersebut ke dukungan GKE.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Pesan error

no endpoints available for service 'jobset-webhook-service'

Error ini berarti jobset tidak diinstal dengan benar. Periksa apakah Pod Kubernetes deployment jobset-controller-manager sedang berjalan. Untuk mengetahui informasi selengkapnya, lihat dokumentasi pemecahan masalah JobSet.

TPU initialization failed: Failed to connect

Pastikan versi node GKE Anda adalah 1.30.4-gke.1348000 atau yang lebih baru (GKE 1.31 tidak didukung).

Penyiapan untuk PyTorch

Bagian ini menjelaskan cara mulai menggunakan PJRT di v6e dengan PyTorch/XLA. Python 3.10 adalah versi Python yang direkomendasikan.

Menyiapkan PyTorch menggunakan GKE dengan XPK

Anda dapat menggunakan container Docker berikut dengan XPK yang telah menginstal dependensi PyTorch:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Untuk membuat workload XPK, gunakan perintah berikut:

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
   --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone ${ZONE} \
   --project ${PROJECT_ID} \
   --enable-debug-logs \
   --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

Menggunakan --base-docker-image akan membuat image Docker baru dengan direktori kerja saat ini yang dibangun ke dalam Docker baru.

Menyiapkan PyTorch menggunakan resource dalam antrean

Ikuti langkah-langkah berikut untuk menginstal PyTorch menggunakan resource dalam antrean dan menjalankan skrip kecil di v6e.

Instal dependensi menggunakan SSH untuk mengakses VM

Gunakan perintah berikut untuk menginstal dependensi di semua VM Cloud TPU. Untuk Multislice, tambahkan tanda --worker=all:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
   sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
   pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Meningkatkan performa model dengan alokasi yang besar dan sering

Untuk model yang memiliki alokasi berukuran besar dan sering, penggunaan fungsi tcmalloc meningkatkan performa secara signifikan dibandingkan dengan implementasi fungsi malloc default, sehingga fungsi malloc default yang digunakan di VM Cloud TPU adalah tcmalloc. Namun, bergantung pada beban kerja Anda (misalnya, dengan DLRM yang memiliki alokasi sangat besar untuk tabel sematannya), fungsi tcmalloc dapat menyebabkan perlambatan. Jika demikian, Anda dapat mencoba membatalkan setelan variabel berikut menggunakan fungsi malloc default:

unset LD_PRELOAD

Menggunakan skrip Python untuk melakukan perhitungan di VM v6e

Gunakan perintah berikut untuk menjalankan skrip yang membuat dua tensor, menjumlahkannya, dan mencetak hasilnya:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project ${PROJECT_ID} \
   --zone ${ZONE} \
   --worker all \
   --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

Tindakan ini akan menghasilkan output yang mirip dengan berikut ini:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

v6e dengan SkyPilot

Anda dapat menggunakan Cloud TPU v6e dengan SkyPilot. Gunakan langkah-langkah berikut untuk menambahkan informasi lokasi dan harga terkait v6e ke SkyPilot. Untuk mengetahui informasi selengkapnya, lihat contoh SkyPilot TPU v6e.

Tutorial inferensi

Tutorial berikut menunjukkan cara menjalankan inferensi di Cloud TPU v6e:

Contoh pelatihan

Bagian berikut memberikan contoh untuk melatih model MaxText, MaxDiffusion, dan PyTorch di Cloud TPU v6e.

Pelatihan MaxText dan MaxDiffusion di VM Cloud TPU v6e

Bagian berikut mencakup siklus proses pelatihan model MaxText dan MaxDiffusion.

Secara umum, langkah-langkah tingkat tingginya adalah:

  1. Buat image dasar workload.
  2. Jalankan workload Anda menggunakan XPK.
    1. Buat perintah pelatihan untuk workload.
    2. Deploy workload.
  3. Ikuti workload dan lihat metrik.
  4. Hapus workload XPK jika tidak diperlukan.
  5. Hapus cluster XPK jika tidak diperlukan lagi.

Membangun image dasar

Instal MaxText atau MaxDiffusion dan bangun image Docker:

  1. Clone repositori yang ingin Anda gunakan dan ubah ke direktori untuk repositori:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Konfigurasi Docker agar menggunakan Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Bangun image Docker menggunakan perintah berikut atau menggunakan JAX Stable Stack. Untuk mengetahui informasi selengkapnya tentang JAX Stable Stack, lihat Membangun image Docker dengan JAX Stable Stack.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Tetapkan project ID Anda dalam konfigurasi gcloud CLI aktif:

    gcloud config set project ${PROJECT_ID}
    
  5. Jika Anda meluncurkan beban kerja dari mesin yang tidak memiliki image yang dibuat secara lokal, upload image tersebut.

    1. Tetapkan variabel lingkungan CLOUD_IMAGE_NAME:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Upload gambar:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Menjalankan workload menggunakan XPK

  1. Tetapkan variabel lingkungan berikut jika Anda tidak menggunakan nilai default yang ditetapkan oleh MaxText atau MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Buat skrip model Anda. Skrip ini akan disalin sebagai perintah pelatihan pada langkah selanjutnya.

    Jangan jalankan skrip model terlebih dahulu.

    MaxText

    MaxText adalah LLM open source berperforma tinggi dan sangat skalabel yang ditulis dalam Python dan JAX murni serta menargetkan TPU dan GPU untuk pelatihan dan inferensi. Google Cloud

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma adalah serangkaian LLM dengan bobot terbuka yang dikembangkan oleh Google DeepMind, berdasarkan riset dan teknologi Gemini.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral adalah model AI canggih yang dikembangkan oleh Mistral AI, dengan memanfaatkan arsitektur mixture-of-experts (MoE) yang jarang.

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama adalah serangkaian LLM dengan bobot terbuka yang dikembangkan oleh Meta.

    Untuk mengetahui contoh cara menjalankan Llama3 di PyTorch, lihat model torch_xla di repositori torchprime.

    MaxDiffusion

    MaxDiffusion adalah kumpulan implementasi referensi dari berbagai model difusi laten yang ditulis dalam Python dan JAX murni yang berjalan di perangkat XLA, termasuk Cloud TPU dan GPU. Stable Diffusion adalah model text-to-image laten yang menghasilkan gambar fotorealistik dari input teks apa pun.

    Anda harus menginstal cabang Git tertentu untuk menjalankan MaxDiffusion seperti yang ditunjukkan dalam perintah git clone berikut.

    Skrip pelatihan:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16  resolution=1024  per_device_batch_size=1 output_dir=${OUT_DIR} jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=200 attention=flash run_name=sdxl-ddp-v6e
    
  3. Ekspor variabel berikut:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    CLUSTER_NAME Nama cluster XPK Anda.
    ACCELERATOR_TYPE Lihat Jenis Akselerator.
    NUM_SLICES Jumlah slice TPU.
    YOUR_MODEL_SCRIPT Skrip model yang akan dieksekusi sebagai perintah pelatihan.
  4. Jalankan model menggunakan skrip yang Anda buat pada langkah sebelumnya. Anda harus menentukan flag --base-docker-image untuk menggunakan gambar dasar MaxText atau menentukan flag --docker-image dan gambar yang ingin Anda gunakan.

    Opsional: Anda dapat mengaktifkan logging debug dengan menyertakan flag --enable-debug-logs. Untuk mengetahui informasi selengkapnya, lihat Men-debug JAX di MaxText.

    Opsional: Anda dapat membuat Eksperimen Vertex AI untuk mengupload data ke Vertex AI TensorBoard dengan menyertakan tanda --use-vertex-tensorboard. Untuk mengetahui informasi selengkapnya, lihat Memantau JAX di MaxText menggunakan Vertex AI.

    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      [--enable-debug-logs] \
      [--use-vertex-tensorboard] \
      --command="${YOUR_MODEL_SCRIPT}"

    Outputnya mencakup link untuk memantau beban kerja Anda. Buka link dan klik tab Log untuk melacak beban kerja Anda secara real time.

Men-debug JAX di MaxText

Gunakan perintah XPK tambahan untuk mendiagnosis alasan cluster atau beban kerja tidak berjalan:

Memantau JAX di MaxText menggunakan Vertex AI

Untuk menggunakan TensorBoard, akun pengguna Google Cloud Anda harus memiliki peran aiplatform.user. Jalankan perintah berikut untuk memberikan peran ini:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Melihat data skalar dan profil melalui TensorBoard terkelola Vertex AI.

  1. Tingkatkan permintaan pengelolaan resource (CRUD) untuk zona yang Anda gunakan dari 600 menjadi 5.000. Hal ini mungkin tidak menjadi masalah untuk workload kecil yang menggunakan kurang dari 16 VM.

  2. Instal dependensi seperti cloud-accelerator-diagnostics untuk Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Buat cluster XPK menggunakan flag --create-vertex-tensorboard, seperti yang didokumentasikan di Membuat Vertex AI TensorBoard. Anda juga dapat menjalankan perintah ini di cluster yang sudah ada.

  4. Buat eksperimen Vertex AI saat menjalankan beban kerja XPK menggunakan tanda --use-vertex-tensorboard dan tanda --experiment-name opsional. Untuk mengetahui daftar lengkap langkah-langkahnya, lihat Membuat Vertex AI Experiment untuk mengupload data ke Vertex AI TensorBoard.

Log mencakup link ke Vertex AI TensorBoard, mirip dengan berikut:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Anda juga dapat menemukan link Vertex AI TensorBoard di konsol Google Cloud . Buka Vertex AI Experiments di konsol Google Cloud . Pilih wilayah yang sesuai dari drop-down.

Direktori TensorBoard juga ditulis ke bucket Cloud Storage yang Anda tentukan dengan ${BASE_OUTPUT_DIR}.

Menghapus workload XPK

Gunakan perintah xpk workload delete untuk menghapus satu atau beberapa workload berdasarkan awalan tugas atau status tugas. Perintah ini mungkin berguna jika Anda mengirim workload XPK yang tidak perlu lagi dijalankan, atau jika Anda memiliki tugas yang macet dalam antrean.

Menghapus cluster XPK

Gunakan perintah xpk cluster delete untuk menghapus cluster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
   --zone=${ZONE} --project=${PROJECT_ID}

Pelatihan Llama dan PyTorch/XLA di VM Cloud TPU v6e

Tutorial ini menjelaskan cara melatih model Llama menggunakan PyTorch/XLA di Cloud TPU v6e menggunakan set data WikiText.

Mendapatkan akses ke Hugging Face dan model Llama 3

Anda memerlukan token akses pengguna Hugging Face untuk menjalankan tutorial ini. Untuk mengetahui informasi tentang cara membuat token akses pengguna, lihat dokumentasi Hugging Face tentang token akses pengguna.

Anda juga memerlukan izin untuk mengakses model Llama-3-8B di Hugging Face. Untuk mendapatkan akses, buka model Meta-Llama-3-8B di HuggingFace dan minta akses.

Buat VM Cloud TPU

Buat Cloud TPU v6e dengan 8 chip untuk menjalankan tutorial.

  1. Siapkan variabel lingkungan:

    export NODE_ID=your-tpu-name
    export PROJECT_ID=your-project-id
    export ACCELERATOR_TYPE=v6e-8
    export ZONE=us-east1-d
    export RUNTIME_VERSION=v2-alpha-tpuv6e
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export VALID_DURATION=your-duration 
  2. Buat VM Cloud TPU:

    gcloud alpha compute tpus tpu-vm create ${NODE_ID} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Penginstalan

Instal fork pytorch-tpu/transformers dari transformer Hugging Face dan dependensi. Tutorial ini diuji dengan versi dependensi berikut yang digunakan dalam contoh ini:

  • torch: kompatibel dengan 2.5.0
  • torch_xla[tpu]: kompatibel dengan 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Menyiapkan konfigurasi model

Perintah pelatihan di bagian berikutnya, Jalankan model, menggunakan dua file konfigurasi JSON untuk menentukan parameter model dan konfigurasi Fully Sharded Data Parallel (FSDP). Sharding FSDP memungkinkan Anda menggunakan ukuran batch yang lebih besar saat melatih dengan melakukan sharding bobot model di beberapa TPU. Saat melatih dengan model yang lebih kecil, mungkin cukup menggunakan paralelisme data dan mereplikasi bobot di setiap perangkat. Untuk mengetahui informasi selengkapnya tentang cara memecah tensor di seluruh perangkat di PyTorch/XLA, lihat Panduan pengguna SPMD PyTorch/XLA.

  1. Buat file konfigurasi parameter model. Berikut adalah konfigurasi parameter model untuk Llama-3-8B. Untuk model lainnya, temukan konfigurasi di Hugging Face. Misalnya, lihat konfigurasi Llama-2-7B.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Buat file konfigurasi FSDP:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Untuk mengetahui informasi selengkapnya tentang FSDP, lihat FSDPv2.

  3. Upload file konfigurasi ke VM TPU menggunakan perintah berikut:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${NODE_ID}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Menjalankan model

Dengan menggunakan file konfigurasi yang Anda buat di bagian sebelumnya, jalankan skrip run_clm.py untuk melatih model Llama-3-8B pada set data WikiText. Skrip pelatihan memerlukan waktu sekitar 10 menit untuk dijalankan di Cloud TPU v6e-8.

  1. Login ke Hugging Face di Cloud TPU menggunakan perintah berikut:

    gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Jalankan pelatihan model:

    gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Memecahkan masalah PyTorch/XLA

Jika Anda menetapkan variabel opsional untuk proses debug di bagian sebelumnya, profil untuk model akan disimpan di lokasi yang ditentukan oleh variabel PROFILE_LOGDIR. Anda dapat mengekstrak file xplane.pb yang disimpan di lokasi ini dan menggunakan tensorboard untuk melihat profil di browser menggunakan petunjuk TensorBoard.

Jika PyTorch/XLA tidak berperforma seperti yang diharapkan, lihat Panduan pemecahan masalah, yang berisi saran untuk men-debug, membuat profil, dan mengoptimalkan model Anda.

Hasil tolok ukur

Bagian berikut berisi hasil tolok ukur untuk MaxDiffusion di v6e.

MaxDiffusion

Kami menjalankan skrip pelatihan untuk MaxDiffusion di v6e-4, v6e-16, dan dua v6e-16. Lihat throughput dalam tabel berikut.

v6e-4 v6e-16 Dua v6e-16
Langkah-langkah pelatihan 0,069 0,073 0,13
Ukuran batch global 8 32 64
Throughput (contoh/dtk) 115,9 438,4 492,3