JetStream MaxText inference on v6e TPU VMs

This tutorial shows how to use JetStream to serve MaxText models on TPU v6e. JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs). In this tutorial, you run the inference benchmark for the Llama2-7B model.

Before you begin

Prepare to provision a TPU v6e with 4 chips:

  1. Follow Set up the Cloud TPU environment guide to ensure you have appropriate access to use Cloud TPUs.

  2. Create a service identity for the TPU VM.

    gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
  3. Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Secure capacity

When you are ready to secure TPU capacity, review the quotas page to learn about the Cloud Quotas system. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.

Provision the Cloud TPU environment

You can provision TPU VMs with GKE, with GKE and XPK, or as queued resources.

Prerequisites

  • This tutorial has been tested with Python 3.10 or later.
  • Verify that your project has enough TPUS_PER_TPU_FAMILY quota, which specifies the maximum number of chips you can access within your Google Cloud project.
  • Verify that your project has enough TPU quota for:
    • TPU VM quota
    • IP Address quota
    • Hyperdisk balanced quota
  • User project permissions

Create environment variables

In a Cloud Shell, create the following environment variables:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-4
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Command flag descriptions

Variable Description
NODE_ID The user-assigned ID of the TPU that is created when the queued resource request is allocated.
PROJECT_ID Google Cloud project name. Use an existing project or create a new one.
ZONE See the TPU regions and zones document for the supported zones.
ACCELERATOR_TYPE See the Accelerator Types docuentation for all supported accelerator types.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT This is the email address for your service account that you can find in Google Cloud console -> IAM -> Service Accounts

For example: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES The number of slices to create (needed for Multislice only)
QUEUED_RESOURCE_ID The user-assigned text ID of the queued resource request.
VALID_DURATION The duration for which the queued resource request is valid.
NETWORK_NAME The name of a secondary network to use.
NETWORK_FW_NAME The name of a secondary network firewall to use.

Provision a TPU v6e

gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
    --node-id TPU_NAME \
    --project PROJECT_ID \
    --zone ZONE \
    --accelerator-type v6e-4 \
    --runtime-version v2-alpha-tpuv6e \
    --service-account SERVICE_ACCOUNT

Use the list or describe commands to query the status of your queued resource.

   gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
      --project ${PROJECT_ID} --zone ${ZONE}

For a complete list of queued resource request statuses, see the Queued Resources documentation.

Connect to the TPU using SSH

   gcloud compute tpus tpu-vm ssh TPU_NAME

Once you have connected to the TPU, you can run the inference benchmark.

Run the Llama2-7B inference benchmark

To set up JetStream and MaxText, convert the model checkpoints, and run the inference benchmark, follow the instructions in the GitHub repository.

When the inference benchmark is complete, be sure to clean up the TPU resources.

Clean up

Delete the TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async