JAX を使用して Cloud TPU VM で計算を実行する

このドキュメントでは、JAX と Cloud TPU の使用について簡単に説明します。

始める前に

このドキュメントのコマンドを実行する前に、 Google Cloudアカウントを作成し、Google Cloud CLI をインストールして、gcloud コマンドを構成する必要があります。詳細については、Cloud TPU 環境を設定するをご覧ください。

gcloud を使用して Cloud TPU VM を作成する

  1. コマンドを使いやすくするため、いくつかの環境変数を定義します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east5-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン

  2. TPU VM を作成するには、Google Cloud Shell、または Google Cloud CLI がインストールされているコンピュータ ターミナルから、次のコマンドを実行します。

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$RUNTIME_VERSION

Cloud TPU VM に接続する

次のコマンドを使用して、SSH 経由で TPU VM に接続します。

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE

SSH を使用して TPU VM に接続できない場合は、TPU VM に外部 IP アドレスがないことが原因である可能性があります。外部 IP アドレスのない TPU VM にアクセスするには、パブリック IP アドレスを持たない TPU VM に接続するの手順を行ってください。

Cloud TPU VM に JAX をインストールする

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

システム チェック

JAX が TPU にアクセスできること、基本オペレーションを実行できることを確認します。

  1. Python 3 インタプリタを起動します。

    (vm)$ python3
    >>> import jax
  2. 使用可能な TPU コアの数を表示します。

    >>> jax.device_count()

TPU コアの数が表示されます。表示されるコア数は、使用している TPU のバージョンによって異なります。詳細については、TPU のバージョンをご覧ください。

計算を実行する

>>> jax.numpy.add(1, 1)

numpy の加算の結果が表示されます。

コマンドからの出力は次のようになります。

Array(2, dtype=int32, weak_type=True)

Python インタプリタを終了する

>>> exit()

TPU VM で JAX コードを実行する

これからは、任意の JAX コードを実行できます。Flax の例は、JAX で標準の ML モデルの実行を開始するのに適しています。たとえば、基本的な MNIST 畳み込みネットワークをトレーニングする場合は、次の手順を行います。

  1. Flax の依存関係のサンプルをインストールします。

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Flax をインストールします。

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Flax MNIST トレーニング スクリプトを実行します。

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
        --config=configs/default.py \
        --config.learning_rate=0.05 \
        --config.num_epochs=5

スクリプトによりデータセットがダウンロードされ、トレーニングが開始されます。スクリプトの出力は次のようになります。

I0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

クリーンアップ

このページで使用したリソースについて、 Google Cloud アカウントに課金されないようにするには、次の手順を実施します。

TPU VM の使用を終了したら、次の手順でリソースをクリーンアップします。

  1. Cloud TPU インスタンスとの接続を切断していない場合は切断します。

    (vm)$ exit

    プロンプトが username@projectname と表示され、Cloud Shell 内にいることが示されます。

  2. Cloud TPU を削除します。

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE
  3. 次のコマンドを実行して、リソースが削除されたことを確認します。TPU がリストに表示されないことを確認します。削除には数分かかることがあります。

    $ gcloud compute tpus tpu-vm list \
        --zone=$ZONE

パフォーマンスに関する注意事項

ここでは、特に JAX での TPU の使用に関連する重要事項をいくつか説明します。

パディング

TPU のパフォーマンスが低下する最も一般的な原因の一つは、誤ったパディングです。

  • Cloud TPU 内の配列はタイル状になっています。そのためには、ディメンションのいずれかを 8 の倍数にパディングし、別のディメンションを 128 の倍数にパディングします。
  • 行列乗算ユニットは、パディングの必要性を最小限に抑える大規模な行列のペアで最も高いパフォーマンスを発揮します。

bfloat16 dtype

デフォルトでは、TPU 上の JAX の行列乗算は、float32 累積と bfloat16 を使用します。これは、関連する jax.numpy 関数呼び出し(matmul、dot、einsum など)の precision 引数で制御できます。具体例は次のとおりです。

  • precision=jax.lax.Precision.DEFAULT: bfloat16 の混合精度を使用する(最速)
  • precision=jax.lax.Precision.HIGH: 複数の MXU パスを使用して精度を高める
  • precision=jax.lax.Precision.HIGHEST: さらに多くの MXU パスを使用して、フル精度の float32 を実現する

また、JAX では bfloat16 dtype も追加されます。これを使って、配列を明示的に bfloat16 にキャストできます。例: jax.numpy.array(x, dtype=jax.numpy.bfloat16)

次のステップ

Cloud TPU の詳細については、以下をご覧ください。