Melatih Llama 3 menggunakan PyTorch di TPU v5e
Tutorial ini menjelaskan cara melatih model Llama-3-8B menggunakan PyTorch/XLA di TPU v5e menggunakan set data WikiText. Lihat Meta-Llama-3-8B untuk mengetahui detail model.
Model Llama-3-8B dihosting di platform Hugging Face.
Ada dua versi Meta-Llama-3-8B, satu untuk digunakan dengan Transformers dan satu lagi dengan codebase Llama 3 asli. Tutorial ini menggunakan versi Transformers karena:
Terintegrasi dengan lancar dengan ekosistem Hugging Face: Hal ini memudahkan penyesuaian model, penggunaan pipeline bawaan, dan akses ke koleksi set data dan alat yang luas.
Memungkinkan fleksibilitas dan penyesuaian: Versi Transformers menawarkan opsi penyesuaian dan fleksibilitas yang signifikan untuk menyesuaikan dan men-deploy model.
Memberikan dukungan komunitas: Komunitas Hugging Face menyediakan dokumentasi, tutorial, dan dukungan yang ekstensif untuk menggunakan model Transformer.
Untuk informasi selengkapnya tentang Transformer, lihat dokumentasi Hugging Face Transformers.
Untuk mengakses dan menggunakan model Meta-Llama-3-8B, termasuk mendownload bobot dan tokenizer-nya, Anda memerlukan token akses pengguna Hugging Face. Token ini menyediakan:
Autentikasi dan Otorisasi: Token akses berfungsi sebagai kredensial, memungkinkan server Hugging Face memberikan otorisasi akses Anda ke resource model. Hal ini memastikan bahwa hanya pengguna yang diotorisasi yang dapat mendownload dan menggunakan model.
Keamanan: Hugging Face menggunakan token akses untuk melindungi modelnya dan mencegah akses atau penyalahgunaan yang tidak sah.
Untuk informasi tentang cara membuat dan menggunakan token akses untuk tutorial ini, lihat Menjalankan model. Untuk informasi yang lebih komprehensif tentang cara membuat dan menggunakan token akses, lihat dokumentasi Hugging Face tentang token akses pengguna.
Anda juga memerlukan izin untuk mengakses model Llama 3 8B di Hugging Face. Untuk mendapatkan izin tersebut, buka model Meta-Llama-3-8B di Hugging Face dan minta akses.
Bersiap untuk menyediakan TPU v5litepod-16
Tutorial ini diuji menggunakan variabel lingkungan Cloud TPU berikut. Anda dapat menggunakan variabel lain untuk menyediakan TPU,
asalkan jenis akselerator, zona, dan versi runtime kompatibel.
Misalnya, dalam tutorial ini, europe-west4-b
digunakan sebagai zona di seluruh bagian. Anda dapat menggunakan zona lain yang mendukung
versi TPU (jenis akselerator) yang Anda jalankan (v5litepod-16 dalam tutorial ini).
Tetapkan variabel lingkungan VM TPU berikut.
export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v5litepod-16 export ZONE=europe-west4-b export RUNTIME_VERSION=v2-alpha-tpuv5-lite export QUEUED_RESOURCE_ID=queued-resource-id export VALID_UNTIL_DURATION=1d
Jika Anda memiliki akses ke model Meta-Llama-3-8B di Hugging Face, siapkan lingkungan TPU untuk menjalankan tutorial.
Ikuti panduan Menyiapkan lingkungan Cloud TPU untuk memastikan Anda memiliki akses yang sesuai untuk menggunakan Cloud TPU.
Buat identitas layanan untuk VM TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
Buat akun layanan TPU dan berikan akses ke layanan Google Cloud .
Akun layanan memungkinkan layanan Google Cloud TPU mengakses layanan Google Cloudlainnya. Akun layanan yang dikelola pengguna direkomendasikan. Anda dapat membuat akun layanan dari Konsol Google Cloud atau melalui perintah
gcloud
.Buat akun layanan menggunakan alat command line
gcloud
:gcloud iam service-accounts create your-service-account-name \ --description="your-sa-description" \ --display-name="your-sa-display-name" export SERVICE_ACCOUNT_NAME=your-service-account-name
Buat akun layanan dari konsol Google Cloud:
- Buka halaman Akun Layanan di konsol Google Cloud.
- Klik Create service account.
- Masukkan nama akun layanan.
- (Opsional) Masukkan deskripsi akun layanan.
- Klik Buat dan lanjutkan.
- Pilih peran yang ingin Anda berikan ke akun layanan.
- Klik Lanjutkan.
- (Opsional) Tentukan pengguna atau grup yang dapat mengelola akun layanan.
- Klik Selesai untuk menyelesaikan pembuatan akun layanan.
Setelah membuat akun layanan, ikuti langkah-langkah berikut untuk memberikan peran akun layanan.
Peran berikut diperlukan:
- TPU Admin: Diperlukan untuk membuat TPU
- Storage Admin: Diperlukan untuk mengakses Cloud Storage
- Logs Writer
- Monitoring Metric Writer: Diperlukan untuk menulis metrik ke Cloud Monitoring
Administrator harus memberi Anda
roles/resourcemanager.projectIamAdmin
agar Anda dapat menetapkan peran IAM kepada pengguna. Pengguna dengan peran Project IAM Adminroles/resourcemanager.projectIamAdmin
juga dapat memberikan peran ini.Gunakan perintah
gcloud
berikut untuk menambahkan peran akun layanan:gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/tpu.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/storage.admin gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/logging.logWriter gcloud projects add-iam-policy-binding ${PROJECT_ID} \ --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \ --role roles/monitoring.metricWriter
Anda juga dapat menetapkan peran menggunakan konsol Google Cloud.
Dari konsol Google Cloud, pilih peran berikut:
- Pilih akun layanan Anda, lalu klik Tambahkan Akun Utama.
- Di kolom New Principals, masukkan alamat email akun layanan Anda.
- Di drop-down Select a role, telusuri peran (misalnya, Storage Admin) dan pilih peran tersebut.
- Klik Simpan.
Lakukan autentikasi dengan Google Cloud dan konfigurasikan project dan zona default untuk Google Cloud CLI.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Kapasitas aman
Jika Anda sudah siap untuk mendapatkan kapasitas TPU, tinjau halaman kuota untuk mempelajari sistem Kuota Cloud. Jika ada pertanyaan tambahan tentang cara mendapatkan kapasitas, hubungi tim penjualan atau akun Cloud TPU Anda.
Menyediakan lingkungan Cloud TPU
Anda dapat menyediakan VM TPU dengan GKE, dengan GKE dan XPK, atau sebagai resource dalam antrean.
Prasyarat
- Tutorial ini telah diuji dengan Python 3.10 atau yang lebih baru.
- Pastikan project Anda memiliki kuota
TPUS_PER_TPU_FAMILY
yang cukup, yang menentukan jumlah maksimum chip yang dapat Anda akses dalam projectGoogle Cloud . - Pastikan project Anda memiliki cukup kuota TPU untuk:
- Kuota VM TPU
- Kuota Alamat IP
- Kuota Hyperdisk Balanced
- Izin project pengguna
- Jika Anda menggunakan GKE dengan XPK, lihat Izin Konsol Cloud di akun pengguna atau layanan untuk mengetahui izin yang diperlukan untuk menjalankan XPK.
Menyediakan TPU v5litepod-16
Buat VM TPU:
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT_NAME} \ --spot
Verifikasi bahwa TPU berada dalam status
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Saat TPU menjadi aktif (ACTIVE
), Anda akan melihat output yang mirip dengan:
createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
state: ACTIVE
tpu:
nodeSpec:
- node:
acceleratorType: v5litepod-16
networkConfig:
enableExternalIps: true
network: default
queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
runtimeVersion: v2-alpha-tpuv5-lite
schedulingConfig: {}
my-service-account@your-project-id.iam.gserviceaccount.com
email: 19672137854-compute@developer.iam.gserviceaccount.com
shieldedInstanceConfig: {}
nodeId: tpu-name
parent: projects/19672137403/locations/zone
Penginstalan
Instal pytorch-tpu/transformers
fork dari
Hugging Face Transformers dan dependensinya. Tutorial ini diuji dengan
versi dependensi berikut:
torch
: kompatibel dengan 2.6.0torch_xla[tpu]
: kompatibel dengan 2.6.0jax
: 0.4.38jaxlib
: 0.4.38
Menginstal software dan dependensi framework
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git sudo apt install python3.10-venv python -m venv /home/$USER/venv/ source ~/venv/bin/activate cd transformers pip3 install --user -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Setelah penginstalan selesai, Anda akan melihat output yang mirip dengan:
Collecting jax==0.4.38
Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Menyiapkan konfigurasi model
Perintah pelatihan di bagian berikutnya, Jalankan model, menggunakan dua file konfigurasi JSON untuk menentukan parameter model dan konfigurasi FSDP (Fully Sharded Data Parallel). Sharding FSDP digunakan untuk bobot model agar sesuai dengan ukuran batch yang lebih besar selama pelatihan. Saat melatih dengan model yang lebih kecil, Anda mungkin cukup menggunakan paralelisme data dan mereplikasi bobot di setiap perangkat. Untuk mengetahui informasi selengkapnya tentang cara melakukan shard tensor di seluruh perangkat di PyTorch/XLA, lihat Panduan Pengguna SPMD PyTorch/XLA.
Perintah ini membuat file konfigurasi parameter model untuk Llama3-8B. Untuk model lainnya, temukan konfigurasi di Hugging Face. Misalnya, lihat konfigurasi Llama2-7B.
cat > llama-config.json <<EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Buat file konfigurasi FSDP:
cat > fsdp-config.json <<EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
Untuk informasi selengkapnya tentang FSDP, lihat FSDPv2.
Upload file konfigurasi ke VM TPU menggunakan perintah berikut:
ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent. gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Perintah ini akan menghasilkan output yang mirip dengan:
Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers. SCP: Attempting to connect to worker 0... SCP: Attempting to connect to worker 1... SCP: Attempting to connect to worker 2... SCP: Attempting to connect to worker 3... llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.0KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 llama-config.json 100% 707 4.1KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00 fsdp-config.json 100% 156 0.9KB/s 00:00
Menjalankan model
Dengan menggunakan file konfigurasi yang Anda buat di bagian sebelumnya, jalankan skrip run_clm.py
untuk melatih model Llama 3 8B pada set data WikiText. Skrip pelatihan memerlukan waktu sekitar 10 menit untuk dijalankan di TPU v5litepod-16.
Buat token Hugging Face baru jika Anda belum memilikinya:
- Klik Profil Anda > Setelan > Token Akses.
- Pilih New Token.
- Tentukan Nama pilihan Anda dan Peran minimal Baca.
- Pilih Buat token.
Gunakan token Hugging Face untuk login ke Hugging Face dari VM TPU menggunakan perintah berikut.
Ganti variabel token
huggingface-cli login
dengan variabel yang dibuat dari Hugging Face di langkah sebelumnya:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' pip install -U "huggingface_hub[cli]" export PATH="/home/$USER/.local/bin/:$PATH" huggingface-cli login --token hf_abcxyzEFg'
Perintah ini akan membuat Anda login ke Hugging Face dan menampilkan token aktif saat ini.
Jalankan pelatihan model:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' source ~/venv/bin/activate export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=your-bucket/profile_path python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Langkah pelatihan memerlukan waktu sekitar 10 menit. Menjelang akhir pelatihan, Anda akan melihat pesan yang mirip dengan:
[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >> Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >> Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >> Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >> Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >> Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >> Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >> Number of trainable parameters = 8,030,261,248
0%| | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
5%|▌ | 1/20 [00:07<02:29, 7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
5%|▌ | 1/20 [00:07<02:29, 7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
10%|█ | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
10%|█ | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
15%|█▌ | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
15%|█▌ | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
20%|██ | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
85%|████████▌ | 17/20 [07:55<00:06, 2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
Pembersihan
Setelah pelatihan selesai, gunakan langkah berikut untuk menghapus VM TPU dan resource yang diantrekan. Tindakan ini akan menghentikan penagihan untuk penggunaan VM TPU Anda.
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --force \ --async