자동 체크포인트를 사용해 학습 진행 상태 보존

지금까지 TPU VM에 유지보수가 필요할 때는 사용자가 체크포인트 저장과 같은 진행 상태 보존 작업을 수행할 시간 없이 절차가 즉시 시작되었습니다. 이 내용은 그림 1(a)에 나와 있습니다.

자동 체크포인트를 사용하는 경우와 사용하지 않는 경우의 호스트 유지보수 영향을 보여주는 다이어그램

그림 1. 자동 체크포인트 기능 그림: (a) 자동 체크포인트가 없으면 예정된 유지보수 이벤트가 있을 때 마지막 체크포인트의 학습 진행 상태가 손실됩니다. (b) 자동 체크포인트를 사용하면 예정된 유지보수 이벤트가 있을 때 마지막 체크포인트 이후의 학습 진행 상태를 보존할 수 있습니다.

자동 체크포인트(그림 1(b))를 사용하면 유지보수 이벤트가 발생할 때 예약되지 않은 체크포인트를 저장하도록 코드를 구성하여 학습 진행 상태를 보존할 수 있습니다. 유지보수 이벤트가 발생하면 마지막 체크포인트 이후의 진행 상태가 자동으로 저장됩니다. 이 기능은 단일 슬라이스와 멀티슬라이스 모두에서 작동합니다.

자동 체크포인트 기능은 SIGTERM 신호를 캡처하고 이후에 체크포인트를 저장하는 프레임워크에서 작동합니다. 지원되는 프레임워크는 다음과 같습니다.

자동 체크포인트 사용

자동 체크포인트 기능은 기본적으로 사용 중지되어 있습니다. TPU를 만들거나 또는 큐에 추가된 리소스를 요청할 때는 TPU를 프로비저닝할 때 --autocheckpoint-enabled 플래그를 추가하여 자동 체크포인트를 사용 설정할 수 있습니다. 이 기능을 사용 설정하면 유지보수 이벤트 알림이 수신되었을 때 Cloud TPU가 다음 단계를 수행합니다.

  1. TPU 기기를 사용해서 전송된 SIGTERM 신호를 진행 상태에 캡처합니다.
  2. 프로세스가 종료되거나 5분이 경과될 때까지 기다립니다.
  3. 영향을 받는 리소스에서 유지보수를 수행합니다.

자동 체크포인트에 사용되는 인프라는 ML 프레임워크에 독립적입니다. SIGTERM 신호를 캡처하고 체크포인트 지정 프로세스를 시작할 수 있는 한 어떤 ML 프레임워크라도 자동 체크포인트를 지원할 수 있습니다.

애플리케이션 코드에서 ML 프레임워크에서 제공된 자동 체크포인트 기능을 사용 설정해야 합니다. 예를 들어 Pax에서는 학습을 시작할 때 명령줄 플래그를 사용 설정해야 합니다. 자세한 내용은 Pax를 사용한 자동 체크포인트 빠른 시작을 참조하세요. 이 과정 중에 프레임워크는 SIGTERM 신호가 수신될 때 예약되지 않은 체크포인트를 저장하고 TPU가 더 이상 사용 중이 아닐 때 영향을 받는 TPU VM에 유지보수가 진행됩니다.

빠른 시작: MaxText를 사용한 자동 체크포인트

MaxText는 Cloud TPU를 대상으로 순수 Python/JAX로 작성되어 임의로 확장 가능하고 테스트를 철저하게 거친 고성능 오픈소스 LLM입니다. MaxText에는 자동 체크포인트 기능을 사용하는 데 필요한 모든 설정이 포함됩니다.

MaxText README 파일에서는 규모에 맞게 MaxText를 실행하기 위한 두 가지 방법에 대해 설명합니다.

multihost_runner.py를 사용하는 경우 큐에 추가된 리소스를 프로비저닝할 때 autocheckpoint-enabled 플래그를 설정하여 자동 체크포인트를 사용 설정합니다.

multihost_job.py를 사용할 경우에는 작업을 실행할 때 ENABLE_AUTOCHECKPOINT=true 명령줄 플래그를 지정하여 자동 체크포인트를 사용 설정합니다.

빠른 시작: 단일 슬라이스에서 Pax를 사용한 자동 체크포인트

이 섹션에서는 단일 슬라이스에서 Pax와 함께 자동 체크포인트를 설정하고 사용하는 방법의 예시를 보여줍니다. 다음과 같이 되도록 적절한 설정을 사용합니다.

  • 유지보수 이벤트가 발생할 때 체크포인트가 저장됩니다.
  • 체크포인트 저장 후 영향을 받는 TPU VM에서 Cloud TPU가 유지보수를 수행합니다.
  • Cloud TPU가 유지보수를 완료하면 일반적으로 TPU VM을 사용할 수 있습니다.
  1. TPU VM을 만들거나 큐에 추가된 리소스를 요청할 때 autocheckpoint-enabled 플래그를 사용합니다.

    예를 들면 다음과 같습니다.

    1. 환경 변수를 설정합니다.

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      환경 변수 설명

      변수 설명
      PROJECT_ID Google Cloud 프로젝트 ID입니다. 기존 프로젝트를 사용하거나 새 프로젝트를 만듭니다.
      TPU_NAME TPU의 이름입니다.
      ZONE TPU VM을 만들 영역입니다. 지원되는 영역에 대한 자세한 내용은 TPU 리전 및 영역을 참조하세요.
      ACCELERATOR_TYPE 가속기 유형은 만들려는 Cloud TPU의 버전과 크기를 지정합니다. 각 TPU 버전에서 지원되는 가속기 유형에 대한 자세한 내용은 TPU 버전을 참조하세요.
      RUNTIME_VERSION Cloud TPU 소프트웨어 버전입니다.

    2. 활성 구성에서 프로젝트 ID 및 영역을 설정합니다.

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. TPU를 만듭니다.

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. SSH를 사용하여 TPU에 연결합니다.

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. 단일 슬라이스에 Pax 설치

    자동 체크포인트 기능은 Pax 버전 1.1.0 이상에서 작동합니다. TPU VM에서 jax[tpu] 및 최신 paxml을 설치합니다.

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. LmCloudSpmd2B 모델을 구성합니다. 학습 스크립트를 실행하기 전에 ICI_MESH_SHAPE[1, 8, 1]로 변경합니다.

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. 적절한 구성으로 학습을 실행합니다.

    다음 예시에서는 자동 체크포인트로 트리거된 체크포인트를 Cloud Storage 버킷에 저장하도록 LmCloudSpmd2B 모델을 구성하는 방법을 보여줍니다. your-storage-bucket을 기존 버킷의 이름으로 바꾸거나 새 버킷을 만듭니다.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    명령어에 전달되는 다음 두 가지 플래그에 주의하세요.

    • jax_fully_async_checkpoint: 이 플래그를 설정하면 orbax.checkpoint.AsyncCheckpointer가 사용됩니다. AsyncCheckpointer 클래스는 학습 스크립트에 SIGTERM 신호가 수신될 때 자동으로 체크포인트를 저장합니다.
    • exit_after_ondemand_checkpoint: 이 플래그를 설정하면 자동 체크포인트가 성공적으로 저장된 후 TPU 프로세스가 종료되고, 유지보수가 즉시 수행되도록 트리거됩니다. 이 플래그를 사용하지 않으면 체크포인트가 저장된 후에도 학습이 계속되고 Cloud TPU가 시간 초과(5분)가 발생할 때까지 기다린 후에 필요한 유지보수를 수행합니다.

Orbax에서의 자동 체크포인트

자동 체크포인트 기능은 MaxText 또는 Pax로 제한되지 않습니다. SIGTERM 신호를 캡처하고 체크포인트 지정 프로세스를 시작할 수 있는 모든 프레임워크가 자동 체크포인트로 제공되는 인프라를 지원합니다. JAX 사용자를 위한 일반적인 유틸리티 라이브러리를 제공하는 네임스페이스인 Orbax에서도 이러한 기능이 제공됩니다.

Orbax 문서에 설명된 대로 이러한 기능은 orbax.checkpoint.CheckpointManager 사용자에게 기본적으로 사용 설정되어 있습니다. 모든 단계에서 유지보수 이벤트가 임박했는지 여부를 자동으로 확인한 후 호출되는 save 메서드는 단계 번호가 save_interval_steps의 배수가 아니더라도 체크포인트를 저장합니다. 또한 GitHub 문서에서는 사용자 코드 수정과 함께 자동 체크포인트 저장 후 학습을 종료하는 방법을 설명합니다.