Run a calculation on a Cloud TPU VM using PyTorch
This document provides a brief introduction to working with PyTorch and Cloud TPU.
Before you begin
Before running the commands in this document, you must create a Google Cloud account,
install the Google Cloud CLI, and configure the gcloud
command. For more
information, see Set up the Cloud TPU environment.
Create a Cloud TPU using gcloud
Define some environment variables to make the commands easier to use.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
Environment variable descriptions
Variable Description PROJECT_ID
Your Google Cloud project ID. Use an existing project or create a new one. TPU_NAME
The name of the TPU. ZONE
The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones. ACCELERATOR_TYPE
The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions. RUNTIME_VERSION
The Cloud TPU software version. Create your TPU VM by running the following command:
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Connect to your Cloud TPU VM
Connect to your TPU VM over SSH using the following command:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
If you fail to connect to a TPU VM using SSH, it might be because the TPU VM doesn't have an external IP address. To access a TPU VM without an external IP address, follow the instructions in Connect to a TPU VM without a public IP address.
Install PyTorch/XLA on your TPU VM
$ (vm) sudo apt-get update $ (vm) sudo apt-get install libopenblas-dev -y $ (vm) pip install numpy $ (vm) pip install torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
Verify PyTorch can access TPUs
Use the following command to verify PyTorch can access your TPUs:
$ (vm) PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
The output from the command should look like the following:
['xla:0', 'xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7']
Perform a basic calculation
Create a file named
tpu-test.py
in the current directory and copy and paste the following script into it:import torch import torch_xla.core.xla_model as xm dev = xm.xla_device() t1 = torch.randn(3,3,device=dev) t2 = torch.randn(3,3,device=dev) print(t1 + t2)
Run the script:
(vm)$ PJRT_DEVICE=TPU python3 tpu-test.py
The output from the script shows the result of the computation:
tensor([[-0.2121, 1.5589, -0.6951], [-0.7886, -0.2022, 0.9242], [ 0.8555, -1.8698, 1.4333]], device='xla:1')
Clean up
To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.
Disconnect from the Cloud TPU instance, if you have not already done so:
(vm)$ exit
Your prompt should now be
username@projectname
, showing you are in the Cloud Shell.Delete your Cloud TPU.
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
What's next
Read more about Cloud TPU VMs: