This tutorial shows you how to train the Llama 3 8B large language model (LLM) on Google Kubernetes Engine (GKE) using MaxText, Ray Train, and TPUs.
This tutorial provides a complete, end-to-end walkthrough, from configuring the necessary cloud infrastructure to submitting and successfully running the training workload on multi-host TPUs.
This tutorial is for Platform admins and operators and Data and AI specialists who want to learn how to train large models on a distributed, multi-host TPU slice.
Background
The combination of GKE, KubeRay, MaxText, and TPUs provides a powerful and scalable platform for large-scale model training. This section describes the key technologies used in this guide:
JAX
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.
JAX provides an extensible system for transforming numerical functions like
jax.grad, jax.jit, and jax.vmap, utilizing the XLA compiler to create
highly optimized code that scales efficiently on accelerators like GPUs and
TPUs. The core power of JAX lies in its composability, which allows users to
combine these transformations to build complex, high-performance numerical
programs for distributed execution.
MaxText
MaxText is a high-performance, open-source large language model (LLM) designed for scalability and customizability. MaxText is built on top of JAX and optimized to run efficiently on Cloud TPU and GPUs.
TPUs
Tensor Processing Units (TPUs), are custom-designed accelerators created by Google to optimize machine learning workloads. Unlike general-purpose CPUs or parallel-processing GPUs, TPUs are highly specialized for the massive matrix and tensor computations at the foundation of deep learning, making them efficient at this specific task. The primary advantage of TPUs is performance at scale.
This tutorial uses TPU Trillium, which is the sixth generation of TPUs. For more information, see Benefits of using TPU Trillium.
KubeRay
KubeRay is a Kubernetes operator that provides a unified way to deploy, manage, and monitor Ray applications on Kubernetes. The KubeRay operator is installed and managed through the Ray on GKE add-on, which is the recommended way to deploy and manage Ray clusters on GKE.
Objectives
This tutorial shows you how to do the following:
- Set up a GKE cluster with a multi-host TPU node pool.
- Configure KubeRay to manage the distributed training environment.
- Build a custom Docker image that contains MaxText, Ray, and JAX dependencies.
- Create a Python training script that uses Ray Train's JaxTrainerto orchestrate the MaxText training loop across the TPU slice.
- Define a RayClustercustom resource to provision the head and worker nodes with the necessary TPU resources.
- Submit the training Job to the RayClusterand monitor its progress.
- Use Cloud Storage to store model checkpoints.
Before you begin
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
- 
      Install the Google Cloud CLI. 
- 
          If you're using an external identity provider (IdP), you must first sign in to the gcloud CLI with your federated identity. 
- 
        To initialize the gcloud CLI, run the following command: gcloud init
- 
  
  
    Create or select a Google Cloud project. Roles required to select or create a project - Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
- 
      Create a project: To create a project, you need the Project Creator
      (roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
 - 
        Create a Google Cloud project: gcloud projects create PROJECT_ID Replace PROJECT_IDwith a name for the Google Cloud project you are creating.
- 
        Select the Google Cloud project that you created: gcloud config set project PROJECT_ID Replace PROJECT_IDwith your Google Cloud project name.
 
- 
  
    Verify that billing is enabled for your Google Cloud project. 
- 
  
  
    
      Enable the required API: Roles required to enable APIs To enable APIs, you need the Service Usage Admin IAM role ( roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.gcloud services enable container.googleapis.com 
- 
      Install the Google Cloud CLI. 
- 
          If you're using an external identity provider (IdP), you must first sign in to the gcloud CLI with your federated identity. 
- 
        To initialize the gcloud CLI, run the following command: gcloud init
- 
  
  
    Create or select a Google Cloud project. Roles required to select or create a project - Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
- 
      Create a project: To create a project, you need the Project Creator
      (roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
 - 
        Create a Google Cloud project: gcloud projects create PROJECT_ID Replace PROJECT_IDwith a name for the Google Cloud project you are creating.
- 
        Select the Google Cloud project that you created: gcloud config set project PROJECT_ID Replace PROJECT_IDwith your Google Cloud project name.
 
- 
  
    Verify that billing is enabled for your Google Cloud project. 
- 
  
  
    
      Enable the required API: Roles required to enable APIs To enable APIs, you need the Service Usage Admin IAM role ( roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.gcloud services enable container.googleapis.com 
- 
    
        Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmingcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE Replace the following: - PROJECT_ID: Your project ID.
- USER_IDENTIFIER: The identifier for your user account. For example,- myemail@example.com.
- ROLE: The IAM role that you grant to your user account.
 
- Because this tutorial utilizes TPU Trillium (v6e), select a region or zone with availability. For more information, see Cloud TPU quotas.
Prepare your environment
In this tutorial, you use Cloud Shell. Cloud Shell comes
preinstalled with the gcloud, helm, and kubectl command-line tools that
are used in this tutorial.
- Go to the Google Cloud console. 
- At the top of the Google Cloud console window, click the Activate Cloud Shell  button. button.- A Cloud Shell session opens inside a new frame in the Google Cloud console and displays a command-line prompt. 
- Create and activate a Python virtual environment: - python3 -m venv ray-env source ray-env/bin/activate
- Install the Ray CLI and other dependencies: - pip install "ray[default]==2.49.1"
- Set the following environment variables: - export PROJECT_ID=$(gcloud config get project) export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)") export GS_BUCKET=GS_BUCKET export KSA_NAME=KSA_NAME export NAMESPACE=default export CLUSTER_NAME=CLUSTER_NAME export REGION=REGION export ZONE=ZONE export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY- Replace the following: - GS_BUCKET: the name of the Cloud Storage bucket.
- KSA_NAME: the name of the Kubernetes Service Account.
- CLUSTER_NAME: the name of the new cluster.
- REGION: the region where your TPU Trillium capacity is available.
- ZONE: the zone where your TPU Trillium capacity is available. For more information, see TPU availability in GKE.
- ARTIFACT_REGISTRY: the name of the Artifact Registry repository.
 
Create a GKE cluster
You can configure KubeRay on TPUs in a GKE Autopilot or Standard cluster. We recommend that you use a Autopilot cluster for a fully managed Kubernetes experience. To choose the GKE mode of operation that's the best fit for your workloads, see About GKE modes of operation.
Autopilot
- In Cloud Shell, run the following command: - gcloud container clusters create-auto $CLUSTER_NAME \ --enable-ray-operator \ --machine-type=n1-standard-16 \ --location=$REGION
- To communicate with your cluster, configure - kubectl:- gcloud container clusters get-credentials CLUSTER_NAME \ --location=$ZONE
Standard
- In Cloud Shell, create a Standard cluster that enables the Ray operator add-on by running the following command: - gcloud container clusters create $CLUSTER_NAME \ --addons=RayOperator \ --addons GcsFuseCsiDriver \ --machine-type=n1-standard-16 \ --workload-pool=$PROJECT_ID.svc.id.goog \ --location=$ZONE- This command also enables the - GcsFuseCsiDriver, which allows Pods to mount Cloud Storage buckets as local file systems. The cluster creation might take several minutes.
- To communicate with your cluster, configure - kubectl:- gcloud container clusters get-credentials CLUSTER_NAME \ --location=LOCATION
- Create a multi-host TPU slice node pool: - gcloud container node-pools create v6e-16 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4
GKE provisions a node pool consisting of four TPU Trillium (v6e)
VMs, which are configured together as a multi-host TPU slice, with a 4x4
topology, that's ready for distributed training workloads.
The Ray operator-enabled GKE cluster automatically installs KubeRay and the KubeRay TPU webhook in your cluster.
Configure a Cloud Storage bucket and a service account
- Create a Cloud Storage bucket for shared checkpoints between the multi-host TPU nodes. - gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
- To enable access to the Cloud Storage bucket, create a Kubernetes Service Account: - kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
- To enable access to the Cloud Storage bucket, add the required IAM policy bindings to the service account: - gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \ --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \ --role "roles/storage.objectUser"
Create the training script
The following script uses Ray Train's JaxTrainer to run a distributed MaxText
training job. The script configures the training environment for a multi-host
TPU slice node pool and runs the MaxText training job on each worker node. The
train_loop_per_worker function wraps the MaxText main entry point, and uses
the Ray's distributed scheduler to execute the MaxText trainer on a multi-host
TPU slice.
- Save the following Python script as - maxtext_ray_trainer.py:
- To host the custom image, create an Artifact Registry repository: - gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \ --repository-format=docker --location=${REGION} && \ gcloud auth configure-docker ${REGION}-docker.pkg.dev
- To build an image that includes Ray and MaxText dependencies for training, create a - Dockerfile:
- Build, tag, and push the Docker image to Artifact Registry: - export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest gcloud builds submit --tag ${DOCKER_IMAGE}
Train the model
- Save the following sample manifest as - maxtext-tpu-cluster.yaml:- The preceding RayCluster spec creates a TPU worker group with four workers ( - numOfHosts: 4) per replica. Each worker requests four TPU chips (- google.com/tpu: "4"). The workers will be scheduled on a node that runs TPU Trillium (- tpu-v6e-slice), and that's part of the same colocated multi-host slice. KubeRay scales all four workers atomically, and the required JAX environment variables, as well as Pod Affinities for scheduling, are bootstrapped by GKE through a mutating webhook.
- To configure required values in the YAML file, create the RayCluster using - envsubst:- envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
- Verify the cluster is ready and running: - kubectl get rayclusters maxtext-tpu-cluster- The output should be similar to the following: - NAME DESIRED WORKERS AVAILABLE WORKERS CPUS MEMORY GPUS STATUS AGE maxtext-tpu-cluster 4 4 40 798027216Ki 0 ready 11m
- To access the Ray Dashboard through the Ray head service, establish a port-forwarding session: - kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
- Verify the RayCluster is reachable from your local environment: - ray list nodes --address http://localhost:8265- The output should be similar to the following: - ======== List: 2025-09-13 03:53:16.988269 ======== Stats: ------------------------------ Total: 5 Table: ------------------------------ NODE_ID NODE_IP IS_HEAD_NODE STATE STATE_MESSAGE NODE_NAME RESOURCES_TOTAL LABELS 0 92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56 10.84.0.9 (...)
- Submit the JaxTrainer script to the RayCluster and check that the RayJob completes successfully: - ray job submit \ --address http://localhost:8265 \ -- python /app/maxtext_ray_trainer.py \ /app/maxtext/src/MaxText/configs/base.yml \ base_output_directory=/data/ \ dataset_type=synthetic \ per_device_batch_size=1 \ max_target_length=4096 \ model_name=llama3-8b \ steps=100 \ ici_fsdp_parallelism=4 \ ici_tensor_parallelism=4 \ run_name=rayjob-8b-4096-tp4-4x4- The preceding command submits the Python script, which calls the JaxTrainer Ray code to the RayCluster. The - ray job submitcommand includes some MaxText-specific arguments to pass to the model configuration.- In your terminal, you should see output similar to the following: - (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster] ------------------------------------------ Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded ------------------------------------------
Clean up
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
- Delete the RayCluster: - kubectl delete raycluster maxtext-tpu-cluster
- Delete the GKE cluster: - gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
- Delete the Cloud Storage bucket: - gsutil rm -r gs://${GS_BUCKET}
- Delete the Artifact Registry repository: - gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
What's next
- Learn about Ray on Kubernetes.
- Learn how to Serve vLLM on GKE with TPUs.
- Learn how to Serve SDXL on GKE with TPUs.
- Learn more About TPUs in GKE.