TPU スライスで JAX コードを実行する
このドキュメントのコマンドを実行する前に、アカウントと Cloud TPU プロジェクトを設定するの説明に従ってください。
JAX コードを単一の TPU ボードで実行したら、TPU スライスで実行してコードをスケールアップできます。TPU スライスは、専用の高速ネットワーク接続で互いに接続された複数の TPU ボードです。このドキュメントでは、TPU スライスでの JAX コードの実行についての概要を説明します。詳しくは、マルチホスト環境とマルチプロセス環境での JAX の使用をご覧ください。
Cloud TPU スライスを作成する
いくつかの環境変数を作成します。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
環境変数の説明
変数 説明 PROJECT_ID
実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します。 TPU_NAME
TPU の名前。 ZONE
TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。 ACCELERATOR_TYPE
アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。 RUNTIME_VERSION
Cloud TPU ソフトウェアのバージョン。 gcloud
コマンドを使用して TPU スライスを作成します。たとえば、v5litepod-32 スライスを作成するには、次のコマンドを使用します。$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
スライスに JAX をインストールする
TPU スライスを作成したら、TPU スライスのすべてのホストに JAX をインストールする必要があります。これを行うには、--worker=all
パラメータと --commamnd
パラメータを使用して gcloud compute tpus tpu-vm ssh
コマンドを実行します。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
スライスで JAX コードを実行する
TPU スライスで JAX コードを実行するには、TPU スライスの各ホストでコードを実行する必要があります。jax.device_count()
呼び出しは、スライスの各ホストで呼び出されるまで応答しなくなります。次の例は、TPU スライスで JAX 計算を実行する方法を示しています。
コードを準備する
scp
コマンドでは、gcloud
バージョン 344.0.0 以降が必要です。gcloud --version
を使用して gcloud
のバージョンを確認し、必要に応じて gcloud components upgrade
を実行します。
次のコードを含むファイルを example.py
という名前で作成します。
import jax
# The total number of TPU cores in the slice
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
スライス内のすべての TPU ワーカー VM に example.py
をコピーする
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
以前に scp
コマンドを使用したことがない場合、次のようなエラーが表示されることがあります。
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
このエラーを解決するには、エラー メッセージに示されている ssh-add
コマンドを実行してから、コマンドを再実行します。
スライスでコードを実行する
すべての VM で example.py
プログラムを起動します。
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
出力(v5litepod-32 スライスで生成)
global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]
クリーンアップ
TPU VM の使用を終了したら、次の手順に沿ってリソースをクリーンアップします。
Cloud TPU リソースと Compute Engine リソースを削除します。
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
gcloud compute tpus execution-groups list
を実行して、リソースが削除されたことを確認します。削除には数分かかることがあります。次のコマンドの出力には、このチュートリアルで作成したリソースが含まれていないはずです。$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}