TPU 슬라이스에서 JAX 코드 실행
이 문서의 명령어를 실행하기 전 계정 및 Cloud TPU 프로젝트 설정의 안내를 따르도록 유의하세요.
단일 TPU 보드에서 JAX 코드를 실행한 후에는 TPU 슬라이스에서 실행하여 코드를 수직 확장할 수 있습니다. TPU 슬라이스는 전용 고속 네트워크 연결을 통해 서로 연결된 여러 TPU 보드입니다. 이 문서에서는 TPU 슬라이스에서 JAX 코드 실행에 관한 내용을 소개합니다. 더 자세한 내용은 다중 호스트 및 다중 프로세스 환경에서 JAX 사용을 참조하세요.
Cloud TPU 슬라이스 만들기
다음과 같이 몇 가지 환경 변수를 만듭니다.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
환경 변수 설명
변수 설명 PROJECT_ID
Google Cloud 프로젝트 ID입니다. 기존 프로젝트를 사용하거나 새 프로젝트를 만듭니다. TPU_NAME
TPU의 이름입니다. ZONE
TPU VM을 만들 영역입니다. 지원되는 영역에 대한 자세한 내용은 TPU 리전 및 영역을 참조하세요. ACCELERATOR_TYPE
가속기 유형은 만들려는 Cloud TPU의 버전과 크기를 지정합니다. 각 TPU 버전에서 지원되는 가속기 유형에 대한 자세한 내용은 TPU 버전을 참조하세요. RUNTIME_VERSION
Cloud TPU 소프트웨어 버전입니다. gcloud
명령어를 사용하여 TPU 슬라이스를 만듭니다. 예를 들어 v5litepod-32 슬라이스를 만들려면 다음 명령어를 사용합니다.$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
슬라이스에 JAX 설치
TPU 슬라이스를 만든 후 TPU 슬라이스의 모든 호스트에서 JAX를 설치해야 합니다. --worker=all
및 --commamnd
파라미터를 사용해서 gcloud compute tpus tpu-vm ssh
명령어를 사용하여 이 작업을 수행할 수 있습니다.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
슬라이스에서 JAX 코드 실행
TPU 슬라이스에서 JAX 코드를 실행하려면 TPU 슬라이스의 각 호스트에서 코드를 실행해야 합니다. jax.device_count()
호출은 슬라이스의 각 호스트에서 호출될 때까지 응답을 중지합니다. 다음 예시에서는 TPU 슬라이스에서 JAX 계산을 실행하는 방법을 보여줍니다.
코드 준비
344.0.0 이상 gcloud
버전이 필요합니다(scp
명령어).
gcloud --version
을 사용하여 gcloud
버전을 확인하고 필요하면 gcloud components upgrade
를 실행합니다.
다음 코드를 사용하여 example.py
라는 파일을 만듭니다.
import jax
# The total number of TPU cores in the slice
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
example.py
를 슬라이스의 모든 TPU 워커 VM에 복사
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
이전에 scp
명령어를 사용하지 않았으면 다음과 비슷한 오류가 표시될 수 있습니다.
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
오류를 해결하려면 오류 메시지에 표시된 대로 ssh-add
명령어를 실행하고 명령어를 다시 실행합니다.
슬라이스에서 코드 실행
모든 VM에서 example.py
프로그램을 실행합니다.
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
출력(v5litepod-32 슬라이스에서 생성됨):
global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]
삭제
TPU VM 사용이 완료되었으면 다음 단계에 따라 리소스를 삭제하세요.
Cloud TPU 및 Compute Engine 리소스를 삭제합니다.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
gcloud compute tpus execution-groups list
를 실행하여 리소스가 삭제되었는지 확인합니다. 삭제하는 데 몇 분 정도 걸릴 수 있습니다. 다음 명령어의 출력에는 이 튜토리얼에서 만든 리소스가 포함되어서는 안 됩니다.$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}