JAX를 사용하여 Cloud TPU VM에서 계산 실행
이 문서에서는 JAX 및 Cloud TPU 작업에 대해 간략히 안내합니다.
시작하기 전에
이 문서의 명령어를 실행하기 전에 Google Cloud계정을 만들고 Google Cloud CLI를 설치하고 gcloud
명령어를 구성해야 합니다. 자세한 내용은 Cloud TPU 환경 설정을 참조하세요.
gcloud
를 사용하여 Cloud TPU VM 만들기
명령어를 더 쉽게 사용할 수 있도록 몇 가지 환경 변수를 정의합니다.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
환경 변수 설명
변수 설명 PROJECT_ID
Google Cloud 프로젝트 ID입니다. 기존 프로젝트를 사용하거나 새 프로젝트를 만듭니다. TPU_NAME
TPU의 이름입니다. ZONE
TPU VM을 만들 영역입니다. 지원되는 영역에 대한 자세한 내용은 TPU 리전 및 영역을 참조하세요. ACCELERATOR_TYPE
가속기 유형은 만들려는 Cloud TPU의 버전과 크기를 지정합니다. 각 TPU 버전에서 지원되는 가속기 유형에 대한 자세한 내용은 TPU 버전을 참조하세요. RUNTIME_VERSION
Cloud TPU 소프트웨어 버전입니다. Cloud Shell 또는 Google Cloud CLI가 설치된 컴퓨터 터미널에서 다음 명령어를 실행하여 TPU VM을 만듭니다.
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Cloud TPU VM에 연결
다음 명령어를 사용하여 SSH를 통해 TPU VM에 연결합니다.
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
SSH를 사용하여 TPU VM에 연결할 수 없는 경우 TPU VM에 외부 IP 주소가 없기 때문일 수 있습니다. 외부 IP 주소가 없는 TPU VM에 액세스하려면 공개 IP 주소가 없는 TPU VM에 연결의 안내를 따르세요.
Cloud TPU VM에 JAX 설치
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
시스템 확인
JAX가 TPU에 액세스할 수 있고 기본 작업을 실행할 수 있는지 확인합니다.
Python 3 인터프리터 시작:
(vm)$ python3
>>> import jax
사용 가능한 TPU 코어 수 표시:
>>> jax.device_count()
TPU 코어 수가 표시됩니다. 표시되는 코어 수는 사용 중인 TPU 버전에 따라 다릅니다. 자세한 내용은 TPU 버전을 참조하세요.
계산 수행
>>> jax.numpy.add(1, 1)
numpy add의 결과가 표시됩니다.
명령어에서 출력합니다.
Array(2, dtype=int32, weak_type=True)
Python 인터프리터 종료
>>> exit()
TPU VM에서 JAX 코드 실행
이제 원하는 JAX 코드를 실행할 수 있습니다. Flax 예시는 JAX에서 표준 ML 모드 실행을 시작할 수 있는 훌륭한 장소입니다. 예를 들어 기본 MNIST 컨볼루션 네트워크를 학습시키려면 다음 안내를 따르세요.
Flax 예시 종속 항목 설치
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Flax를 설치합니다.
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Flax MNIST 학습 스크립트를 실행합니다.
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
스크립트가 데이터 세트를 다운로드하고 학습을 시작합니다. 스크립트 출력은 다음과 같아야 합니다.
I0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
삭제
이 페이지에서 사용한 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 다음 단계를 수행합니다.
TPU VM 사용이 완료되었으면 다음 단계에 따라 리소스를 삭제하세요.
Cloud TPU 인스턴스에서 아직 연결을 해제하지 않았으면 연결을 해제합니다.
(vm)$ exit
프롬프트가 username@projectname으로 바뀌면 Cloud Shell에 있는 것입니다.
Cloud TPU를 삭제합니다.
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
다음 명령어를 실행하여 리소스가 삭제되었는지 확인합니다. TPU가 더 이상 나열되지 않았는지 확인합니다. 삭제하는 데 몇 분 정도 걸릴 수 있습니다.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
성능 참고사항
다음은 JAX에서 특히 TPU 사용과 관련된 몇 가지 중요한 세부정보입니다.
패딩
TPU에서 성능이 느려지는 가장 일반적인 원인 중 하나는 의도치 않은 패딩의 도입입니다.
- Cloud TPU의 배열은 타일로 나누어집니다. 여기에는 차원 중 하나를 8의 배수로 패딩하고 다른 차원을 128의 배수로 패딩하는 작업이 수반됩니다.
- 행렬 곱셈 단위는 패딩 요구를 최소화하는 대규모 행렬의 쌍에서 효과가 가장 높습니다.
bfloat16 dtype
기본적으로 JAX에서 TPU의 행렬 곱셈에는 float32 누적의 bfloat16이 사용됩니다. 이것은 관련된 jax.numpy
함수 호출에서 정밀도 인수로 제어될 수 있습니다(matmul, dot, einsum 등). 특히 다음 옵션이 지원됩니다.
precision=jax.lax.Precision.DEFAULT
: 혼합 bfloat16 정밀도(가장 빠름) 사용precision=jax.lax.Precision.HIGH
: 여러 MXU 패스를 사용하여 정밀도 향상precision=jax.lax.Precision.HIGHEST
: 훨씬 더 많은 MXU 패스를 사용하여 전체 float32 정밀도 달성
JAX는 bfloat16 dtype도 추가하여 배열을 bfloat16
으로 명시적으로 변환하는 데 사용할 수 있으며 예를 들면 jax.numpy.array(x, dtype=jax.numpy.bfloat16)
입니다.
다음 단계
Cloud TPU에 대한 자세한 내용은 다음을 참조하세요.