Run a calculation on a Cloud TPU VM using PyTorch

This quickstart shows you how to create a Cloud TPU, install PyTorch and run a simple calculation on a Cloud TPU. For a more in depth tutorial showing you how to train a model on a Cloud TPU see one of the Cloud TPU PyTorch Tutorials.

Before you begin

Before you follow this quickstart, you must create a Google Cloud Platform account, install the Google Cloud CLI. and configure the gcloud command. For more information, see Set up an account and a Cloud TPU project.

Create a Cloud TPU with gcloud

To create a TPU VM in the default user project, network, and zone run:

$ gcloud compute tpus tpu-vm create tpu-name \
   --zone=us-central1-a \
   --accelerator-type=v3-8 \
   --version=tpu-ubuntu2204-base

Command flag descriptions

zone
The zone where you plan to create your Cloud TPU.
accelerator-type
The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
version
The Cloud TPU software version.

While creating your TPU, you can pass the additional --network and --subnetwork flags if you want to specify the default network and subnetwork. If you do not want to use the default network, you must pass the --network flag. The --subnetwork flag is optional and can be used to specify a default subnetwork for whatever network you are using (default or user-specified). See the gcloud API reference page for details on these flags.

Connect to your Cloud TPU VM

   $ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central1-a

Install PyTorch/XLA on your TPU VM

   (vm)$ pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
   

Set TPU runtime configuration

Ensure that the PyTorch/XLA runtime uses the TPU.

   (vm) $ export PJRT_DEVICE=TPU

Verify PyTorch can access TPUs

  1. Create a file named tpu-count.py in the current directory and copy and paste the following script into it.

    import torch
    import torch_xla.core.xla_model as xm
    print(f'PyTorch can access {len(torch_xla.devices())} TPU cores')
    
  2. Run the script:

    (vm)$ python3 tpu-count.py

    Output from the script shows the result of the computation:

    PyTorch can access 8 TPU cores
    

Perform a basic calculation

  1. Create a file named tpu-test.py in the current directory and copy and paste the following script into it.

    import torch
    import torch_xla.core.xla_model as xm
    
    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)
    
  2. Run the script:

      (vm)$ python3 tpu-test.py

    Output from the script shows the result of the computation:

    tensor([[-0.2121,  1.5589, -0.6951],
            [-0.7886, -0.2022,  0.9242],
            [ 0.8555, -1.8698,  1.4333]], device='xla:1')
    

Clean up

To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.

  1. Disconnect from the Compute Engine instance, if you have not already done so:

    (vm)$ exit

    Your prompt should now be username@projectname, showing you are in the Cloud Shell.

  2. Delete your Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central1-a

The output of this command should confirm that your TPU has been deleted.

What's next

Read more about Cloud TPU VMs: