Docker 컨테이너에서 TPU 워크로드 실행

Docker 컨테이너를 사용하면 코드와 필요한 모든 종속 항목을 배포 가능한 패키지 하나로 결합하여 애플리케이션을 쉽게 구성할 수 있습니다. TPU VM 내에서 Docker 컨테이너를 실행하여 Cloud TPU 애플리케이션 구성 및 공유를 단순화할 수 있습니다. 이 문서에서는 Cloud TPU에서 지원하는 각 ML 프레임워크에서 Docker 컨테이너를 설정하는 방법을 설명합니다.

Docker 컨테이너에서 PyTorch 모델 학습

TPU 기기

  1. Cloud TPU VM을 만듭니다.

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. SSH를 사용하여 TPU VM에 연결

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=europe-west4-a
  3. Google Cloud 사용자에게 Artifact Registry 리더 역할이 부여되었는지 확인합니다. 자세한 내용은 Artifact Registry 역할 부여를 참조하세요.

  4. 야간 PyTorch/XLA 이미지를 사용하여 TPU VM에서 컨테이너를 시작합니다.

    sudo docker run --net=host -ti --rm --name your-container-name --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. TPU 런타임 구성

    PyTorch/XLA 런타임은 PJRT와 XRT라는 두 가지 옵션이 있습니다. XRT를 사용할 이유가 없으면 PJRT를 사용하는 것이 좋습니다. 다양한 런타임 구성에 대한 자세한 내용은 PJRT 런타임 문서를 참조하세요.

    PJRT

    export PJRT_DEVICE=TPU

    XRT

    export XRT_TPU_CONFIG="localservice;0;localhost:51011"
  6. PyTorch XLA 저장소 클론

    git clone --recursive https://github.com/pytorch/xla.git
  7. ResNet50 학습

    python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

학습 스크립트가 완료되면 리소스를 삭제해야 합니다.

  1. exit를 입력하여 Docker 컨테이너를 종료합니다.
  2. exit를 입력하여 TPU VM을 종료합니다.
  3. TPU VM 삭제

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU 슬라이스

TPU 슬라이스에서 PyTorch 코드를 실행하는 경우 모든 TPU 워커에서 코드를 동시에 실행해야 합니다. 이를 위한 한 가지 방법은 gcloud compute tpus tpu-vm ssh 명령어를 --worker=all--command 플래그와 함께 사용하는 것입니다. 다음 절차에서는 각 TPU 워커를 더 쉽게 설정할 수 있도록 Docker 이미지를 만드는 방법을 보여줍니다.

  1. TPU VM 만들기

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-32 \
    --version=tpu-ubuntu2204-base
  2. Docker 그룹에 현재 사용자를 추가합니다.

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=us-central2-b \
    --worker=all \
    --command='sudo usermod -a -G docker $USER'
  3. PyTorch XLA 저장소 클론

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="git clone --recursive https://github.com/pytorch/xla.git"
  4. 모든 TPU 워커의 컨테이너에서 학습 스크립트를 실행합니다.

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="docker run --rm --privileged --net=host  -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"

    Docker 명령어 플래그:

    • --rm은 프로세스가 종료된 후 컨테이너를 삭제합니다.
    • --privileged는 TPU 기기를 컨테이너에 노출합니다.
    • --net=host는 모든 컨테이너 포트를 TPU VM에 바인딩하여 포드의 호스트 간 통신을 허용합니다.
    • -e는 환경 변수를 설정합니다.

학습 스크립트가 완료되면 리소스를 삭제해야 합니다.

다음 명령어를 사용하여 TPU VM을 삭제합니다.

gcloud compute tpus tpu-vm delete your-tpu-name \
--zone=us-central2-b

Docker 컨테이너에서 JAX 모델 학습

TPU 기기

  1. TPU VM 만들기

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. SSH를 사용하여 TPU VM에 연결

    gcloud compute tpus tpu-vm ssh your-tpu-name  --zone=europe-west4-a
  3. TPU VM에서 Docker 데몬 시작

    sudo systemctl start docker
  4. Docker 컨테이너 시작

    sudo docker run --net=host -ti --rm --name your-container-name \
    --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. JAX 설치

    pip install jax[tpu]
  6. FLAX 설치

    pip install --upgrade clu
    git clone https://github.com/google/flax.git
    pip install --user -e flax
  7. tensorflowtensorflow-dataset 패키지 설치

    pip install tensorflow
    pip install tensorflow-datasets
  8. FLAX MNIST 학습 스크립트 실행

    cd flax/examples/mnist
    python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

학습 스크립트가 완료되면 리소스를 삭제해야 합니다.

  1. exit를 입력하여 Docker 컨테이너를 종료합니다.
  2. exit를 입력하여 TPU VM을 종료합니다.
  3. TPU VM 삭제

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU 슬라이스

TPU 슬라이스에서 JAX 코드를 실행하는 경우 모든 TPU 워커에서 JAX 코드를 동시에 실행해야 합니다. 이를 위한 한 가지 방법은 gcloud compute tpus tpu-vm ssh 명령어를 --worker=all--command 플래그와 함께 사용하는 것입니다. 다음 절차에서는 각 TPU 워커를 더 쉽게 설정할 수 있도록 Docker 이미지를 만드는 방법을 보여줍니다.

  1. 현재 디렉터리에 이름이 Dockerfile인 파일을 만들고 다음 텍스트를 붙여넣습니다.

    FROM python:3.10
    RUN pip install jax[tpu]
    RUN pip install --upgrade clu
    RUN git clone https://github.com/google/flax.git
    RUN pip install --user -e flax
    RUN pip install tensorflow
    RUN pip install tensorflow-datasets
    WORKDIR ./flax/examples/mnist
  2. Artifact Registry 준비

    gcloud artifacts repositories create your-repo \
    --repository-format=docker \
    --location=europe-west4 --description="Docker repository" \
    --project=your-project
    
    gcloud artifacts repositories list \
    --project=your-project
    
    gcloud auth configure-docker europe-west4-docker.pkg.dev
  3. Docker 이미지 빌드

    docker build -t your-image-name .
  4. Artifact Registry에 푸시하기 전에 Docker 이미지에 태그를 추가합니다. Artifact Registry 작업에 대한 자세한 내용은 컨테이너 이미지로 작업하기를 참조하세요.

    docker tag your-image-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  5. Artifact Registry에 Docker 이미지 푸시

    docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  6. TPU VM 만들기

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  7. 모든 TPU 워커에서 Artifact Registry의 Docker 이미지 가져오기

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command='sudo usermod -a -G docker ${USER}'
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
  8. 모든 TPU 워커에서 컨테이너 실행

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag bash"
  9. 모든 TPU 워커에서 학습 스크립트 실행

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5"

학습 스크립트가 완료되면 리소스를 삭제해야 합니다.

  1. 모든 작업자에서 컨테이너 종료

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker kill your-container-name"
  2. TPU VM 삭제

    gcloud compute tpus tpu-vm delete your-tpu-name \
    --zone=europe-west4-a

JAX 안정화 스택을 사용하여 Docker 컨테이너에서 JAX 모델 학습

JAX 안정화 스택 기본 이미지를 사용하여 MaxTextMaxDiffusion Docker 이미지를 빌드할 수 있습니다.

JAX 안정화 스택은 JAX를 orbax, flax, optax, libtpu.so와 같은 핵심 패키지와 번들로 묶어 MaxText 및 MaxDiffusion에 일관된 환경을 제공합니다. 이러한 라이브러리는 호환성을 보장하고 MaxText 및 MaxDiffusion을 안정적으로 빌드하고 실행할 수 있는 기반을 제공하도록 테스트되었습니다. 이렇게 하면 패키지 버전 간의 비호환성으로 인한 잠재적인 충돌을 방지할 수 있습니다.

JAX 안정화 스택에는 TPU 프로그램의 컴파일, 실행, ICI 네트워크 구성을 담당하는 핵심 라이브러리인 libtpu.so의 정식 출시 버전이 포함되어 있습니다. libtpu 출시 버전은 기존에 JAX에서 사용되던 나이틀리 빌드를 대체하며, HLO/StableHLO IR에서 PJRT 수준의 검증 테스트를 통해 TPU에서의 일관된 XLA 연산 기능을 보장합니다.

JAX 안정화 스택으로 MaxText 및 MaxDiffusion Docker 이미지를 빌드하려면 docker_build_dependency_image.sh 스크립트를 실행할 때 MODE 변수를 stable_stack으로 설정하고, BASEIMAGE 변수는 사용하려는 기본 이미지로 설정합니다.

docker_build_dependency_image.shMaxDiffusion GitHub 저장소MaxText GitHub 저장소 모두에 포함되어 있습니다. 사용하려는 저장소를 클론한 후 해당 저장소에서 docker_build_dependency_image.sh 스크립트를 실행해 Docker 이미지를 빌드합니다.

git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
git clone https://github.com/AI-Hypercomputer/maxtext.git

다음 명령어는 us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1을 기본 이미지로 사용하여 MaxText 및 MaxDiffusion용 Docker 이미지를 생성합니다.

sudo bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

사용 가능한 JAX 안정화 스택 기본 이미지 목록은 Artifact Registry의 JAX 안정화 스택 이미지를 참조하세요.

다음 단계