importjax# The total number of TPU cores in the slicedevice_count=jax.device_count()# The number of TPU cores attached to this hostlocal_device_count=jax.local_device_count()# The psum is performed over all mapped devices across the slicexs=jax.numpy.ones(jax.local_device_count())r=jax.pmap(lambdax:jax.lax.psum(x,'i'),axis_name='i')(xs)# Print from a single host to avoid duplicated outputifjax.process_index()==0:print('global device count:',jax.device_count())print('local device count:',jax.local_device_count())print('pmap result:',r)
[[["易于理解","easyToUnderstand","thumb-up"],["解决了我的问题","solvedMyProblem","thumb-up"],["其他","otherUp","thumb-up"]],[["很难理解","hardToUnderstand","thumb-down"],["信息或示例代码不正确","incorrectInformationOrSampleCode","thumb-down"],["没有我需要的信息/示例","missingTheInformationSamplesINeed","thumb-down"],["翻译问题","translationIssue","thumb-down"],["其他","otherDown","thumb-down"]],["最后更新时间 (UTC):2025-08-11。"],[],[],null,["# Run JAX code on TPU slices\n==========================\n\nBefore running the commands in this document, make sure you have followed the\ninstructions in [Set up an account and Cloud TPU project](/tpu/docs/setup-gcp-account).\n\nAfter you have your JAX code running on a single TPU board, you can scale up\nyour code by running it on a [TPU slice](/tpu/docs/system-architecture-tpu-vm#slices).\nTPU slices are multiple TPU boards connected to each other over dedicated\nhigh-speed network connections. This document is an introduction to running JAX\ncode on TPU slices; for more in-depth information, see\n[Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html).\n\nCreate a Cloud TPU slice\n------------------------\n\n1. Create some environment variables:\n\n\n ```bash\n export PROJECT_ID=your-project-id\n export TPU_NAME=your-tpu-name\n export ZONE=europe-west4-b\n export ACCELERATOR_TYPE=v5litepod-32\n export RUNTIME_VERSION=v2-alpha-tpuv5-lite\n ``` \n\n #### Environment variable descriptions\n\n \u003cbr /\u003e\n\n2. Create a TPU slice using the `gcloud` command. For example, to create a\n v5litepod-32 slice use the following command:\n\n ```bash\n $ gcloud compute tpus tpu-vm create ${TPU_NAME} \\\n --zone=${ZONE} \\\n --project=${PROJECT_ID} \\\n --accelerator-type=${ACCELERATOR_TYPE} \\\n --version=${RUNTIME_VERSION} \n ```\n\nInstall JAX on your slice\n-------------------------\n\nAfter creating the TPU slice, you must install JAX on all hosts in the TPU\nslice. You can do this using the `gcloud compute tpus tpu-vm ssh` command using\nthe `--worker=all` and `--commamnd` parameters. \n\n```bash\ngcloud compute tpus tpu-vm ssh ${TPU_NAME} \\\n --zone=${ZONE} \\\n --project=${PROJECT_ID} \\\n --worker=all \\\n --command='pip install -U \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'\n```\n\nRun JAX code on the slice\n-------------------------\n\nTo run JAX code on a TPU slice, you must run the code **on each host in the\nTPU slice** . The `jax.device_count()` call stops responding until it is\ncalled on each host in the slice. The following example illustrates how to\nrun a JAX calculation on a TPU slice.\n\n### Prepare the code\n\nYou need `gcloud` version \\\u003e= 344.0.0 (for the\n[scp](/sdk/gcloud/reference/compute/tpus/tpu-vm/scp) command).\nUse `gcloud --version` to check your `gcloud` version, and\nrun `gcloud components upgrade`, if needed.\n\nCreate a file called `example.py` with the following code: \n\n\n import jax\n\n # The total number of TPU cores in the slice\n device_count = jax.device_count()\n\n # The number of TPU cores attached to this host\n local_device_count = jax.local_device_count()\n\n # The psum is performed over all mapped devices across the slice\n xs = jax.numpy.ones(jax.local_device_count())\n r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)\n\n # Print from a single host to avoid duplicated output\n if jax.process_index() == 0:\n print('global device count:', jax.device_count())\n print('local device count:', jax.local_device_count())\n print('pmap result:', r)\n\n### Copy `example.py` to all TPU worker VMs in the slice\n\n```bash\n$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \\\n --worker=all \\\n --zone=${ZONE} \\\n --project=${PROJECT_ID}\n```\n\nIf you have not previously used the `scp` command, you might see an\nerror similar to the following: \n\n ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH\n agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try\n again.\n\nTo resolve the error, run the `ssh-add` command as displayed in the\nerror message and rerun the command.\n\n### Run the code on the slice\n\nLaunch the `example.py` program on every VM: \n\n```bash\n$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \\\n --zone=${ZONE} \\\n --project=${PROJECT_ID} \\\n --worker=all \\\n --command=\"python3 ./example.py\"\n```\n\n### Output (produced with a v5litepod-32 slice):\n\n global device count: 32\n local device count: 4\n pmap result: [32. 32. 32. 32.]\n\nClean up\n--------\n\nWhen you are done with your TPU VM follow these steps to clean up your resources.\n\n1. Delete your Cloud TPU and Compute Engine resources.\n\n ```bash\n $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \\\n --zone=${ZONE} \\\n --project=${PROJECT_ID}\n ```\n2. Verify the resources have been deleted by running `gcloud compute tpus execution-groups list`. The\n deletion might take several minutes. The output from the following command\n shouldn't include any of the resources created in this tutorial:\n\n ```bash\n $ gcloud compute tpus tpu-vm list --zone=${ZONE} \\\n --project=${PROJECT_ID}\n ```"]]