TPU スライスで JAX コードを実行する

このドキュメントのコマンドを実行する前に、アカウントと Cloud TPU プロジェクトを設定するの説明に従ってください。

JAX コードを単一の TPU ボードで実行したら、TPU スライスで実行してコードをスケールアップできます。TPU スライスは、専用の高速ネットワーク接続で互いに接続された複数の TPU ボードです。このドキュメントでは、TPU スライスでの JAX コードの実行についての概要を説明します。詳しくは、マルチホスト環境とマルチプロセス環境での JAX の使用をご覧ください。

Cloud TPU スライスを作成する

  1. いくつかの環境変数を作成します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5litepod-32
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン

  2. gcloud コマンドを使用して TPU スライスを作成します。たとえば、v5litepod-32 スライスを作成するには、次のコマンドを使用します。

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE}  \
        --version=${RUNTIME_VERSION} 

スライスに JAX をインストールする

TPU スライスを作成したら、TPU スライスのすべてのホストに JAX をインストールする必要があります。これを行うには、--worker=all パラメータと --commamnd パラメータを使用して gcloud compute tpus tpu-vm ssh コマンドを実行します。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

スライスで JAX コードを実行する

TPU スライスで JAX コードを実行するには、TPU スライスの各ホストでコードを実行する必要があります。jax.device_count() 呼び出しは、スライスの各ホストで呼び出されるまで応答しなくなります。次の例は、TPU スライスで JAX 計算を実行する方法を示しています。

コードを準備する

scp コマンドでは、gcloud バージョン 344.0.0 以降が必要です。gcloud --version を使用して gcloud のバージョンを確認し、必要に応じて gcloud components upgrade を実行します。

次のコードを含むファイルを example.py という名前で作成します。


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

スライス内のすべての TPU ワーカー VM に example.py をコピーする

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

以前に scp コマンドを使用したことがない場合、次のようなエラーが表示されることがあります。

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

このエラーを解決するには、エラー メッセージに示されている ssh-add コマンドを実行してから、コマンドを再実行します。

スライスでコードを実行する

すべての VM で example.py プログラムを起動します。

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

出力(v5litepod-32 スライスで生成)

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

クリーンアップ

TPU VM の使用を終了したら、次の手順に沿ってリソースをクリーンアップします。

  1. Cloud TPU リソースと Compute Engine リソースを削除します。

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. gcloud compute tpus execution-groups list を実行して、リソースが削除されたことを確認します。削除には数分かかることがあります。次のコマンドの出力には、このチュートリアルで作成したリソースが含まれていないはずです。

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}