Pelatihan Cloud TPU v5e
Dengan jejak 256 chip yang lebih kecil per Pod, TPU v5e dioptimalkan untuk menjadi produk bernilai tinggi untuk pelatihan, penyesuaian, dan penayangan transformer, text-to-image, dan Convolutional Neural Network (CNN). Untuk informasi selengkapnya tentang cara menggunakan Cloud TPU v5e untuk penayangan, lihat Inferensi menggunakan v5e.
Untuk mengetahui informasi selengkapnya tentang konfigurasi dan hardware TPU Cloud TPU v5e, lihat TPU v5e.
Mulai
Bagian berikut menjelaskan cara mulai menggunakan TPU v5e.
Kuota permintaan
Anda memerlukan kuota untuk menggunakan TPU v5e untuk pelatihan. Ada berbagai jenis kuota untuk TPU on-demand, TPU yang direservasi, dan VM Spot TPU. Ada kuota terpisah yang diperlukan jika Anda menggunakan TPU v5e untuk inferensi. Untuk mengetahui informasi selengkapnya tentang kuota, lihat Kuota. Untuk meminta kuota TPU v5e, hubungi Cloud Sales.
Membuat akun dan project Google Cloud
Anda memerlukan akun dan project Google Cloud untuk menggunakan Cloud TPU. Untuk mengetahui informasi selengkapnya, lihat Menyiapkan lingkungan Cloud TPU.
Buat Cloud TPU
Praktik terbaiknya adalah menyediakan Cloud TPU v5 sebagai resource dalam antrean
menggunakan perintah queued-resource create
. Untuk mengetahui informasi selengkapnya, lihat
Mengelola resource dalam antrean.
Anda juga dapat menggunakan Create Node API (gcloud compute tpus tpu-vm create
) untuk menyediakan Cloud TPU v5e. Untuk mengetahui informasi selengkapnya, lihat Mengelola resource TPU.
Untuk mengetahui informasi selengkapnya tentang konfigurasi v5e yang tersedia untuk pelatihan, lihat Jenis Cloud TPU v5e untuk pelatihan.
Penyiapan framework
Bagian ini menjelaskan proses penyiapan umum untuk pelatihan model kustom menggunakan JAX atau PyTorch dengan TPU v5e.
Untuk petunjuk penyiapan inferensi, lihat pengantar inferensi v5e.
Tentukan beberapa variabel lingkungan:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
Penyiapan untuk JAX
Jika memiliki bentuk slice lebih dari 8 chip, Anda akan memiliki beberapa VM dalam satu
slice. Dalam hal ini, Anda perlu menggunakan flag --worker=all
untuk menjalankan penginstalan di semua VM TPU dalam satu langkah tanpa menggunakan SSH untuk login ke setiap VM secara terpisah:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Deskripsi flag perintah
Variabel | Deskripsi |
TPU_NAME | ID teks TPU yang ditetapkan pengguna yang dibuat saat permintaan resource yang diantrekan dialokasikan. |
PROJECT_ID | Google Cloud Nama Project. Gunakan project yang ada atau buat project baru di Menyiapkan Google Cloud project |
ZONA | Lihat dokumen Region dan zona TPU untuk zona yang didukung. |
pekerja | VM TPU yang memiliki akses ke TPU yang mendasarinya. |
Anda dapat menjalankan perintah berikut untuk memeriksa jumlah perangkat (output yang ditampilkan di sini dihasilkan dengan slice v5litepod-16). Kode ini menguji apakah semuanya diinstal dengan benar dengan memeriksa apakah JAX melihat TensorCore Cloud TPU dan dapat menjalankan operasi dasar:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
Outputnya akan seperti ini:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
menampilkan jumlah total chip dalam slice yang diberikan.
jax.local_device_count()
menunjukkan jumlah chip yang dapat diakses oleh satu VM di slice ini.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
Outputnya akan seperti ini:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
Coba Tutorial JAX dalam dokumen ini untuk memulai pelatihan v5e menggunakan JAX.
Penyiapan untuk PyTorch
Perhatikan bahwa v5e hanya mendukung runtime PJRT dan PyTorch 2.1+ akan menggunakan PJRT sebagai runtime default untuk semua versi TPU.
Bagian ini menjelaskan cara mulai menggunakan PJRT di v5e dengan PyTorch/XLA dengan perintah untuk semua pekerja.
Menginstal dependensi
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Ganti PYTORCH_VERSION
dengan versi PyTorch yang ingin Anda gunakan.
PYTORCH_VERSION
digunakan untuk menentukan versi yang sama untuk PyTorch/XLA. 2.6.0
direkomendasikan.
Untuk mengetahui informasi selengkapnya tentang versi PyTorch dan PyTorch/XLA, lihat PyTorch - Mulai dan rilis PyTorch/XLA.
Untuk informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan PyTorch/XLA.
Jika Anda mendapatkan error saat menginstal roda untuk torch
, torch_xla
, atau
torchvision
seperti
pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
,
downgrade versi Anda dengan perintah ini:
pip3 install setuptools==62.1.0
Menjalankan skrip dengan PJRT
unset LD_PRELOAD
Berikut adalah contoh penggunaan skrip Python untuk melakukan penghitungan di VM v5e:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Tindakan ini akan menghasilkan output yang mirip dengan berikut ini:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
Coba Tutorial PyTorch dalam dokumen ini untuk memulai pelatihan v5e menggunakan PyTorch.
Hapus TPU dan resource yang diantrekan di akhir sesi. Untuk menghapus resource dalam antrean, hapus slice, lalu resource dalam antrean dalam 2 langkah:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Kedua langkah ini juga dapat digunakan untuk menghapus permintaan resource yang diantrean yang berada dalam
status FAILED
.
Contoh JAX/FLAX
Bagian berikut menjelaskan contoh cara melatih model JAX dan FLAX di TPU v5e.
Melatih ImageNet di v5e
Tutorial ini menjelaskan cara melatih ImageNet di v5e menggunakan data input palsu. Jika Anda ingin menggunakan data sebenarnya, lihat file README di GitHub.
Siapkan
Buat variabel lingkungan:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Deskripsi variabel lingkungan
Variabel Deskripsi PROJECT_ID
Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru. TPU_NAME
Nama TPU. ZONE
Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU. ACCELERATOR_TYPE
Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU. RUNTIME_VERSION
Versi software Cloud TPU. SERVICE_ACCOUNT
Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud . Misalnya:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Anda akan dapat menggunakan SSH ke VM TPU setelah resource yang diantrekan berada dalam status
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Jika QueuedResource berada dalam status
ACTIVE
, output akan mirip dengan yang berikut ini:state: ACTIVE
Instal JAX dan jaxlib versi terbaru:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Clone model ImageNet dan instal persyaratan yang sesuai:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
Untuk membuat data palsu, model memerlukan informasi tentang dimensi set data. Hal ini dapat dikumpulkan dari metadata set data ImageNet:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
Melatih model
Setelah semua langkah sebelumnya selesai, Anda dapat melatih model.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"
Menghapus TPU dan resource yang diantrekan
Hapus TPU dan resource yang diantrekan di akhir sesi.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Model FLAX Hugging Face
Model Hugging Face yang diimplementasikan di FLAX berfungsi secara langsung di Cloud TPU v5e. Bagian ini memberikan petunjuk untuk menjalankan model populer.
Melatih ViT di Imagenette
Tutorial ini menunjukkan cara melatih model Vision Transformer (ViT) dari HuggingFace menggunakan set data Imagenette Fast AI di Cloud TPU v5e.
Model ViT adalah model pertama yang berhasil melatih encoder Transformer di ImageNet dengan hasil yang sangat baik dibandingkan dengan jaringan konvolusi. Untuk informasi selengkapnya, lihat referensi berikut:
Siapkan
Buat variabel lingkungan:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Deskripsi variabel lingkungan
Variabel Deskripsi PROJECT_ID
Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru. TPU_NAME
Nama TPU. ZONE
Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU. ACCELERATOR_TYPE
Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU. RUNTIME_VERSION
Versi software Cloud TPU. SERVICE_ACCOUNT
Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud . Misalnya:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Anda akan dapat menggunakan SSH ke VM TPU setelah resource dalam antrean berada dalam status
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Jika resource yang diantrekan berada dalam status
ACTIVE
, output-nya akan mirip dengan yang berikut ini:state: ACTIVE
Instal JAX dan library-nya:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Download repositori dan persyaratan penginstalan Hugging Face:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
Download set data Imagenette:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Melatih model
Latih model dengan buffering yang telah dipetakan sebelumnya sebesar 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
Menghapus TPU dan resource yang diantrekan
Hapus TPU dan resource yang diantrekan di akhir sesi.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Hasil benchmark ViT
Skrip pelatihan dijalankan di v5litepod-4, v5litepod-16, dan v5litepod-64. Tabel berikut menunjukkan throughput dengan berbagai jenis akselerator.
Jenis akselerator | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Epoch | 3 | 3 | 3 |
Ukuran batch global | 32 | 128 | 512 |
Throughput (contoh/dtk) | 263,40 | 429,34 | 470,71 |
Melatih Difusi di Pokémon
Tutorial ini menunjukkan cara melatih model Stable Diffusion dari HuggingFace menggunakan set data Pokémon di Cloud TPU v5e.
Model Stable Diffusion adalah model teks ke gambar laten yang menghasilkan gambar fotorealistik dari input teks apa pun. Untuk informasi selengkapnya, lihat referensi berikut:
Siapkan
Tetapkan variabel lingkungan untuk nama bucket penyimpanan Anda:
export GCS_BUCKET_NAME=your_bucket_name
Siapkan bucket penyimpanan untuk output model Anda:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
Buat variabel lingkungan:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Deskripsi variabel lingkungan
Variabel Deskripsi PROJECT_ID
Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru. TPU_NAME
Nama TPU. ZONE
Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU. ACCELERATOR_TYPE
Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU. RUNTIME_VERSION
Versi software Cloud TPU. SERVICE_ACCOUNT
Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud . Misalnya:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Anda akan dapat menggunakan SSH ke VM TPU setelah resource yang diantrekan berada dalam status
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Jika resource yang diantrekan berada dalam status
ACTIVE
, output-nya akan mirip dengan berikut ini:state: ACTIVE
Instal JAX dan library-nya.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Download repositori HuggingFace dan instal persyaratan.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
Melatih model
Latih model dengan buffering yang telah dipetakan sebelumnya sebesar 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
pip3 install gcsfs
export LIBTPU_INIT_ARGS=''
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"
Pembersihan
Hapus TPU, resource yang diantrekan, dan bucket Cloud Storage di akhir sesi.
Menghapus TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Hapus resource yang ada dalam antrean:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Hapus bucket Cloud Storage:
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
Hasil benchmark untuk difusi
Skrip pelatihan dijalankan di v5litepod-4, v5litepod-16, dan v5litepod-64. Tabel berikut menunjukkan throughput.
Jenis akselerator | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Langkah Latihan | 1500 | 1500 | 1500 |
Ukuran batch global | 32 | 64 | 128 |
Throughput (contoh/dtk) | 36,53 | 43,71 | 49,36 |
PyTorch/XLA
Bagian berikut menjelaskan contoh cara melatih model PyTorch/XLA di TPU v5e.
Melatih ResNet menggunakan runtime PJRT
PyTorch/XLA bermigrasi dari XRT ke PjRt dari PyTorch 2.0+. Berikut adalah petunjuk yang diperbarui untuk menyiapkan v5e untuk workload pelatihan PyTorch/XLA.
Siapkan
Buat variabel lingkungan:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Deskripsi variabel lingkungan
Variabel Deskripsi PROJECT_ID
Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru. TPU_NAME
Nama TPU. ZONE
Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU. ACCELERATOR_TYPE
Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU. RUNTIME_VERSION
Versi software Cloud TPU. SERVICE_ACCOUNT
Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud . Misalnya:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Anda akan dapat melakukan SSH ke VM TPU setelah QueuedResource berada dalam status
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Jika resource yang diantrekan berada dalam status
ACTIVE
, output-nya akan mirip dengan berikut ini:state: ACTIVE
Menginstal dependensi khusus Torch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Ganti
PYTORCH_VERSION
dengan versi PyTorch yang ingin Anda gunakan.PYTORCH_VERSION
digunakan untuk menentukan versi yang sama untuk PyTorch/XLA. 2.6.0 direkomendasikan.Untuk mengetahui informasi selengkapnya tentang versi PyTorch dan PyTorch/XLA, lihat PyTorch - Mulai dan rilis PyTorch/XLA.
Untuk informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan PyTorch/XLA.
Melatih model ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git checkout release-r2.6
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
Menghapus TPU dan resource yang diantrekan
Hapus TPU dan resource yang diantrekan di akhir sesi.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Hasil benchmark
Tabel berikut menunjukkan throughput benchmark.
Jenis akselerator | Throughput (contoh/detik) |
v5litepod-4 | 4240 ex/s |
v5litepod-16 | 10.810 ex/s |
v5litepod-64 | 46.154 ex/s |
Melatih ViT di v5e
Tutorial ini akan membahas cara menjalankan VIT di v5e menggunakan repositori HuggingFace di PyTorch/XLA pada set data cifar10.
Siapkan
Buat variabel lingkungan:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Deskripsi variabel lingkungan
Variabel Deskripsi PROJECT_ID
Project ID Google Cloud Anda. Gunakan project yang ada atau buat project baru. TPU_NAME
Nama TPU. ZONE
Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU. ACCELERATOR_TYPE
Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat versi TPU. RUNTIME_VERSION
Versi software Cloud TPU. SERVICE_ACCOUNT
Alamat email untuk akun layanan Anda. Anda dapat menemukannya dengan membuka halaman Akun Layanan di konsol Google Cloud . Misalnya:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
ID teks yang ditetapkan pengguna untuk permintaan resource yang diantrekan. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Anda akan dapat menggunakan SSH ke VM TPU setelah QueuedResource berada dalam status
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Jika resource yang diantrekan berada dalam status
ACTIVE
, output-nya akan mirip dengan yang berikut ini:state: ACTIVE
Menginstal dependensi PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Ganti
PYTORCH_VERSION
dengan versi PyTorch yang ingin Anda gunakan.PYTORCH_VERSION
digunakan untuk menentukan versi yang sama untuk PyTorch/XLA. 2.6.0 direkomendasikan.Untuk mengetahui informasi selengkapnya tentang versi PyTorch dan PyTorch/XLA, lihat PyTorch - Mulai dan rilis PyTorch/XLA.
Untuk informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan PyTorch/XLA.
Download repositori HuggingFace dan persyaratan penginstalan.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Melatih model
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
Menghapus TPU dan resource yang diantrekan
Hapus TPU dan resource yang diantrekan di akhir sesi.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Hasil benchmark
Tabel berikut menunjukkan throughput benchmark untuk berbagai jenis akselerator.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoch | 3 | 3 | 3 |
Ukuran batch global | 32 | 128 | 512 |
Throughput (contoh/dtk) | 201 | 657 | 2.844 |