TPU Pod スライスでの JAX コードの実行

JAX コードを単一の TPU ボードで実行したら、TPU Pod スライスで実行してコードをスケールアップできます。TPU Pod スライスは、専用の高速ネットワーク接続で互いに接続された複数の TPU ボードです。このドキュメントでは、TPU Pod スライスでの JAX コードの実行についての概要を説明します。詳しくは、マルチホスト環境とマルチプロセス環境での JAX の使用をご覧ください。

データ ストレージにマウントされた NFS を使用する場合は、Pod スライス内のすべての TPU VM に OS Login を設定する必要があります。詳細については、データ ストレージに NFS を使用するをご覧ください。

環境の設定

  1. Cloud Shell で次のコマンドを実行して、gcloud の最新バージョンを実行していることを確認します。

    $ gcloud components update

    gcloud をインストールする必要がある場合は、次のコマンドを使用します。

    $ sudo apt install -y google-cloud-sdk
  2. いくつかの環境変数を作成します。

    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32

TPU Pod スライスの作成

このドキュメントのコマンドを実行する前に、アカウントと Cloud TPU プロジェクトを設定するの手順に従ってください。 ローカルマシンで次のコマンドを実行します。

gcloud コマンドを使用して TPU Pod スライスを作成します。たとえば、v4-32 Pod スライスを作成するには、次のコマンドを使用します。

$ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
  --zone=${ZONE} \
  --accelerator-type=${ACCELERATOR_TYPE}  \
  --version=${RUNTIME_VERSION} 

Pod スライスに JAX をインストールする

TPU Pod スライスを作成したら、TPU Pod スライスのすべてのホストに JAX をインストールする必要があります。--worker=all オプションを使用すると、1 つのコマンドですべてのホストに JAX をインストールできます。

  gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --worker=all \
  --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Pod スライスで JAX コードを実行する

TPU Pod スライスで JAX コードを実行するには、TPU Pod スライスの各ホストでコードを実行する必要があります。jax.device_count() 呼び出しは、Pod スライスの各ホストで呼び出されるまで応答しなくなります。次の例は、TPU Pod スライスで JAX 計算を実行する方法を示しています。

コードの準備

scp コマンドでは、gcloud バージョン 344.0.0 以降が必要です。 gcloud --version を使用して gcloud のバージョンを確認し、必要に応じて gcloud components upgrade を実行します。

次のコードを使用して、example.py という名前のファイルを作成します。

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Pod スライス内のすべての TPU ワーカー VM に example.py をコピーする

$ gcloud compute tpus tpu-vm scp example.py ${TPU_NAME} \
  --worker=all \
  --zone=${ZONE}

以前に scp コマンドを使用したことがない場合、次のようなエラーが表示されることがあります。

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

このエラーを解決するには、エラー メッセージに示されている ssh-add コマンドを実行してから、コマンドを再実行します。

Pod スライスでコードを実行する

すべての VM で example.py プログラムを起動します。

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --worker=all \
  --command="python3 example.py"

出力(v4-32 Pod スライスで生成)

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

クリーンアップ

TPU VM の使用を終了したら、次の手順に沿ってリソースをクリーンアップします。

  1. Compute Engine インスタンスとの接続を切断していない場合は切断します。

    (vm)$ exit

    プロンプトが username@projectname に変わります。これは、現在、Cloud Shell 内にいることを示しています。

  2. Cloud TPU と Compute Engine リソースを削除します。

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE}
  3. gcloud compute tpus execution-groups list を実行して、リソースが削除されたことを確認します。削除には数分かかることがあります。次のコマンドの出力には、このチュートリアルで作成したリソースを含めないでください。

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}