Mantenha tudo organizado com as coleções
Salve e categorize o conteúdo com base nas suas preferências.
Executar um cálculo em uma VM da Cloud TPU usando o JAX
Este documento apresenta uma breve introdução sobre como trabalhar com o JAX e o Cloud TPU.
Antes de começar
Antes de executar os comandos neste documento, é necessário criar uma conta Google Cloud, instalar a Google Cloud CLI e configurar o comando gcloud. Para
mais informações, consulte Configurar o ambiente do Cloud TPU.
Criar uma VM do Cloud TPU usando gcloud
Defina algumas variáveis de ambiente para facilitar o uso dos comandos.
O ID do seu Google Cloud projeto. Use um projeto existente ou
crie um novo.
TPU_NAME
O nome da TPU.
ZONE
A zona em que a VM TPU será criada. Para mais informações sobre as zonas com suporte, consulte
Regiões e zonas de TPU.
ACCELERATOR_TYPE
O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer
criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte
Versões de TPU.
Se você não conseguir se conectar a uma VM da TPU usando SSH, talvez seja porque a VM da TPU
não tem um endereço IP externo. Para acessar uma VM da TPU sem um endereço IP externo, siga as instruções em Conectar-se a uma VM da TPU sem um endereço IP público.
Verifique se o JAX pode acessar o TPU e executar operações básicas:
Inicie o interpretador do Python 3:
(vm)$python3
>>>importjax
Veja o número de núcleos de TPU disponíveis:
>>>jax.device_count()
O número de núcleos de TPU é exibido. O número de cores exibidas depende
da versão da TPU que você está usando. Para mais informações, consulte Versões da TPU.
Fazer um cálculo
>>>jax.numpy.add(1,1)
O resultado da adição "numpy" é exibido:
Resultado do comando:
Array(2,dtype=int32,weak_type=True)
Sair do interpretador Python
>>>exit()
Como executar código do JAX em uma VM de TPU
Agora é possível executar qualquer código do JAX que você quiser. Os exemplos de Flax
são um ótimo lugar para começar a executar modelos padrão de ML no JAX. Por exemplo,
para treinar uma rede convolucional básica MNIST:
Execute o seguinte comando para verificar se os recursos foram excluídos. Verifique
se a TPU não está mais listada. A exclusão pode levar vários minutos.
$gcloudcomputetpustpu-vmlist\--zone=$ZONE
Notas de desempenho
Veja alguns detalhes importantes que são relevantes principalmente para usar TPUs no
JAX.
Preenchimento
Uma das causas mais comuns do desempenho lento em TPUs é o preenchimento
involuntário:
As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das
dimensões em um múltiplo de 8 e de uma dimensão diferente em um múltiplo de
128.
A unidade de multiplicação de matriz tem um melhor desempenho com pares
de matrizes grandes que minimizam a necessidade de preenchimento.
bfloat16 dtype
Por padrão, a multiplicação de matriz no JAX em TPUs usa bfloat16
com acumulação float32. Isso pode ser controlado com o argumento de precisão em
chamadas de função jax.numpy relevantes (matmul, dot, einsum etc.). Especificamente:
precision=jax.lax.Precision.DEFAULT: usa a precisão bfloat16
mista (mais rápida)
precision=jax.lax.Precision.HIGH: usa vários passes MXU para
aumentar a precisão
precision=jax.lax.Precision.HIGHEST: usa ainda mais passes MXU
para alcançar uma precisão float32 completa.
O JAX também adiciona o dtype bfloat16, que pode ser usado para transmitir matrizes explicitamente para
bfloat16. Por exemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16).
A seguir
Para mais informações sobre o Cloud TPU, consulte:
[[["Fácil de entender","easyToUnderstand","thumb-up"],["Meu problema foi resolvido","solvedMyProblem","thumb-up"],["Outro","otherUp","thumb-up"]],[["Difícil de entender","hardToUnderstand","thumb-down"],["Informações incorretas ou exemplo de código","incorrectInformationOrSampleCode","thumb-down"],["Não contém as informações/amostras de que eu preciso","missingTheInformationSamplesINeed","thumb-down"],["Problema na tradução","translationIssue","thumb-down"],["Outro","otherDown","thumb-down"]],["Última atualização 2025-08-18 UTC."],[],[],null,["# Run a calculation on a Cloud TPU VM using JAX\n=============================================\n\nThis document provides a brief introduction to working with JAX and Cloud TPU.\n| **Note:** This example shows how to run code on a v5litepod-8 (v5e) TPU which is a single-host TPU. Single-host TPUs have only 1 TPU VM. To run code on TPUs with more than one TPU VM (for example, v5litepod-16 or larger), see [Run JAX code on Cloud TPU slices](/tpu/docs/jax-pods).\n\n\nBefore you begin\n----------------\n\nBefore running the commands in this document, you must create a Google Cloud\naccount, install the Google Cloud CLI, and configure the `gcloud` command. For\nmore information, see [Set up the Cloud TPU environment](/tpu/docs/setup-gcp-account).\n\nCreate a Cloud TPU VM using `gcloud`\n------------------------------------\n\n1. Define some environment variables to make commands easier to use.\n\n\n ```bash\n export PROJECT_ID=your-project-id\n export TPU_NAME=your-tpu-name\n export ZONE=us-east5-a\n export ACCELERATOR_TYPE=v5litepod-8\n export RUNTIME_VERSION=v2-alpha-tpuv5-lite\n ``` \n\n #### Environment variable descriptions\n\n \u003cbr /\u003e\n\n2. Create your TPU VM by running the following command from a Cloud Shell or\n your computer terminal where the [Google Cloud CLI](/sdk/docs/install)\n is installed.\n\n ```bash\n $ gcloud compute tpus tpu-vm create $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE \\\n --accelerator-type=$ACCELERATOR_TYPE \\\n --version=$RUNTIME_VERSION\n ```\n\nConnect to your Cloud TPU VM\n----------------------------\n\nConnect to your TPU VM over SSH by using the following command: \n\n```bash\n$ gcloud compute tpus tpu-vm ssh $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE\n```\n\nIf you fail to connect to a TPU VM using SSH, it might be because the TPU VM\ndoesn't have an external IP address. To access a TPU VM without an external IP\naddress, follow the instructions in [Connect to a TPU VM without a public IP\naddress](/tpu/docs/tpu-iap).\n\nInstall JAX on your Cloud TPU VM\n--------------------------------\n\n```bash\n(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\n```\n\nSystem check\n------------\n\nVerify that JAX can access the TPU and can run basic operations:\n\n1. Start the Python 3 interpreter:\n\n ```bash\n (vm)$ python3\n ``` \n\n ```bash\n \u003e\u003e\u003e import jax\n ```\n2. Display the number of TPU cores available:\n\n ```bash\n \u003e\u003e\u003e jax.device_count()\n ```\n\nThe number of TPU cores is displayed. The number of cores displayed is dependent\non the TPU version you are using. For more information, see [TPU versions](/tpu/docs/system-architecture-tpu-vm#versions).\n\n### Perform a calculation\n\n```bash\n\u003e\u003e\u003e jax.numpy.add(1, 1)\n```\n\nThe result of the numpy add is displayed:\n\nOutput from the command: \n\n```bash\nArray(2, dtype=int32, weak_type=True)\n```\n\n\u003cbr /\u003e\n\n### Exit the Python interpreter\n\n```bash\n\u003e\u003e\u003e exit()\n```\n\nRunning JAX code on a TPU VM\n----------------------------\n\nYou can now run any JAX code you want. The [Flax examples](https://github.com/google/flax/tree/master/examples)\nare a great place to start with running standard ML models in JAX. For example,\nto train a basic MNIST convolutional network:\n\n1. Install Flax examples dependencies:\n\n ```bash\n (vm)$ pip install --upgrade clu\n (vm)$ pip install tensorflow\n (vm)$ pip install tensorflow_datasets\n ```\n2. Install Flax:\n\n ```bash\n (vm)$ git clone https://github.com/google/flax.git\n (vm)$ pip install --user flax\n ```\n3. Run the Flax MNIST training script:\n\n ```bash\n (vm)$ cd flax/examples/mnist\n (vm)$ python3 main.py --workdir=/tmp/mnist \\\n --config=configs/default.py \\\n --config.learning_rate=0.05 \\\n --config.num_epochs=5\n ```\n\nThe script downloads the dataset and starts training. The script output should\nlook like this: \n\n```bash\nI0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88\nI0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72\nI0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04\nI0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15\nI0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18\n```\n\n\nClean up\n--------\n\n\nTo avoid incurring charges to your Google Cloud account for\nthe resources used on this page, follow these steps.\n\nWhen you are done with your TPU VM, follow these steps to clean up your resources.\n\n1. Disconnect from the Cloud TPU instance, if you have not already done so:\n\n ```bash\n (vm)$ exit\n ```\n\n Your prompt should now be username@projectname, showing you are in the Cloud Shell.\n2. Delete your Cloud TPU:\n\n ```bash\n $ gcloud compute tpus tpu-vm delete $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE\n ```\n3. Verify the resources have been deleted by running the following command. Make\n sure your TPU is no longer listed. The deletion might take several minutes.\n\n ```bash\n $ gcloud compute tpus tpu-vm list \\\n --zone=$ZONE\n ```\n\nPerformance notes\n-----------------\n\nHere are a few important details that are particularly relevant to using TPUs in\nJAX.\n\n### Padding\n\nOne of the most common causes for slow performance on TPUs is introducing\ninadvertent padding:\n\n- Arrays in the Cloud TPU are tiled. This entails padding one of the dimensions to a multiple of 8, and a different dimension to a multiple of 128.\n- The matrix multiplication unit performs best with pairs of large matrices that minimize the need for padding.\n\n### bfloat16 dtype\n\nBy default, matrix multiplication in JAX on TPUs uses [bfloat16](/tpu/docs/bfloat16)\nwith float32 accumulation. This can be controlled with the precision argument on\nrelevant `jax.numpy` function calls (matmul, dot, einsum, etc). In particular:\n\n- `precision=jax.lax.Precision.DEFAULT`: uses mixed bfloat16 precision (fastest)\n- `precision=jax.lax.Precision.HIGH`: uses multiple MXU passes to achieve higher precision\n- `precision=jax.lax.Precision.HIGHEST`: uses even more MXU passes to achieve full float32 precision\n\nJAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays to\n`bfloat16`. For example,\n`jax.numpy.array(x, dtype=jax.numpy.bfloat16)`.\n\n\nWhat's next\n-----------\n\nFor more information about Cloud TPU, see:\n\n- [Run JAX code on TPU slices](/tpu/docs/jax-pods)\n- [Manage TPUs](/tpu/docs/managing-tpus-tpu-vm)\n- [Cloud TPU System architecture](/tpu/docs/system-architecture-tpu-vm)"]]