Pelatihan Cloud TPU v5e

Dengan jejak 256 chip yang lebih kecil per Pod, TPU v5e dioptimalkan untuk menjadi produk bernilai tinggi untuk pelatihan, penyesuaian, dan penayangan transformer, text-to-image, dan Convolutional Neural Network (CNN). Untuk informasi selengkapnya tentang cara menggunakan Cloud TPU v5e untuk penayangan, lihat Inferensi menggunakan v5e.

Untuk mengetahui informasi selengkapnya tentang konfigurasi dan hardware TPU Cloud TPU v5e, lihat TPU v5e.

Mulai

Bagian berikut menjelaskan cara mulai menggunakan TPU v5e.

Kuota permintaan

Anda memerlukan kuota untuk menggunakan TPU v5e untuk pelatihan. Ada berbagai jenis kuota untuk TPU on-demand, TPU yang direservasi, dan VM Spot TPU. Ada kuota terpisah yang diperlukan jika Anda menggunakan TPU v5e untuk inferensi. Untuk mengetahui informasi selengkapnya tentang kuota, lihat Kuota. Untuk meminta kuota TPU v5e, hubungi Cloud Sales.

Membuat akun dan project Google Cloud

Anda memerlukan akun dan project Google Cloud untuk menggunakan Cloud TPU. Untuk mengetahui informasi selengkapnya, lihat Menyiapkan lingkungan Cloud TPU.

Buat Cloud TPU

Praktik terbaiknya adalah menyediakan Cloud TPU v5 sebagai resource dalam antrean menggunakan perintah queued-resource create. Untuk mengetahui informasi selengkapnya, lihat Mengelola resource dalam antrean.

Anda juga dapat menggunakan Create Node API (gcloud compute tpus tpu-vm create) untuk menyediakan Cloud TPU v5e. Untuk mengetahui informasi selengkapnya, lihat Mengelola resource TPU.

Untuk mengetahui informasi selengkapnya tentang konfigurasi v5e yang tersedia untuk pelatihan, lihat Jenis Cloud TPU v5e untuk pelatihan.

Penyiapan framework

Bagian ini menjelaskan proses penyiapan umum untuk pelatihan model kustom menggunakan JAX atau PyTorch dengan TPU v5e.

Untuk petunjuk penyiapan inferensi, lihat pengantar inferensi v5e.

Tentukan beberapa variabel lingkungan:

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id

Penyiapan untuk JAX

Jika memiliki bentuk slice lebih dari 8 chip, Anda akan memiliki beberapa VM dalam satu slice. Dalam hal ini, Anda perlu menggunakan flag --worker=all untuk menjalankan penginstalan di semua VM TPU dalam satu langkah tanpa menggunakan SSH untuk login ke setiap VM secara terpisah:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Deskripsi flag perintah

Variabel Deskripsi
TPU_NAME ID teks TPU yang ditetapkan pengguna yang dibuat saat permintaan resource yang diantrekan dialokasikan.
PROJECT_ID Google Cloud Nama Project. Gunakan project yang ada atau buat project baru di Menyiapkan Google Cloud project
ZONA Lihat dokumen Region dan zona TPU untuk zona yang didukung.
pekerja VM TPU yang memiliki akses ke TPU yang mendasarinya.

Anda dapat menjalankan perintah berikut untuk memeriksa jumlah perangkat (output yang ditampilkan di sini dihasilkan dengan slice v5litepod-16). Kode ini menguji apakah semuanya diinstal dengan benar dengan memeriksa apakah JAX melihat TensorCore Cloud TPU dan dapat menjalankan operasi dasar:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

Outputnya akan seperti ini:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() menampilkan jumlah total chip dalam slice yang diberikan. jax.local_device_count() menunjukkan jumlah chip yang dapat diakses oleh satu VM di slice ini.

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

Outputnya akan seperti ini:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

Coba Tutorial JAX dalam dokumen ini untuk memulai pelatihan v5e menggunakan JAX.

Penyiapan untuk PyTorch

Perhatikan bahwa v5e hanya mendukung runtime PJRT dan PyTorch 2.1+ akan menggunakan PJRT sebagai runtime default untuk semua versi TPU.

Bagian ini menjelaskan cara mulai menggunakan PJRT di v5e dengan PyTorch/XLA dengan perintah untuk semua pekerja.

Menginstal dependensi

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip install mkl mkl-include
      pip install tf-nightly tb-nightly tbp-nightly
      pip install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Ganti PYTORCH_VERSION dengan versi PyTorch yang ingin Anda gunakan. PYTORCH_VERSION digunakan untuk menentukan versi yang sama untuk PyTorch/XLA. 2.6.0 direkomendasikan.

Untuk mengetahui informasi selengkapnya tentang versi PyTorch dan PyTorch/XLA, lihat PyTorch - Mulai dan rilis PyTorch/XLA.

Untuk informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan PyTorch/XLA.

Jika Anda mendapatkan error saat menginstal roda untuk torch, torch_xla, atau torchvision seperti pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222, downgrade versi Anda dengan perintah ini:

pip3 install setuptools==62.1.0

Menjalankan skrip dengan PJRT

unset LD_PRELOAD

Berikut adalah contoh penggunaan skrip Python untuk melakukan penghitungan di VM v5e:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

Tindakan ini akan menghasilkan output yang mirip dengan berikut ini:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

Coba Tutorial PyTorch dalam dokumen ini untuk memulai pelatihan v5e menggunakan PyTorch.

Hapus TPU dan resource yang diantrekan di akhir sesi. Untuk menghapus resource dalam antrean, hapus slice, lalu resource dalam antrean dalam 2 langkah:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Kedua langkah ini juga dapat digunakan untuk menghapus permintaan resource yang diantrean yang berada dalam status FAILED.

Contoh JAX/FLAX

Bagian berikut menjelaskan contoh cara melatih model JAX dan FLAX di TPU v5e.

Melatih ImageNet di v5e

Tutorial ini menjelaskan cara melatih ImageNet di v5e menggunakan data input palsu. Jika Anda ingin menggunakan data sebenarnya, lihat file README di GitHub.

Siapkan

  1. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Misalnya: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  2. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda akan dapat menggunakan SSH ke VM TPU setelah resource yang diantrekan berada dalam status ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika QueuedResource berada dalam status ACTIVE, output akan mirip dengan yang berikut ini:

     state: ACTIVE
    
  3. Instal JAX dan jaxlib versi terbaru:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Clone model ImageNet dan instal persyaratan yang sesuai:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
    
  5. Untuk membuat data palsu, model memerlukan informasi tentang dimensi set data. Hal ini dapat dikumpulkan dari metadata set data ImageNet:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
    

Melatih model

Setelah semua langkah sebelumnya selesai, Anda dapat melatih model.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

Menghapus TPU dan resource yang diantrekan

Hapus TPU dan resource yang diantrekan di akhir sesi.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Model FLAX Hugging Face

Model Hugging Face yang diimplementasikan di FLAX berfungsi secara langsung di Cloud TPU v5e. Bagian ini memberikan petunjuk untuk menjalankan model populer.

Melatih ViT di Imagenette

Tutorial ini menunjukkan cara melatih model Vision Transformer (ViT) dari HuggingFace menggunakan set data Imagenette Fast AI di Cloud TPU v5e.

Model ViT adalah model pertama yang berhasil melatih encoder Transformer di ImageNet dengan hasil yang sangat baik dibandingkan dengan jaringan konvolusi. Untuk informasi selengkapnya, lihat referensi berikut:

Siapkan

  1. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Misalnya: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  2. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda akan dapat menggunakan SSH ke VM TPU setelah resource dalam antrean berada dalam status ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika resource yang diantrekan berada dalam status ACTIVE, output-nya akan mirip dengan yang berikut ini:

     state: ACTIVE
    
  3. Instal JAX dan library-nya:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Download repositori dan persyaratan penginstalan Hugging Face:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
    
  5. Download set data Imagenette:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

Melatih model

Latih model dengan buffering yang telah dipetakan sebelumnya sebesar 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

Menghapus TPU dan resource yang diantrekan

Hapus TPU dan resource yang diantrekan di akhir sesi.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hasil benchmark ViT

Skrip pelatihan dijalankan di v5litepod-4, v5litepod-16, dan v5litepod-64. Tabel berikut menunjukkan throughput dengan berbagai jenis akselerator.

Jenis akselerator v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
Ukuran batch global 32 128 512
Throughput (contoh/dtk) 263,40 429,34 470,71

Melatih Difusi di Pokémon

Tutorial ini menunjukkan cara melatih model Stable Diffusion dari HuggingFace menggunakan set data Pokémon di Cloud TPU v5e.

Model Stable Diffusion adalah model teks ke gambar laten yang menghasilkan gambar fotorealistik dari input teks apa pun. Untuk informasi selengkapnya, lihat referensi berikut:

Siapkan

  1. Tetapkan variabel lingkungan untuk nama bucket penyimpanan Anda:

    export GCS_BUCKET_NAME=your_bucket_name
  2. Siapkan bucket penyimpanan untuk output model Anda:

    gcloud storage buckets create gs://GCS_BUCKET_NAME \
        --project=your_project \
        --location=us-west1
  3. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west1-c
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Misalnya: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  4. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda akan dapat menggunakan SSH ke VM TPU setelah resource yang diantrekan berada dalam status ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika resource yang diantrekan berada dalam status ACTIVE, output-nya akan mirip dengan berikut ini:

     state: ACTIVE
    
  5. Instal JAX dan library-nya.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  6. Download repositori HuggingFace dan instal persyaratan.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
         --project=${PROJECT_ID} \
         --zone=${ZONE} \
         --worker=all \
         --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
    

Melatih model

Latih model dengan buffering yang telah dipetakan sebelumnya sebesar 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip3 install -r requirements.txt
    pip3 install .
    pip3 install gcsfs
    export LIBTPU_INIT_ARGS=''
    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

Pembersihan

Hapus TPU, resource yang diantrekan, dan bucket Cloud Storage di akhir sesi.

  1. Menghapus TPU:

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  2. Hapus resource yang ada dalam antrean:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  3. Hapus bucket Cloud Storage:

    gcloud storage rm -r gs://${GCS_BUCKET_NAME}
    

Hasil benchmark untuk difusi

Skrip pelatihan dijalankan di v5litepod-4, v5litepod-16, dan v5litepod-64. Tabel berikut menunjukkan throughput.

Jenis akselerator v5litepod-4 v5litepod-16 v5litepod-64
Langkah Latihan 1500 1500 1500
Ukuran batch global 32 64 128
Throughput (contoh/dtk) 36,53 43,71 49,36

PyTorch/XLA

Bagian berikut menjelaskan contoh cara melatih model PyTorch/XLA di TPU v5e.

Melatih ResNet menggunakan runtime PJRT

PyTorch/XLA bermigrasi dari XRT ke PjRt dari PyTorch 2.0+. Berikut adalah petunjuk yang diperbarui untuk menyiapkan v5e untuk workload pelatihan PyTorch/XLA.

Siapkan
  1. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Misalnya: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  2. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda akan dapat melakukan SSH ke VM TPU setelah QueuedResource berada dalam status ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika resource yang diantrekan berada dalam status ACTIVE, output-nya akan mirip dengan berikut ini:

     state: ACTIVE
    
  3. Menginstal dependensi khusus Torch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --worker=all \
      --command='
         sudo apt-get update -y
         sudo apt-get install libomp5 -y
         pip3 install mkl mkl-include
         pip3 install tf-nightly tb-nightly tbp-nightly
         pip3 install numpy
         sudo apt-get install libopenblas-dev -y
         pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    Ganti PYTORCH_VERSION dengan versi PyTorch yang ingin Anda gunakan. PYTORCH_VERSION digunakan untuk menentukan versi yang sama untuk PyTorch/XLA. 2.6.0 direkomendasikan.

    Untuk mengetahui informasi selengkapnya tentang versi PyTorch dan PyTorch/XLA, lihat PyTorch - Mulai dan rilis PyTorch/XLA.

    Untuk informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan PyTorch/XLA.

Melatih model ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git checkout release-r2.6
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 —num_workers=16  --log_steps=300 --batch_size=64 --profile'

Menghapus TPU dan resource yang diantrekan

Hapus TPU dan resource yang diantrekan di akhir sesi.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
Hasil benchmark

Tabel berikut menunjukkan throughput benchmark.

Jenis akselerator Throughput (contoh/detik)
v5litepod-4 4240 ex/s
v5litepod-16 10.810 ex/s
v5litepod-64 46.154 ex/s

Melatih ViT di v5e

Tutorial ini akan membahas cara menjalankan VIT di v5e menggunakan repositori HuggingFace di PyTorch/XLA pada set data cifar10.

Siapkan

  1. Buat variabel lingkungan:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.
    SERVICE_ACCOUNT Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud .

    Misalnya: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan.

  2. Buat resource TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Anda akan dapat menggunakan SSH ke VM TPU setelah QueuedResource berada dalam status ACTIVE:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Jika resource yang diantrekan berada dalam status ACTIVE, output-nya akan mirip dengan yang berikut ini:

     state: ACTIVE
    
  3. Menginstal dependensi PyTorch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
      pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    Ganti PYTORCH_VERSION dengan versi PyTorch yang ingin Anda gunakan. PYTORCH_VERSION digunakan untuk menentukan versi yang sama untuk PyTorch/XLA. 2.6.0 direkomendasikan.

    Untuk mengetahui informasi selengkapnya tentang versi PyTorch dan PyTorch/XLA, lihat PyTorch - Mulai dan rilis PyTorch/XLA.

    Untuk informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan PyTorch/XLA.

  4. Download repositori HuggingFace dan persyaratan penginstalan.

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

Melatih model

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

Menghapus TPU dan resource yang diantrekan

Hapus TPU dan resource yang diantrekan di akhir sesi.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hasil benchmark

Tabel berikut menunjukkan throughput benchmark untuk berbagai jenis akselerator.

v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
Ukuran batch global 32 128 512
Throughput (contoh/dtk) 201 657 2.844