TPU Pod スライスでの JAX コードの実行
JAX コードを単一の TPU ボードで実行したら、TPU Pod スライスで実行してコードをスケールアップできます。TPU Pod スライスは、専用の高速ネットワーク接続で互いに接続された複数の TPU ボードです。このドキュメントでは、TPU Pod スライスでの JAX コードの実行についての概要を説明します。詳しくは、マルチホスト環境とマルチプロセス環境での JAX の使用をご覧ください。
データ ストレージにマウントされた NFS を使用する場合は、Pod スライス内のすべての TPU VM に OS Login を設定する必要があります。詳細については、データ ストレージに NFS を使用するをご覧ください。環境の設定
Cloud Shell で次のコマンドを実行して、
gcloud
の最新バージョンを実行していることを確認します。$ gcloud components update
gcloud
をインストールする必要がある場合は、次のコマンドを使用します。$ sudo apt install -y google-cloud-sdk
いくつかの環境変数を作成します。
$ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-32
TPU Pod スライスの作成
このドキュメントのコマンドを実行する前に、アカウントと Cloud TPU プロジェクトを設定するの手順に従ってください。 ローカルマシンで次のコマンドを実行します。
gcloud
コマンドを使用して TPU Pod スライスを作成します。たとえば、v4-32 Pod スライスを作成するには、次のコマンドを使用します。
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
Pod スライスに JAX をインストールする
TPU Pod スライスを作成したら、TPU Pod スライスのすべてのホストに JAX をインストールする必要があります。--worker=all
オプションを使用すると、1 つのコマンドですべてのホストに JAX をインストールできます。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --worker=all \ --command="pip install \ --upgrade 'jax[tpu]>0.3.0' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
Pod スライスで JAX コードを実行する
TPU Pod スライスで JAX コードを実行するには、TPU Pod スライスの各ホストでコードを実行する必要があります。jax.device_count()
呼び出しは、Pod スライスの各ホストで呼び出されるまで応答しなくなります。次の例は、TPU Pod スライスで JAX 計算を実行する方法を示しています。
コードの準備
scp
コマンドでは、gcloud
バージョン 344.0.0 以降が必要です。
gcloud --version
を使用して gcloud
のバージョンを確認し、必要に応じて gcloud components upgrade
を実行します。
次のコードを使用して、example.py
という名前のファイルを作成します。
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Pod スライス内のすべての TPU ワーカー VM に example.py
をコピーする
$ gcloud compute tpus tpu-vm scp example.py ${TPU_NAME} \ --worker=all \ --zone=${ZONE}
以前に scp
コマンドを使用したことがない場合、次のようなエラーが表示されることがあります。
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
このエラーを解決するには、エラー メッセージに示されている ssh-add
コマンドを実行してから、コマンドを再実行します。
Pod スライスでコードを実行する
すべての VM で example.py
プログラムを起動します。
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --worker=all \ --command="python3 example.py"
出力(v4-32 Pod スライスで生成)
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
クリーンアップ
TPU VM の使用を終了したら、次の手順に沿ってリソースをクリーンアップします。
Compute Engine インスタンスとの接続を切断していない場合は切断します。
(vm)$ exit
プロンプトが
username@projectname
に変わります。これは、現在、Cloud Shell 内にいることを示しています。Cloud TPU と Compute Engine リソースを削除します。
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE}
gcloud compute tpus execution-groups list
を実行して、リソースが削除されたことを確認します。削除には数分かかることがあります。次のコマンドの出力には、このチュートリアルで作成したリソースを含めないでください。$ gcloud compute tpus tpu-vm list --zone=${ZONE}