Melatih Llama 3 menggunakan PyTorch di TPU v5e

Tutorial ini menjelaskan cara melatih model Llama-3-8B menggunakan PyTorch/XLA di TPU v5e menggunakan set data WikiText. Lihat Meta-Llama-3-8B untuk mengetahui detail model.

Model Llama-3-8B dihosting di platform Hugging Face.

Ada dua versi Meta-Llama-3-8B, satu untuk digunakan dengan Transformers dan satu lagi dengan codebase Llama 3 asli. Tutorial ini menggunakan versi Transformers karena:

  • Terintegrasi dengan lancar dengan ekosistem Hugging Face: Hal ini memudahkan penyesuaian model, penggunaan pipeline bawaan, dan akses ke koleksi set data dan alat yang luas.

  • Memungkinkan fleksibilitas dan penyesuaian: Versi Transformers menawarkan opsi penyesuaian dan fleksibilitas yang signifikan untuk menyesuaikan dan men-deploy model.

  • Memberikan dukungan komunitas: Komunitas Hugging Face menyediakan dokumentasi, tutorial, dan dukungan yang ekstensif untuk menggunakan model Transformer.

Untuk informasi selengkapnya tentang Transformer, lihat dokumentasi Hugging Face Transformers.

Untuk mengakses dan menggunakan model Meta-Llama-3-8B, termasuk mendownload bobot dan tokenizer-nya, Anda memerlukan token akses pengguna Hugging Face. Token ini menyediakan:

  • Autentikasi dan Otorisasi: Token akses berfungsi sebagai kredensial, memungkinkan server Hugging Face memberikan otorisasi akses Anda ke resource model. Hal ini memastikan bahwa hanya pengguna yang diotorisasi yang dapat mendownload dan menggunakan model.

  • Keamanan: Hugging Face menggunakan token akses untuk melindungi modelnya dan mencegah akses atau penyalahgunaan yang tidak sah.

Untuk informasi tentang cara membuat dan menggunakan token akses untuk tutorial ini, lihat Menjalankan model. Untuk informasi yang lebih komprehensif tentang cara membuat dan menggunakan token akses, lihat dokumentasi Hugging Face tentang token akses pengguna.

Anda juga memerlukan izin untuk mengakses model Llama 3 8B di Hugging Face. Untuk mendapatkan izin tersebut, buka model Meta-Llama-3-8B di Hugging Face dan minta akses.

Bersiap untuk menyediakan TPU v5litepod-16

Tutorial ini diuji menggunakan variabel lingkungan Cloud TPU berikut. Anda dapat menggunakan variabel lain untuk menyediakan TPU, asalkan jenis akselerator, zona, dan versi runtime kompatibel. Misalnya, dalam tutorial ini, europe-west4-b digunakan sebagai zona di seluruh bagian. Anda dapat menggunakan zona lain yang mendukung versi TPU (jenis akselerator) yang Anda jalankan (v5litepod-16 dalam tutorial ini).

Tetapkan variabel lingkungan VM TPU berikut.

   export TPU_NAME=queued-resources-node-id #The TPU name is the queued resource node-id
   export PROJECT_ID=your-project-id
   export ACCELERATOR_TYPE=v5litepod-16
   export ZONE=europe-west4-b
   export RUNTIME_VERSION=v2-alpha-tpuv5-lite
   export QUEUED_RESOURCE_ID=queued-resource-id
   export VALID_UNTIL_DURATION=1d

Jika Anda memiliki akses ke model Meta-Llama-3-8B di Hugging Face, siapkan lingkungan TPU untuk menjalankan tutorial.

  1. Ikuti panduan Menyiapkan lingkungan Cloud TPU untuk memastikan Anda memiliki akses yang sesuai untuk menggunakan Cloud TPU.

  2. Buat identitas layanan untuk VM TPU.

    gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
  3. Buat akun layanan TPU dan berikan akses ke layanan Google Cloud .

    Akun layanan memungkinkan layanan Google Cloud TPU mengakses layanan Google Cloudlainnya. Akun layanan yang dikelola pengguna direkomendasikan. Anda dapat membuat akun layanan dari Konsol Google Cloud atau melalui perintah gcloud.

    Buat akun layanan menggunakan alat command line gcloud:

    gcloud iam service-accounts create your-service-account-name \
    --description="your-sa-description" \
    --display-name="your-sa-display-name"
    export SERVICE_ACCOUNT_NAME=your-service-account-name

    Buat akun layanan dari konsol Google Cloud:

    1. Buka halaman Akun Layanan di konsol Google Cloud.
    2. Klik Create service account.
    3. Masukkan nama akun layanan.
    4. (Opsional) Masukkan deskripsi akun layanan.
    5. Klik Buat dan lanjutkan.
    6. Pilih peran yang ingin Anda berikan ke akun layanan.
    7. Klik Lanjutkan.
    8. (Opsional) Tentukan pengguna atau grup yang dapat mengelola akun layanan.
    9. Klik Selesai untuk menyelesaikan pembuatan akun layanan.

    Setelah membuat akun layanan, ikuti langkah-langkah berikut untuk memberikan peran akun layanan.

    Peran berikut diperlukan:

    • TPU Admin: Diperlukan untuk membuat TPU
    • Storage Admin: Diperlukan untuk mengakses Cloud Storage
    • Logs Writer
    • Monitoring Metric Writer: Diperlukan untuk menulis metrik ke Cloud Monitoring

    Administrator harus memberi Anda roles/resourcemanager.projectIamAdmin agar Anda dapat menetapkan peran IAM kepada pengguna. Pengguna dengan peran Project IAM Admin roles/resourcemanager.projectIamAdmin juga dapat memberikan peran ini.

    Gunakan perintah gcloud berikut untuk menambahkan peran akun layanan:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/tpu.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/storage.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/logging.logWriter
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/monitoring.metricWriter

    Anda juga dapat menetapkan peran menggunakan konsol Google Cloud.

    Dari konsol Google Cloud, pilih peran berikut:

    1. Pilih akun layanan Anda, lalu klik Tambahkan Akun Utama.
    2. Di kolom New Principals, masukkan alamat email akun layanan Anda.
    3. Di drop-down Select a role, telusuri peran (misalnya, Storage Admin) dan pilih peran tersebut.
    4. Klik Simpan.
  4. Lakukan autentikasi dengan Google Cloud dan konfigurasikan project dan zona default untuk Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Kapasitas aman

Jika Anda sudah siap untuk mendapatkan kapasitas TPU, tinjau halaman kuota untuk mempelajari sistem Kuota Cloud. Jika ada pertanyaan tambahan tentang cara mendapatkan kapasitas, hubungi tim penjualan atau akun Cloud TPU Anda.

Menyediakan lingkungan Cloud TPU

Anda dapat menyediakan VM TPU dengan GKE, dengan GKE dan XPK, atau sebagai resource dalam antrean.

Prasyarat

  • Tutorial ini telah diuji dengan Python 3.10 atau yang lebih baru.
  • Pastikan project Anda memiliki kuota TPUS_PER_TPU_FAMILY yang cukup, yang menentukan jumlah maksimum chip yang dapat Anda akses dalam projectGoogle Cloud .
  • Pastikan project Anda memiliki cukup kuota TPU untuk:
    • Kuota VM TPU
    • Kuota Alamat IP
    • Kuota Hyperdisk Balanced
  • Izin project pengguna

Menyediakan TPU v5litepod-16

  1. Buat VM TPU:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID}   \
    --node-id=${TPU_NAME}  \
    --project=${PROJECT_ID}   \
    --zone=${ZONE}   \
    --accelerator-type=${ACCELERATOR_TYPE}   \
    --runtime-version=${RUNTIME_VERSION}   \
    --service-account=${SERVICE_ACCOUNT_NAME}   \
    --spot
  2. Verifikasi bahwa TPU berada dalam status ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project=${PROJECT_ID} \
    --zone=${ZONE}

Saat TPU menjadi aktif (ACTIVE), Anda akan melihat output yang mirip dengan:

createTime: '2025-02-28T21:16:08.053492925Z'
name: projects/my-project/locations/zone/queuedResources/tpu-name-zone
spot: {}
state:
  state: ACTIVE
tpu:
  nodeSpec:
  - node:
      acceleratorType: v5litepod-16
      networkConfig:
        enableExternalIps: true
        network: default
      queuedResource: projects/19672137403/locations/zone/queuedResources/qr-name
      runtimeVersion: v2-alpha-tpuv5-lite
      schedulingConfig: {}
      my-service-account@your-project-id.iam.gserviceaccount.com
        email: 19672137854-compute@developer.iam.gserviceaccount.com
      shieldedInstanceConfig: {}
    nodeId: tpu-name
    parent: projects/19672137403/locations/zone

Penginstalan

Instal pytorch-tpu/transformers fork dari Hugging Face Transformers dan dependensinya. Tutorial ini diuji dengan versi dependensi berikut:

  • torch: kompatibel dengan 2.6.0
  • torch_xla[tpu]: kompatibel dengan 2.6.0
  • jax: 0.4.38
  • jaxlib: 0.4.38

Menginstal software dan dependensi framework

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    sudo apt install python3.10-venv
    python -m venv /home/$USER/venv/
    source ~/venv/bin/activate
    cd transformers
    pip3 install --user -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.38 jaxlib==0.4.38 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Setelah penginstalan selesai, Anda akan melihat output yang mirip dengan:

Collecting jax==0.4.38
  Downloading jax-0.4.38-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 18.0 MB/s eta 0:00:00
Collecting jaxlib==0.4.38
  Downloading jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.0/85.0 MB 10.1 MB/s eta 0:00:00
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting opt-einsum
  Downloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 KB 186.4 kB/s eta 0:00:00
Requirement already satisfied: numpy>=1.24 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (2.2.3)
Requirement already satisfied: scipy>=1.10 in /home/your-username/.local/lib/python3.10/site-packages (from jax==0.4.38) (1.15.2)
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0
Collecting ml-dtypes>=0.2.0
  Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 13.8 MB/s eta 0:00:00
Installing collected packages: opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.38 jaxlib-0.4.38 ml-dtypes-0.5.1 opt-einsum-3.4.0

Menyiapkan konfigurasi model

Perintah pelatihan di bagian berikutnya, Jalankan model, menggunakan dua file konfigurasi JSON untuk menentukan parameter model dan konfigurasi FSDP (Fully Sharded Data Parallel). Sharding FSDP digunakan untuk bobot model agar sesuai dengan ukuran batch yang lebih besar selama pelatihan. Saat melatih dengan model yang lebih kecil, Anda mungkin cukup menggunakan paralelisme data dan mereplikasi bobot di setiap perangkat. Untuk mengetahui informasi selengkapnya tentang cara melakukan shard tensor di seluruh perangkat di PyTorch/XLA, lihat Panduan Pengguna SPMD PyTorch/XLA.

  1. Perintah ini membuat file konfigurasi parameter model untuk Llama3-8B. Untuk model lainnya, temukan konfigurasi di Hugging Face. Misalnya, lihat konfigurasi Llama2-7B.

    cat > llama-config.json <<EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
    
  2. Buat file konfigurasi FSDP:

    cat > fsdp-config.json <<EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Untuk informasi selengkapnya tentang FSDP, lihat FSDPv2.

  3. Upload file konfigurasi ke VM TPU menggunakan perintah berikut:

     ssh-add ~/.ssh/google_compute_engine #Setup SSH Key in the SSH agent.
     gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

    Perintah ini akan menghasilkan output yang mirip dengan:

    Using scp batch size of 4.Attempting to SCP into 1 nodes with a total of 4 workers.
    SCP: Attempting to connect to worker 0...
    SCP: Attempting to connect to worker 1...
    SCP: Attempting to connect to worker 2...
    SCP: Attempting to connect to worker 3...
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.0KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    llama-config.json                    100%  707     4.1KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00
    fsdp-config.json                     100%  156     0.9KB/s   00:00

Menjalankan model

Dengan menggunakan file konfigurasi yang Anda buat di bagian sebelumnya, jalankan skrip run_clm.py untuk melatih model Llama 3 8B pada set data WikiText. Skrip pelatihan memerlukan waktu sekitar 10 menit untuk dijalankan di TPU v5litepod-16.

  1. Buat token Hugging Face baru jika Anda belum memilikinya:

    1. Klik Profil Anda > Setelan > Token Akses.
    2. Pilih New Token.
    3. Tentukan Nama pilihan Anda dan Peran minimal Baca.
    4. Pilih Buat token.
  2. Gunakan token Hugging Face untuk login ke Hugging Face dari VM TPU menggunakan perintah berikut.

    Ganti variabel token huggingface-cli login dengan variabel yang dibuat dari Hugging Face di langkah sebelumnya:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    pip install -U "huggingface_hub[cli]"
    export PATH="/home/$USER/.local/bin/:$PATH"
    huggingface-cli login --token hf_abcxyzEFg'

    Perintah ini akan membuat Anda login ke Hugging Face dan menampilkan token aktif saat ini.

  3. Jalankan pelatihan model:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --worker=all \
        --command='
        source ~/venv/bin/activate
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=your-bucket/profile_path
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

Langkah pelatihan memerlukan waktu sekitar 10 menit. Menjelang akhir pelatihan, Anda akan melihat pesan yang mirip dengan:

[INFO|trainer.py:2053] 2025-03-18 22:05:02,536 >> ***** Running training *****
[INFO|trainer.py:2054] 2025-03-18 22:05:02,536 >>   Num examples = 272
[INFO|trainer.py:2055] 2025-03-18 22:05:02,536 >>   Num Epochs = 2
[INFO|trainer.py:2056] 2025-03-18 22:05:02,536 >>   Instantaneous batch size per device = 16
[INFO|trainer.py:2059] 2025-03-18 22:05:02,536 >>   Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2060] 2025-03-18 22:05:02,536 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2061] 2025-03-18 22:05:02,536 >>   Total optimization steps = 20
[INFO|trainer.py:2062] 2025-03-18 22:05:02,537 >>   Number of trainable parameters = 8,030,261,248
  0%|          | 0/20 [00:00<?, ?it/s][INFO|trainer.py:2143] 2025-03-18 22:05:02,540 >> Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7f01bdcb6770>
  5%|         | 1/20 [00:07<02:29,  7.86s/it]/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  5%|         | 1/20 [00:07<02:29,  7.89s/it]Compilation at Step 0, time: 213.83555555343628
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1810:
 10%|         | 2/20 [03:43<38:57, 129.87s/it]Compilation at Step 0, time: 213.12156581878662
/home/your-username/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:"
 10%|█         | 2/20 [03:40<38:29, 128.30s/it]Compilation at Step 1, time: 224.5414960384369
 15%|█▌        | 3/20 [07:22<48:31, 171.24s/it]Compilation at Step 1, time: 226.23664164543152
 15%|█▌        | 3/20 [07:26<48:56, 172.73s/it]Compilation at Step 1, time: 226.9180543422699
Compilation at Step 1, time: 224.3874273300171
 20%|██        | 4/20 [07:23<27:45, 104.10s/it]Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104419: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 847930 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104373: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 763960 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
2025-03-18 22:12:32.104538: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 854020 nanoseconds and will start immediately.
2025-03-18 22:12:32.104347: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 761070 nanoseconds and will start immediately.
Starting to trace for 100000 ms. Remaining attempt(s): 2
 85%|████████▌ | 17/20 [07:55<00:06,  2.26s/it]Compilation at Step -1, time: 3.676558494567871
Compilation at Step -1, time: 3.447533130645752
Compilation at Step -1, time: 3.5890843868255615
Compilation at Step -1, time: 3.4956483840942383
100%|██████████| 20/20 [11:39<00:00, 35.14s/it][INFO|trainer.py:2350] 2025-03-18 22:16:42,476 >>
Training completed. Do not forget to share your model on huggingface.co/models =)

100%|██████████| 20/20 [11:47<00:00, 35.23s/it][INFO|trainer.py:2350] 2025-03-18 22:16:43,239 >>

Training completed. Do not forget to share your model on huggingface.co/models =)

Pembersihan

Setelah pelatihan selesai, gunakan langkah berikut untuk menghapus VM TPU dan resource yang diantrekan. Tindakan ini akan menghentikan penagihan untuk penggunaan VM TPU Anda.

  gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --force \
       --async