Run a calculation on a Cloud TPU VM using PyTorch

This document provides a brief introduction to working with PyTorch and Cloud TPU.

Before you begin

Before running the commands in this document, you must create a Google Cloud account, install the Google Cloud CLI, and configure the gcloud command. For more information, see Set up the Cloud TPU environment.

Create a Cloud TPU using gcloud

  1. Define some environment variables to make the commands easier to use.

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5litepod-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export TPU_NAME=your-tpu-name
    Your Google Cloud project ID.
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
    The zone where you plan to create your Cloud TPU.
    The Cloud TPU runtime version.
    The user-assigned name for your Cloud TPU.
  2. Create your TPU VM by running the following command:

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \

Connect to your Cloud TPU VM

Connect to your TPU VM over SSH using the following command:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
    --project=$PROJECT_ID \

If you fail to connect to a TPU VM using SSH, it might be because the TPU VM doesn't have an external IP address. To access a TPU VM without an external IP address, follow the instructions in Connect to a TPU VM without a public IP address.

Install PyTorch/XLA on your TPU VM

$ (vm) sudo apt-get update
$ (vm) sudo apt-get install libopenblas-dev -y
$ (vm) pip install numpy
$ (vm) pip install torch torch_xla[tpu]~=2.6.0 -f

Verify PyTorch can access TPUs

Use the following command to verify PyTorch can access your TPUs.

$ (vm) PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"

The output from the command should look like the following:

['xla:0', 'xla:1', 'xla:2', 'xla:3']

Perform a basic calculation

  1. Create a file named in the current directory and copy and paste the following script into it.

    import torch
    import torch_xla.core.xla_model as xm
    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)
  2. Run the script:

    (vm)$ PJRT_DEVICE=TPU python3

    The output from the script shows the result of the computation:

    tensor([[-0.2121,  1.5589, -0.6951],
            [-0.7886, -0.2022,  0.9242],
            [ 0.8555, -1.8698,  1.4333]], device='xla:1')

Clean up

To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.

  1. Disconnect from the Cloud TPU instance, if you have not already done so:

    (vm)$ exit

    Your prompt should now be username@projectname, showing you are in the Cloud Shell.

  2. Delete your Cloud TPU.

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
        --project=$PROJECT_ID \
  3. Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.

    $ gcloud compute tpus tpu-vm list \

What's next

Read more about Cloud TPU VMs: