Multihost inference is a method of running model inference that distributes the model across multiple accelerators hosts. This enables the inference of large models that don't fit on a single host. Pathways can be deployed for both batch and real time multihost inference use cases.
Before you begin
Make sure you have:
- Created a GKE cluster that uses Trillium chips (v6e-16).
- Installed Kubernetes tools
- Enabled the TPU API
- Enabled the Google Kubernetes Engine API
- Ensure your Google Cloud project is allowlisted for Pathways
Run Batch inference using JetStream
JetStream is a throughput and memory-optimized engine for large language model (LLM) inference on XLA devices, primarily Tensor Processing Units (TPUs) written in JAX.
You can use a prebuilt JetStream Docker image to run a batch inference workload,
as shown in the following YAML. This container is built from the
OSS JetStream project.
For more information about MaxText-JetStream flags, see
JetStream MaxText server flags.
The following example uses Trillium chips (v6e-16
) to load the Llama3.1-405b
int8 checkpoint and perform inference over it. This example assumes you already
have a GKE cluster with at least one v6e-16
nodepool inside it.
Start model server and Pathways
- Get credentials to the cluster and add them to your local kubectl context.
gcloud container clusters get-credentials $CLUSTER \ --zone=$ZONE \ --project=$PROJECT \ && kubectl config set-context --current --namespace=default
- Install the `PathwaysJob` API on your GKE cluster:
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.8.0/manifests.yaml kubectl apply --server-side -f https://github.com/google/pathways-job/releases/download/v0.1.0/install.yaml
- Copy and paste the following YAML into a file named
pathways-job.yaml
: This YAML has been optimized for thev6e-16
slice shape. For more information about how to convert a Meta checkpoint into a JAX compatible checkpoint, see Checkpoint conversion with Llama3.1-405B. Replace the following:apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USERNAME spec: maxRestarts: MAX_RESTARTS workers: - type: TPU_MACHINE_TYPE # ct6e-standard-4t for this test topology: TOPOLOGY # 4x4 for this test numSlices: WORKLOAD_NODEPOOL_COUNT # 1 for this test pathwaysDir: "gs://BUCKET_NAME" # Pre-create this bucket. controller: deploymentMode: colocate_head_with_worker template: mainContainerName: jetstream spec: containers: - name: jetstream image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 args: - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml - model_name=llama3.1-405b - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - steps=1 - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=2 - ici_tensor_parallelism=8 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=10 - enable_single_controller=true - quantization=int8 - quantize_kvcache=true - checkpoint_is_quantized=true - enable_model_warmup=true imagePullPolicy: Always ports: - containerPort: 9000 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000
USERNAME
: your Google Cloud user IDJAX_VERSION
: the version of JAX you are usingMAX_RESTARTS
: the maximum number of times the PathwaysJob can be restartedTPU_MACHINE_TYPE
: the TPU machine typeTOPOLOGY
: the TPU topologyWORKLOAD_NODEPOOL_COUNT
: the number of node pools used by a Pathways workloadCUSTOM_PROXY_SERVER_IMAGE
: a custom Pathways proxy server imageCUSTOM_PATHWAYS_SERVER_IMAGE
: a custom Pathways server imageCUSTOM_PATHWAYS_WORKER_IMAGE
: a custom Pathways worker imageBUCKET_NAME
: the Cloud Storage bucket for temporary filesGCS_CHECKPOINT_PATH
: the converted checkpoint location, it should be similar togs://OUTPUT_BUCKET_DIRECTORY/bf16/unscanned/0/items
for thebf16
checkpoint orgs://OUTPUT_BUCKET_DIRECTORY/int8
for anint8
checkpoint
- Look at the Kubernetes logs to see if the JetStream model server is ready:
The output is similar to the following which indicates the JetStream model server is ready to serve requests:kubectl get pods | grep pathways-USERNAME-pathways-head-0-0-uuid # find uuid from here kubectl logs -f pathways-USERNAME-pathways-head-0-0-uuid -c jetstream
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Connect to the model server
You can access the JetStream Pathways deployment using GKE's ClusterIP service. The ClusterIP service is only reachable from within the cluster. Therefore, to access the service from outside the cluster, you must first establish a port-forwarding session by running the following command:
kubectl port-forward pod/pathways-USERNAME-pathways-head-0-0-UUID 8000:8000
Verify that you can access the JetStream HTTP server by opening a new terminal and running the following command:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
The initial request can take several seconds to complete due to model warmup. The output should be similar to the following:
{
"response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}
Disaggregated inference
Disaggregated serving is a technique for running large language models (LLMs) that separates the prefill and decode stages into different processes, potentially on different machines. This allows for better utilization of resources and can lead to improvements in performance and efficiency, especially for large models.
- Prefill: this stage processes the input prompt and generates an intermediate representation (like a key-value cache). It's often compute intensive.
- Decode: this stage generates the output tokens, one by one, using the prefill representation. It is typically memory-bandwidth bound.
By separating these stages, disaggregated serving allows for prefill and decode to run in parallel, improving throughput and latency.
To enable disaggregated serving, modify the previous YAML to utilize two v6e-8
slices: one for prefill and the other for generate. Before proceeding, ensure
your GKE cluster has at least two nodepools configured with this v6e-8
topology.
For optimal performance, specific XLA flags have been configured.
- To launch the JetStream server in disaggregated mode using Pathways, copy
and paste the following YAML into a file named
pathways-job.yaml
: Replace the following:apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USERNAME spec: maxRestarts: number-of-times-the-PathwaysJob-can-be-restarted workers: - type: ct6e-standard-4t topology: 2x4 numSlices: 2 pathwaysDir: "gs://BUCKET_NAME" controller: deploymentMode: colocate_head_with_workers mainContainerName: jetstream template: spec: containers: - name: jetstream image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 args: - MaxText/configs/base.yml - tokenizer_path=assets/tokenizer.llama2 - load_parameters_path=gs://GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - model_name=llama2-70b - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=1 - ici_tensor_parallelism=-1 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=1 - checkpoint_is_quantized=true - quantization=int8 - quantize_kvcache=true - compute_axis_order=0,2,1,3 - ar_cache_axis_order=0,2,1,3 - stack_prefill_result_cache=True - inference_server=ExperimentalMaxtextDisaggregatedServer_8 - inference_benchmark_test=True - enable_model_warmup=True imagePullPolicy: Always securityContext: capabilities: add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"] ports: - containerPort: 9000 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000
USERNAME
: your Google Cloud user IDJAX_VERSION
: the JAX versionMAX_RESTARTS
: the maximum number of times the Job can be restartedTPU_MACHINE_TYPE
: the [TPU machine type](/kubernetes-engine/docs/concepts/tpus#machine_type)TOPOLOGY
: The TPU topologyWORKLOAD_NODEPOOL_COUNT
: the number of node pools used by a Pathways workloadCUSTOM_PROXY_SERVER_IMAGE
: a custom Pathways proxy server imageCUSTOM_PATHWAYS_SERVER_IMAGE
: a custom Pathways server imageCUSTOM_PATHWAYS_WORKER_IMAGE
: a custom Pathways worker imageCUSTOM_PYTHON_SIDECAR_IMAGE
: a custom Python sidecar server imageBUCKET_NAME
: the Cloud Storage bucket for temporary files, create this before applying this YAML
- Apply this YAML, the model server will take some time to restore the
checkpoint. For the 70B model, this may take ~2 minutes.
kubectl apply pathways-job.yaml
- Look at the Kubernetes logs to see if the JetStream model server is ready:
You will see output similar to the following which indicates the JetSteam model server is ready to serve requests:HEAD_POD=$(kubectl get pods | grep pathways-USERNAME-pathways-head | awk '{print $1}') kubectl logs -f ${HEAD_POD} -c jetstream
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Connect to the model server
You can access the JetStream Pathways deployment through the GKE's ClusterIP service. The ClusterIP service is only reachable from within the cluster. Therefore, to access the service from outside the cluster, establish a port-forwarding session by running the following command:
kubectl port-forward pod/pathways-USERNAME-pathways-head-0-0-UUID 8000:8000
Verify that you can access the JetStream HTTP server by opening a new terminal and running the following command:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
The initial request can take several seconds to complete due to model warmup. The output should be similar to the following:
{
"response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}