TPU 슬라이스에서 JAX 코드 실행

이 문서의 명령어를 실행하기 전 계정 및 Cloud TPU 프로젝트 설정의 안내를 따르도록 유의하세요.

단일 TPU 보드에서 JAX 코드를 실행한 후에는 TPU 슬라이스에서 실행하여 코드를 수직 확장할 수 있습니다. TPU 슬라이스는 전용 고속 네트워크 연결을 통해 서로 연결된 여러 TPU 보드입니다. 이 문서에서는 TPU 슬라이스에서 JAX 코드 실행에 관한 내용을 소개합니다. 더 자세한 내용은 다중 호스트 및 다중 프로세스 환경에서 JAX 사용을 참조하세요.

Cloud TPU 슬라이스 만들기

  1. 다음과 같이 몇 가지 환경 변수를 만듭니다.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5litepod-32
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    환경 변수 설명

    변수 설명
    PROJECT_ID Google Cloud 프로젝트 ID입니다. 기존 프로젝트를 사용하거나 새 프로젝트를 만듭니다.
    TPU_NAME TPU의 이름입니다.
    ZONE TPU VM을 만들 영역입니다. 지원되는 영역에 대한 자세한 내용은 TPU 리전 및 영역을 참조하세요.
    ACCELERATOR_TYPE 가속기 유형은 만들려는 Cloud TPU의 버전과 크기를 지정합니다. 각 TPU 버전에서 지원되는 가속기 유형에 대한 자세한 내용은 TPU 버전을 참조하세요.
    RUNTIME_VERSION Cloud TPU 소프트웨어 버전입니다.

  2. gcloud 명령어를 사용하여 TPU 슬라이스를 만듭니다. 예를 들어 v5litepod-32 슬라이스를 만들려면 다음 명령어를 사용합니다.

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE}  \
        --version=${RUNTIME_VERSION} 

슬라이스에 JAX 설치

TPU 슬라이스를 만든 후 TPU 슬라이스의 모든 호스트에서 JAX를 설치해야 합니다. --worker=all--commamnd 파라미터를 사용해서 gcloud compute tpus tpu-vm ssh 명령어를 사용하여 이 작업을 수행할 수 있습니다.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

슬라이스에서 JAX 코드 실행

TPU 슬라이스에서 JAX 코드를 실행하려면 TPU 슬라이스의 각 호스트에서 코드를 실행해야 합니다. jax.device_count() 호출은 슬라이스의 각 호스트에서 호출될 때까지 응답을 중지합니다. 다음 예시에서는 TPU 슬라이스에서 JAX 계산을 실행하는 방법을 보여줍니다.

코드 준비

344.0.0 이상 gcloud 버전이 필요합니다(scp 명령어). gcloud --version을 사용하여 gcloud 버전을 확인하고 필요하면 gcloud components upgrade를 실행합니다.

다음 코드를 사용하여 example.py라는 파일을 만듭니다.


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py를 슬라이스의 모든 TPU 워커 VM에 복사

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

이전에 scp 명령어를 사용하지 않았으면 다음과 비슷한 오류가 표시될 수 있습니다.

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

오류를 해결하려면 오류 메시지에 표시된 대로 ssh-add 명령어를 실행하고 명령어를 다시 실행합니다.

슬라이스에서 코드 실행

모든 VM에서 example.py 프로그램을 실행합니다.

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

출력(v5litepod-32 슬라이스에서 생성됨):

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

삭제

TPU VM 사용이 완료되었으면 다음 단계에 따라 리소스를 삭제하세요.

  1. Cloud TPU 및 Compute Engine 리소스를 삭제합니다.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. gcloud compute tpus execution-groups list를 실행하여 리소스가 삭제되었는지 확인합니다. 삭제하는 데 몇 분 정도 걸릴 수 있습니다. 다음 명령어의 출력에는 이 튜토리얼에서 만든 리소스가 포함되어서는 안 됩니다.

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}