Mempertahankan progres pelatihan menggunakan Autocheckpoint

Secara historis, saat VM TPU memerlukan pemeliharaan, prosedur akan segera dimulai, tanpa memberi waktu bagi pengguna untuk melakukan tindakan yang mempertahankan progres seperti menyimpan titik pemeriksaan. Hal ini ditunjukkan dalam Gambar 1(a).

Diagram yang menunjukkan dampak pemeliharaan host dengan dan tanpa pemeriksaan otomatis

Gambar 1. Ilustrasi fitur Autocheckpoint: (a) Tanpa Autocheckpoint, progres pelatihan dari checkpoint terakhir akan hilang saat ada peristiwa pemeliharaan mendatang. (b) Dengan Autocheckpoint, progres pelatihan sejak checkpoint terakhir dapat dipertahankan saat ada peristiwa pemeliharaan mendatang.

Anda dapat menggunakan Autocheckpoint (Gambar 1(b)) untuk mempertahankan progres pelatihan dengan mengonfigurasi kode untuk menyimpan titik pemeriksaan yang tidak terjadwal saat peristiwa pemeliharaan terjadi. Saat peristiwa pemeliharaan terjadi, progres sejak titik kontrol terakhir akan otomatis disimpan. Fitur ini berfungsi pada satu irisan dan Multislice.

Fitur Autocheckpoint berfungsi dengan framework yang dapat menangkap sinyal SIGTERM, lalu menyimpan checkpoint. Framework yang didukung meliputi:

Menggunakan Titik pemeriksaan otomatis

Fitur Autocheckpoint dinonaktifkan secara default. Saat membuat TPU atau meminta resource dalam antrean, Anda dapat mengaktifkan Autocheckpoint dengan menambahkan flag --autocheckpoint-enabled saat menyediakan TPU. Dengan mengaktifkan fitur ini, Cloud TPU akan melakukan langkah-langkah berikut setelah menerima notifikasi tentang peristiwa pemeliharaan:

  1. Menangkap sinyal SIGTERM yang dikirim ke proses menggunakan perangkat TPU
  2. Tunggu hingga proses keluar, atau 5 menit telah berlalu, mana saja yang lebih dulu
  3. Melakukan pemeliharaan pada slice yang terpengaruh

Infrastruktur yang digunakan oleh Autocheckpoint tidak bergantung pada framework ML. Framework ML apa pun dapat mendukung Autocheckpoint jika dapat menangkap sinyal SIGTERM dan memulai proses pembuatan checkpoint.

Dalam kode aplikasi, Anda perlu mengaktifkan kemampuan Autocheckpoint yang disediakan oleh framework ML. Misalnya, di Pax, hal ini berarti mengaktifkan flag command line saat meluncurkan pelatihan. Untuk informasi selengkapnya, lihat panduan memulai Autocheckpoint dengan Pax. Di balik layar, framework menyimpan titik pemeriksaan yang tidak terjadwal saat sinyal SIGTERM diterima, dan VM TPU yang terpengaruh akan menjalani pemeliharaan saat TPU tidak lagi digunakan.

Panduan memulai: Titik henti sementara otomatis dengan MaxText

MaxText adalah LLM berperforma tinggi, skalabilitas arbitrer, open source, dan telah diuji dengan baik yang ditulis dalam Python/JAX murni yang menargetkan Cloud TPU. MaxText berisi semua penyiapan yang diperlukan untuk menggunakan fitur Autocheckpoint.

File README MaxText menjelaskan dua cara untuk menjalankan MaxText dalam skala besar:

Saat menggunakan multihost_runner.py, aktifkan Autocheckpoint dengan menetapkan flag autocheckpoint-enabled saat menyediakan resource yang diantrekan.

Saat menggunakan multihost_job.py, aktifkan Autocheckpoint dengan menentukan flag command line ENABLE_AUTOCHECKPOINT=true saat meluncurkan tugas.

Panduan memulai: Titik periksa otomatis dengan Pax di satu slice

Bagian ini memberikan contoh cara menyiapkan dan menggunakan Autocheckpoint dengan Pax pada satu slice. Dengan penyiapan yang sesuai:

  • Titik pemeriksaan akan disimpan saat peristiwa pemeliharaan terjadi.
  • Cloud TPU akan melakukan pemeliharaan pada VM TPU yang terpengaruh setelah checkpoint disimpan.
  • Setelah Cloud TPU menyelesaikan pemeliharaan, Anda dapat menggunakan VM TPU seperti biasa.
  1. Gunakan flag autocheckpoint-enabled saat membuat VM TPU atau meminta resource dalam antrean.

    Contoh:

    export PROJECT=your-gcp-project-name
    export ZONE=zone-you-want-to-use
    export NODE_ID=your-node-id
    export ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
        --accelerator-type $ACCELERATOR_TYPE \
        --version tpu-ubuntu2204-base \
        --autocheckpoint-enabled
  2. Hubungkan ke TPU menggunakan SSH:

    gcloud compute tpus tpu-vm ssh $NODE_ID 
    
  3. Menginstal Pax di satu slice

    Fitur Autocheckpoint berfungsi di Pax versi 1.1.0 dan yang lebih baru. Di VM TPU, instal jax[tpu] dan paxml terbaru:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. Luncurkan pelatihan dengan konfigurasi yang sesuai.

    Contoh berikut menunjukkan cara mengonfigurasi model LmCloudSpmd2B untuk menyimpan titik pemeriksaan yang dipicu oleh Autocheckpoint ke bucket Cloud Storage. Ganti your-storage-bucket dengan nama bucket yang ada, atau buat bucket baru.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Perhatikan dua flag yang diteruskan ke perintah:

    • jax_fully_async_checkpoint: Jika flag ini diaktifkan, orbax.checkpoint.AsyncCheckpointer akan digunakan. Class AsyncCheckpointer otomatis menyimpan titik pemeriksaan saat skrip pelatihan menerima sinyal SIGTERM.
    • exit_after_ondemand_checkpoint: Dengan mengaktifkan tanda ini, proses TPU akan keluar setelah Autocheckpoint berhasil disimpan, yang memicu pemeliharaan untuk segera dilakukan. Jika Anda tidak menggunakan flag ini, pelatihan akan dilanjutkan setelah titik pemeriksaan disimpan dan Cloud TPU akan menunggu waktu tunggu habis (5 menit) sebelum melakukan pemeliharaan yang diperlukan.

Panduan memulai: Titik pemeriksaan otomatis dengan Pax di Multislice

Titik periksa otomatis tidak hanya berfungsi untuk satu slice, tetapi juga untuk Multislice. Bagian ini menjelaskan langkah-langkah yang diperlukan untuk menggunakan Autocheckpoint dengan Multislice.

  1. Tentukan Autocheckpoint selama pembuatan resource yang diantrekan.

    Lingkungan Multislice hanya dapat disediakan melalui permintaan resource yang diantrekan. Serupa dengan kasus satu slice, gunakan flag autocheckpoint-enabled dalam panggilan untuk membuat resource yang diantrekan.

    export QR_ID=your-qr-id
    export NODE_COUNT=your-node-count
    export ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
        --node-count $NODE_COUNT \
        --accelerator-type $ACCELERATOR_TYPE \
        --runtime-version tpu-ubuntu2204-base \
        --autocheckpoint-enabled

    Untuk informasi selengkapnya tentang semua opsi yang tersedia, lihat Panduan pengguna multislice. Saat permintaan resource yang diantrekan dibuat dan dalam status ACTIVE, ikuti langkah-langkah berikutnya untuk menjalankan Pax dengan Autocheckpoint.

  2. Instal jax[tpu] dan paxml terbaru di semua VM TPU di lingkungan Multislice Anda.

    gcloud compute tpus queued-resources ssh $QR_ID \
        --node=all \
        --worker=all \
        --batch-size=your-batch-size \
        --command="pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

    Tetapkan flag --batch-size ke jumlah koneksi serentak yang harus dibuat dengan pekerja TPU. Untuk informasi selengkapnya tentang cara memilih ukuran batch untuk beban kerja Multislice, lihat Mengoptimalkan pelatihan.

  3. Konfigurasikan model LmCloudSpmd2B untuk Autocheckpoint saat berlatih di lingkungan Multislice. Sebelum menjalankan skrip pelatihan, tetapkan DCN_MESH_SHAPE ke [2, 1, 1] seperti yang ditunjukkan pada contoh berikut:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 4, 1]
        DCN_MESH_SHAPE = [2, 1, 1]
  4. Untuk menetapkan checkpoint lebih sering, tetapkan task_p.train.save_interval_steps dan task_p.train.save_max_to_keep, seperti yang ditunjukkan dalam contoh berikut:

    @experiment_registry.register
    class LmCloudSpmd2BLimitSteps(LmCloudSpmd2B):
    """SPMD model with 2B params and limited steps.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    
    def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
        task_p = super().task()
        task_p.train.save_interval_steps = 50
        task_p.train.save_max_to_keep = 5
        return task_p
    
  5. Luncurkan pelatihan dengan menjalankan perintah berikut untuk setiap host. Ganti your-storage-bucket dengan nama bucket yang ada, atau buat bucket baru.

    export TF_CPP_MIN_LOG_LEVEL=0
    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --num_hosts=2 \
        --host_idx=host-index \
        --server_addr=worker0-node0-ip-address \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Saat meluncurkan pelatihan, selain flag command line yang dibahas dalam kasus satu slice, tiga flag lainnya diperlukan:

    • num_hosts: jumlah total host. Dalam hal ini, nilainya adalah 2.
    • host_idx: indeks host yang meluncurkan pelatihan. Nilai ini bervariasi dari 0 hingga N-1 dengan N adalah jumlah total host.
    • server_addr: alamat IP pekerja 0 dari node 0, dengan port yang tidak digunakan (misalnya, 8476). Untuk menemukan informasi ini, gunakan hostname -i pada pekerja 0 dari node 0.

Titik pemeriksaan otomatis dengan Orbax

Fitur Autocheckpoint tidak terbatas pada MaxText atau Pax. Setiap framework yang dapat menangkap sinyal SIGTERM dan memulai proses checkpointing berfungsi dengan infrastruktur yang disediakan oleh Autocheckpoint. Orbax, namespace yang menyediakan library utilitas umum untuk pengguna JAX, menyediakan kemampuan ini.

Seperti yang dijelaskan dalam dokumentasi Orbax, kemampuan ini diaktifkan secara default untuk pengguna orbax.checkpoint.CheckpointManager. Metode save yang dipanggil setelah setiap langkah akan otomatis memeriksa apakah peristiwa pemeliharaan akan segera terjadi, dan jika ya, akan menyimpan titik pemeriksaan meskipun nomor langkah bukan kelipatan save_interval_steps. Dokumentasi GitHub juga mengilustrasikan cara membuat pelatihan keluar setelah menyimpan Autocheckpoint, dengan modifikasi pada kode pengguna.