PyTorch を使用して Cloud TPU VM で計算を実行する

このドキュメントでは、PyTorch と Cloud TPU の使用について簡単に説明します。

始める前に

このドキュメントのコマンドを実行する前に、 Google Cloud アカウントを作成し、Google Cloud CLI をインストールして、gcloud コマンドを構成する必要があります。詳細については、Cloud TPU 環境を設定するをご覧ください。

gcloud を使用して Cloud TPU を作成する

  1. コマンドを使いやすくするため、いくつかの環境変数を定義します。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east5-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    環境変数の説明

    変数 説明
    PROJECT_ID 実際の Google Cloud のプロジェクト ID。既存のプロジェクトを使用するか、新しいプロジェクトを作成します
    TPU_NAME TPU の名前。
    ZONE TPU VM を作成するゾーン。サポートされているゾーンの詳細については、TPU のリージョンとゾーンをご覧ください。
    ACCELERATOR_TYPE アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン

  2. 次のコマンドを実行して TPU VM を作成します。

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$RUNTIME_VERSION

Cloud TPU VM に接続する

次のコマンドを使用して、SSH 経由で TPU VM に接続します。

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE

SSH を使用して TPU VM に接続できない場合は、TPU VM に外部 IP アドレスがないことが原因である可能性があります。外部 IP アドレスのない TPU VM にアクセスするには、パブリック IP アドレスを持たない TPU VM に接続するの説明に従ってください。

TPU VM に PyTorch/XLA をインストールする

$ (vm) sudo apt-get update
$ (vm) sudo apt-get install libopenblas-dev -y
$ (vm) pip install numpy
$ (vm) pip install torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

PyTorch が TPU にアクセスできることを確認する

次のコマンドを使用して、PyTorch が TPU にアクセスできることを確認します。

$ (vm) PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"

コマンドの出力は次のようになります。

['xla:0', 'xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7']

基本的な計算を行う

  1. 現在のディレクトリに tpu-test.py という名前のファイルを作成し、次のスクリプトをコピーして貼り付けます。

    import torch
    import torch_xla.core.xla_model as xm
    
    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)
    
  2. スクリプトを実行します。

    (vm)$ PJRT_DEVICE=TPU python3 tpu-test.py

    スクリプトからの出力に、計算の結果が示されます。

    tensor([[-0.2121,  1.5589, -0.6951],
            [-0.7886, -0.2022,  0.9242],
            [ 0.8555, -1.8698,  1.4333]], device='xla:1')
    

クリーンアップ

このページで使用したリソースについて、 Google Cloud アカウントに課金されないようにするには、次の手順を実施します。

  1. Cloud TPU インスタンスとの接続を切断していない場合は切断します。

    (vm)$ exit

    プロンプトが username@projectname に変わります。これは、現在、Cloud Shell 内にいることを示しています。

  2. Cloud TPU を削除します。

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE
  3. 次のコマンドを実行して、リソースが削除されたことを確認します。TPU がリストに表示されないことを確認します。削除には数分かかることがあります。

    $ gcloud compute tpus tpu-vm list \
        --zone=$ZONE

次のステップ

Cloud TPU VM の詳細を確認する。