Serve Gemma using TPUs on GKE with JetStream


This guide shows you how to serve a Gemma large language model (LLM) using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE) with JetStream through MaxText. In this guide, you download the Gemma 7B parameter instruction tuned model weights to Cloud Storage and deploy them on a GKE Autopilot or Standard cluster using a container that runs JetStream.

If you need the scalability, resilience, and cost-effectiveness offered by Kubernetes features when deploying your model on JetStream, this guide is a good starting point.

Background

By serving Gemma using TPUs on GKE with JetStream, you can build a robust, production-ready serving solution with all the benefits of managed Kubernetes, including cost-efficiency, scalability and higher availability. This section describes the key technologies used in this tutorial.

Gemma

Gemma is a set of openly available, lightweight, generative artificial intelligence (AI) models released under an open license. These AI models are available to run in your applications, hardware, mobile devices, or hosted services. You can use the Gemma models for text generation, however you can also tune these models for specialized tasks.

To learn more, see the Gemma documentation.

TPUs

TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning and AI models built using frameworks such as TensorFlow, PyTorch, and JAX.

Before you use TPUs in GKE, we recommend that you complete the following learning path:

  1. Learn about current TPU version availability with the Cloud TPU system architecture.
  2. Learn about TPUs in GKE.

This tutorial covers serving the Gemma 7B model. GKE deploys the model on single-host TPUv5e nodes with TPU topologies configured based on the model requirements for serving prompts with low latency.

JetStream

JetStream is an open source inference serving framework developed by Google. JetStream enables high-performance, high-throughput, and memory-optimized inference on TPUs and GPUs. It provides advanced performance optimizations, including continuous batching and quantization techniques, to facilitate LLM deployment. JetStream enables PyTorch/XLA and JAX TPU serving to achieve optimal performance.

To learn more about these optimizations, refer to the JetStream PyTorch and JetStream MaxText project repositories.

MaxText

MaxText is a performant, scalable, and adaptable JAX LLM implementation, built on open source JAX libraries such as Flax, Orbax, and Optax. MaxText's decoder-only LLM implementation is written in Python. It leverages the XLA compiler heavily to achieve high performance without needing to build custom kernels.

To learn more about the latest models and parameter sizes that MaxText supports, see the MaxtText project repository.

Objectives

This tutorial is intended for Generative AI customers who use JAX, new or existing users of GKE, ML Engineers, MLOps (DevOps) engineers, or platform administrators who are interested in using Kubernetes container orchestration capabilities for serving LLMs.

This tutorial covers the following steps:

  1. Prepare a GKE Autopilot or Standard cluster with the recommended TPU topology based on the model characteristics.
  2. Deploy JetStream components on GKE.
  3. Get and publish the Gemma 7B instruction tuned model.
  4. Serve and interact with the published model.

Architecture

This section describes the GKE architecture used in this tutorial. The architecture comprises a GKE Autopilot or Standard cluster that provisions TPUs and hosts JetStream components to deploy and serve the models.

The following diagram shows you the components of this architecture:

Architecture of GKE cluster with single-host TPU node pools containing the Maxengine and Max HTTP components.

This architecture includes the following components:

  • A GKE Autopilot or Standard regional cluster.
  • Two single-host TPU slice node pools that host the JetStream deployment.
  • The Service component spreads inbound traffic to all JetStream HTTP replicas.
  • JetStream HTTP is an HTTP server which accepts requests as a wrapper to JetStream's required format and sends it to JetStream's GRPC client.
  • Maxengine is a JetStream server that performs inferencing with continuous batching.

Before you begin

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role colunn to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. Click Grant access.
    4. In the New principals field, enter your user identifier. This is typically the email address for a Google Account.

    5. In the Select a role list, select a role.
    6. To grant additional roles, click Add another role and add each additional role.
    7. Click Save.
  • Ensure that you have sufficient quota for eight TPU v5e PodSlice Lite chips. In this tutorial, you use on-demand instances.
  • Create a Kaggle account, if you don't already have one.

Get access to the model

To get access to the Gemma model for deployment to GKE, you must first sign the license consent agreement.

You must sign the consent agreement to use Gemma. Follow these instructions:

  1. Access the Gemma model consent page on Kaggle.com.
  2. Login to Kaggle if you haven't done so already.
  3. Click Request Access.
  4. In the Choose Account for Consent section, select Verify via Kaggle Account to use your Kaggle account for consent.
  5. Accept the model Terms and Conditions.

Generate an access token

To access the model through Kaggle, you need a Kaggle API token.

Follow these steps to generate a new token if you don't have one already:

  1. In your browser, go to Kaggle settings.
  2. Under the API section, click Create New Token.

A file named kaggle.json is downloaded.

Prepare the environment

In this tutorial, you use Cloud Shell to manage resources hosted on Google Cloud. Cloud Shell comes preinstalled with the software you'll need for this tutorial, including kubectl and gcloud CLI.

To set up your environment with Cloud Shell, follow these steps:

  1. In the Google Cloud console, launch a Cloud Shell session by clicking Cloud Shell activation icon Activate Cloud Shell in the Google Cloud console. This launches a session in the bottom pane of Google Cloud console.

  2. Set the default environment variables:

    gcloud config set project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export BUCKET_NAME=BUCKET_NAME
    export REGION=REGION
    export LOCATION=LOCATION
    export CLUSTER_VERSION=CLUSTER_VERSION
    

    Replace the following values:

    • PROJECT_ID: your Google Cloud project ID.
    • CLUSTER_NAME: the name of your GKE cluster.
    • BUCKET_NAME: the name of your Cloud Storage bucket. You don't need to specify the gs:// prefix.
    • REGION: the region where your GKE cluster, Cloud Storage bucket, and TPU nodes are located. The region contains zones where TPU v5e machine types are available (for example, us-west1, us-west4, us-central1, us-east1, us-east5, or europe-west4). For Autopilot clusters, ensure that you have sufficient TPU v5e zonal resources for your region of choice.
    • (Standard cluster only) LOCATION: the zone where the TPU resources are available (for example, us-west4-a). For Autopilot clusters, you don't need to specify the zone, only the region.
    • CLUSTER_VERSION: the GKE version, which must support the machine type that you want to use. Note that the default GKE version might not have availability for your target TPU. For a list of minimum GKE versions available by TPU machine type, see TPU availability in GKE.

Create and configure Google Cloud resources

Follow these instructions to create the required resources.

Create a GKE cluster

You can serve Gemma on TPUs in a GKE Autopilot or Standard cluster. We recommend that you use a Autopilot cluster for a fully managed Kubernetes experience. To choose the GKE mode of operation that's the best fit for your workloads, see Choose a GKE mode of operation.

Autopilot

In Cloud Shell, run the following command:

gcloud container clusters create-auto ${CLUSTER_NAME} \
  --project=${PROJECT_ID} \
  --region=${REGION} \
  --cluster-version=${CLUSTER_VERSION}

Standard

  1. Create a regional GKE Standard cluster that uses Workload Identity Federation for GKE.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --machine-type=e2-standard-4 \
        --num-nodes=2 \
        --cluster-version=${CLUSTER_VERSION} \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    The cluster creation might take several minutes.

  2. Run the following command to create a node pool for your cluster:

    gcloud container node-pools create gemma-7b-tpu-nodepool \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct5lp-hightpu-8t \
      --project=${PROJECT_ID} \
      --num-nodes=2 \
      --region=${REGION} \
      --node-locations=${LOCATION}
    

    GKE creates a TPU v5e node pool with a 2x4 topology and two nodes.

Create a Cloud Storage bucket

In Cloud Shell, run the following command:

gcloud storage buckets create gs://${BUCKET_NAME} --location=${REGION}

This creates a Cloud Storage bucket to store the model files you download from Kaggle.

Upload the access token to Cloud Shell

In Cloud Shell, you can upload the Kaggle API token to your Google Cloud project:

  1. In Cloud Shell, click More > Upload.
  2. Select File and click Choose Files.
  3. Open the kaggle.json file.
  4. Click Upload.

Create a Kubernetes Secret for Kaggle credentials

In Cloud Shell, do the following:

  1. Configure kubectl to communicate with your cluster:

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. Create a Secret to store the Kaggle credentials:

    kubectl create secret generic kaggle-secret \
        --from-file=kaggle.json
    

Configure your workloads access using Workload Identity Federation for GKE

Assign a Kubernetes ServiceAccount to the application and configure that Kubernetes ServiceAccount to act as an IAM service account.

  1. Create an IAM service account for your application:

    gcloud iam service-accounts create wi-jetstream
    
  2. Add an IAM policy binding for your IAM service account to manage Cloud Storage:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
        --member "serviceAccount:wi-jetstream@${PROJECT_ID}.iam.gserviceaccount.com" \
        --role roles/storage.objectUser
    
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
        --member "serviceAccount:wi-jetstream@${PROJECT_ID}.iam.gserviceaccount.com" \
        --role roles/storage.insightsCollectorService
    
  3. Allow the Kubernetes ServiceAccount to impersonate the IAM service account by adding an IAM policy binding between the two service accounts. This binding allows the Kubernetes ServiceAccount to act as the IAM service account:

    gcloud iam service-accounts add-iam-policy-binding wi-jetstream@${PROJECT_ID}.iam.gserviceaccount.com \
        --role roles/iam.workloadIdentityUser \
        --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/default]"
    
  4. Annotate the Kubernetes service account with the email address of the IAM service account:

    kubectl annotate serviceaccount default \
        iam.gke.io/gcp-service-account=wi-jetstream@${PROJECT_ID}.iam.gserviceaccount.com
    

Convert the model checkpoints

In this section, you create a Job to do the following:

  1. Download the base Orbax checkpoint from Kaggle.
  2. Upload the checkpoint to a Cloud Storage bucket.
  3. Convert the checkpoint to a MaxText compatible checkpoint.
  4. Unscan the checkpoint to be used for serving.

Deploy the model checkpoint conversion Job

Follow these instructions to download and convert the Gemma 7B model checkpoint files.

  1. Create the following manifest as job-7b.yaml.

    apiVersion: batch/v1
    kind: Job
    metadata:
      name: data-loader-7b
    spec:
      ttlSecondsAfterFinished: 30
      template:
        spec:
          restartPolicy: Never
          containers:
          - name: inference-checkpoint
            image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.2
            args:
            - -b=BUCKET_NAME
            - -m=google/gemma/maxtext/7b-it/2
            volumeMounts:
            - mountPath: "/kaggle/"
              name: kaggle-credentials
              readOnly: true
            resources:
              requests:
                google.com/tpu: 8
              limits:
                google.com/tpu: 8
          nodeSelector:
            cloud.google.com/gke-tpu-topology: 2x4
            cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
          volumes:
          - name: kaggle-credentials
            secret:
              defaultMode: 0400
              secretName: kaggle-secret
    
  2. Apply the manifest:

    kubectl apply -f job-7b.yaml
    
  3. Wait for the Pod scheduling the Job to begin running:

    kubectl get pod -w
    

    The output will be similar to the following, this may take a few minutes:

    NAME                  READY   STATUS              RESTARTS   AGE
    data-loader-7b-abcd   0/1     ContainerCreating   0          28s
    data-loader-7b-abcd   1/1     Running             0          51s
    

    For Autopilot clusters, it may take a few minutes to provision the required TPU resources.

  4. View the logs from the Job:

    kubectl logs -f jobs/data-loader-7b
    

    When the Job is completed, the output is similar to the following:

    Successfully generated decode checkpoint at: gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
    + echo -e '\nCompleted unscanning checkpoint to gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items'
    
    Completed unscanning checkpoint to gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
    

Deploy JetStream

In this section, you deploy the JetStream container to serve the Gemma model.

Follow these instructions to deploy the Gemma 7B instruction tuned model.

  1. Save the following Deployment manifest as jetstream-gemma-deployment.yaml. A Deployment is a Kubernetes API that lets you run multiple replicas of Pods among the nodes in a cluster:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: maxengine-server
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: maxengine-server
      template:
        metadata:
          labels:
            app: maxengine-server
        spec:
          nodeSelector:
            cloud.google.com/gke-tpu-topology: 2x4
            cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
          containers:
          - name: maxengine-server
            image: us-docker.pkg.dev/cloud-tpu-images/inference/maxengine-server:v0.2.2
            args:
            - model_name=gemma-7b
            - tokenizer_path=assets/tokenizer.gemma
            - per_device_batch_size=4
            - max_prefill_predict_length=1024
            - max_target_length=2048
            - async_checkpointing=false
            - ici_fsdp_parallelism=1
            - ici_autoregressive_parallelism=-1
            - ici_tensor_parallelism=1
            - scan_layers=false
            - weight_dtype=bfloat16
            - load_parameters_path=gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
            - prometheus_port=PROMETHEUS_PORT
            ports:
            - containerPort: 9000
            resources:
              requests:
                google.com/tpu: 8
              limits:
                google.com/tpu: 8
          - name: jetstream-http
            image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2
            ports:
            - containerPort: 8000
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: maxengine-server
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
      - protocol: TCP
        name: jetstream-grpc
        port: 9000
        targetPort: 9000
    

    The manifest sets the following key properties:

    • tokenizer_path: the path to your model's tokenizer.
    • load_parameters_path: the path in the Cloud Storage bucket where your checkpoints are stored.
    • per_device_batch_size: the decoding batch size per device, where one TPU chip equals one device.
    • max_prefill_predict_length: the maximum length for the prefill when doing autoregression.
    • max_target_length: the maximum sequence length.
    • model_name: the model name (gemma-7b).
    • ici_fsdp_parallelism: the number of shards for fully sharded data parallelism (FSDP).
    • ici_tensor_parallelism: the number of shards for tensor parallelism.
    • ici_autoregressive_parallelism: the number of shards for autoregressive parallelism.
    • prometheus_port: port to expose prometheus metrics. Remove this argument if metrics aren't needed.
    • scan_layers: scan layers boolean flag (boolean).
    • weight_dtype: the weight data type (bfloat16).
  2. Apply the manifest:

    kubectl apply -f jetstream-gemma-deployment.yaml
    
  3. Verify the Deployment:

    kubectl get deployment
    

    The output is similar to the following:

    NAME                              READY   UP-TO-DATE   AVAILABLE   AGE
    maxengine-server                  2/2     2            2           ##s
    

    For Autopilot clusters, it may take a few minutes to provision the required TPU resources.

  4. View the HTTP server logs to check that the model has been loaded and compiled. It may take the server a few minutes to complete this operation.

    kubectl logs deploy/maxengine-server -f -c jetstream-http
    

    The output is similar to the following:

    kubectl logs deploy/maxengine-server -f -c jetstream-http
    
    INFO:     Started server process [1]
    INFO:     Waiting for application startup.
    INFO:     Application startup complete.
    INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
    
  5. View the MaxEngine logs and verify that the compilation is done.

    kubectl logs deploy/maxengine-server -f -c maxengine-server
    

    The output is similar to the following:

    2024-03-29 17:09:08,047 - jax._src.dispatch - DEBUG - Finished XLA compilation of jit(initialize) in 0.26236414909362793 sec
    2024-03-29 17:09:08,150 - root - INFO - ---------Generate params 0 loaded.---------
    

Serve the model

In this section, you interact with the model.

Set up port forwarding

You can access the JetStream Deployment through the ClusterIP Service that you created in the preceding step. The ClusterIP Services are only reachable from within the cluster. Therefore, to access the Service from outside the cluster, complete the following steps:

To establish a port forwarding session, run the following command:

kubectl port-forward svc/jetstream-svc 8000:8000

Interact with the model using curl

  1. Verify that you can access the JetStream HTTP server by opening a new terminal and running the following command:

    curl --request POST \
    --header "Content-type: application/json" \
    -s \
    localhost:8000/generate \
    --data \
    '{
        "prompt": "What are the top 5 programming languages",
        "max_tokens": 200
    }'
    

    The initial request can take several seconds to complete due to model warmup. The output is similar to the following:

    {
        "response": "\nfor data science in 2023?\n\n**1. Python:**\n- Widely used for data science due to its simplicity, readability, and extensive libraries for data wrangling, analysis, visualization, and machine learning.\n- Popular libraries include pandas, scikit-learn, and matplotlib.\n\n**2. R:**\n- Statistical programming language widely used for data analysis, visualization, and modeling.\n- Popular libraries include ggplot2, dplyr, and caret.\n\n**3. Java:**\n- Enterprise-grade language with strong performance and scalability.\n- Popular libraries include Spark, TensorFlow, and Weka.\n\n**4. C++:**\n- High-performance language often used for data analytics and machine learning models.\n- Popular libraries include TensorFlow, PyTorch, and OpenCV.\n\n**5. SQL:**\n- Relational database language essential for data wrangling and querying large datasets.\n- Popular tools"
    }
    

(Optional) Interact with the model through a Gradio chat interface

In this section, you build a web chat application that lets you interact with your instruction tuned model.

Gradio is a Python library that has a ChatInterface wrapper that creates user interfaces for chatbots.

Deploy the chat interface

  1. In Cloud Shell, save the following manifest as gradio.yaml:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: gradio
      labels:
        app: gradio
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: gradio
      template:
        metadata:
          labels:
            app: gradio
        spec:
          containers:
          - name: gradio
            image: us-docker.pkg.dev/google-samples/containers/gke/gradio-app:v1.0.3
            resources:
              requests:
                cpu: "512m"
                memory: "512Mi"
              limits:
                cpu: "1"
                memory: "512Mi"
            env:
            - name: CONTEXT_PATH
              value: "/generate"
            - name: HOST
              value: "http://jetstream-http-svc:8000"
            - name: LLM_ENGINE
              value: "max"
            - name: MODEL_ID
              value: "gemma"
            - name: USER_PROMPT
              value: "<start_of_turn>user\nprompt<end_of_turn>\n"
            - name: SYSTEM_PROMPT
              value: "<start_of_turn>model\nprompt<end_of_turn>\n"
            ports:
            - containerPort: 7860
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: gradio
    spec:
      selector:
        app: gradio
      ports:
        - protocol: TCP
          port: 8080
          targetPort: 7860
      type: ClusterIP
    
  2. Apply the manifest:

    kubectl apply -f gradio.yaml
    
  3. Wait for the deployment to be available:

    kubectl wait --for=condition=Available --timeout=300s deployment/gradio
    

Use the chat interface

  1. In Cloud Shell, run the following command:

    kubectl port-forward service/gradio 8080:8080
    

    This creates a port forward from Cloud Shell to the Gradio service.

  2. Click the Web Preview icon Web Preview button which can be found on the top right of the Cloud Shell taskbar. Click Preview on Port 8080. A new tab opens in your browser.

  3. Interact with Gemma using the Gradio chat interface. Add a prompt and click Submit.

Troubleshoot issues

  • If you get the message Empty reply from server, it's possible the container has not finished downloading the model data. Check the Pod's logs again for the Connected message which indicates that the model is ready to serve.
  • If you see Connection refused, verify that your port forwarding is active.

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

Delete the deployed resources

To avoid incurring charges to your Google Cloud account for the resources that you created in this guide, run the following commands and follow the prompts:

gcloud container clusters delete ${CLUSTER_NAME} --region=${REGION}

gcloud iam service-accounts delete wi-jetstream@PROJECT_ID.iam.gserviceaccount.com

gcloud storage rm --recursive gs://BUCKET_NAME

What's next