Pathways interactive workloads are remote JAX workloads that run within a VM that is not part of GKE cluster hosting the Pathways cluster. Unlike batch workloads, the completion of interactive workload operation does not shut down the Pathways cluster components, which remain available for connection by other JAX clients. This document uses a Jupyter notebook as an example to demonstrate interactive workloads.
Using the IFRT interface, JAX users send commands to a Pathways cluster. JAX code, whether executed from a terminal, notebook, or any Python-compatible environment, can seamlessly interact with Pathways resources.
Before you begin
Make sure you have:
- Created a GKE cluster using XPK
- Installed XPK
- Installed Kubernetes tools
- Installed the gcloud CLI
- Enabled the TPU API
- Enabled the Google Kubernetes Engine API
- Ensure Pathways is enabled for your Google Cloud project
Run Pathways in interactive mode
You can run Pathways in interactive mode using xpk
or gcloud
.
XPK
Set the following environment variables:
export WORKLOAD=WORKLOAD export WORKLOAD_NODEPOOL_COUNT=WORKLOAD_NODEPOOL_COUNT export TPU_TYPE=TPU_TYPE export PROJECT_ID=PROJECT export ZONE=ZONE \ export CLUSTER=CLUSTER
Replace the following:
WORKLOAD
: set this to a unique name to identify your workloadWORKLOAD_NODEPOOL_COUNT
: the number of node pools used by a Pathways workloadTPU_TYPE
: the TPU type specifies the version and size of the Cloud TPU you want to create. For more information about supported TPU types for each TPU version, see TPU versionsPROJECT
: your Google Cloud project IDZONE
: the zone where you plan to run your workloadCLUSTER
: the name of your GKE cluster
Create the Pathways containers on the cluster. To run a headless workload, run the following command:
xpk workload create-pathways \ --headless \ --workload=${WORKLOAD} \ --num-slices=${WORKLOAD_NODEPOOL_COUNT} \ --tpu-type=${TPU_TYPE} \ --project=${PROJECT} \ --zone=${ZONE} \ --cluster=${CLUSTER}
At this point, your JAX workload can connect to the IFRT proxy server.
gcloud
The following YAML is the same as the batch workload YAML except it doesn't
specify the main
container.
- Replace the placeholders, copy the following YAML, and paste it into a file
called
pathways-headless-workload.yaml
. Replace the following:apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USERNAME spec: maxRestarts: MAX_RESTARTS workers: - type: TPU_MACHINE_TYPE topology: TOPOLOGY numSlices: WORKLOAD_NODEPOOL_COUNT pathwaysDir: "gs://BUCKET_NAME" controller: deploymentMode: default
USERNAME
: your usernameMAX_RESTARTS
: the maximum number of times thePathwaysJob
can be restartedTPU_MACHINE_TYPE
: the TPU machine type you want to useTOPOLOGY
: the TPU topologyWORKLOAD_NODEPOOL_COUNT
: the number of node pools used by a Pathways workloadBUCKET_NAME
: a Cloud Storage bucket used to store temporary files
WORKLOAD_NODEPOOL_COUNT
in the previous YAML, you need to delete thisPathwaysJob
and create a newPathwaysJob
with the updated number of node pools. You also need to restart any connected notebooks to establish the connection with the new Pathways cluster. - Apply the
pathways-headless-workload.yaml
file:kubectl apply -f pathways-headless-workload.yaml
- Run
kubectl get pods
to check that all containers in the Pod are running. The following output is for a 2 slice v5p 2x2x2, whereUSER
is the ID of the user running the command:NAME READY STATUS RESTARTS AGE pathways-USER-pathways-head-0-0-n848j 2/2 Running 0 49s pathways-USER-pathways-workers-0-0-jxt2z 1/1 Running 0 71s pathways-USER-pathways-workers-0-1-cxmhc 1/1 Running 0 70s pathways-USER-pathways-workers-1-0-5kmz9 1/1 Running 0 71s pathways-USER-pathways-workers-1-1-vg5n4 1/1 Running 0 71s
Connecting to the Pathways cluster in interactive mode
You can connect to the Pathways cluster with or without port forwarding. Use one of the following sections to connect to the Pathways cluster.
Connect using port-forwarding
At this point you can use port-forwarding (from any host with access to your cluster's control plane) to access the proxy server:
Use the command appropriate for your workload:
XPK
PROXY_POD=$(kubectl get pods | grep ${WORKLOAD}-pathways-head | awk '{print $1}')
PROXY_PORT=29000
kubectl port-forward ${PROXY_POD} ${PROXY_PORT}:${PROXY_PORT}
You should see output similar to:
Forwarding from 127.0.0.1:29000 -> 29000
Forwarding from [::1]:29000 -> 29000
gcloud
PROXY_POD=$(kubectl get pods | grep pathways-${USER}-pathways-head | awk '{print $1}')
On the same host, open a new terminal window. Set the JAX_PLATFORMS
and
JAX_BACKEND_TARGET
environment variables, and run a Python script that imports
pathwaysutils
and jax
:
python3 -m venv .venv
source .venv/bin/activate
pip install pathwaysutils jax
JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'
You should see output like the following:
[device(144,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(145,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(146,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(147,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(148,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(149,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(150,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(151,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(162,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(163,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(164,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(165,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(166,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(167,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(168,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(169,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)]
Waiting up to 5 seconds.
Sent all pending logs.
2024-11-13 21:38:51.267523: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled
Connect from hosts in the VPC without using port forwarding
If you don't want to use port forwarding you can connect to the Pathways cluster using Cloud DNS or an internal load balancer.
Connect using Cloud DNS
Enabling Cloud DNS in your cluster switches the Cloud DNS provider from kube-dns to Cloud DNS. When enabled, a private Cloud DNS zone is created in your Virtual Private Cloud for the Cloud DNS names. For more information, see Using Cloud DNS for GKE.
If you enable Cloud DNS with either the cluster scope and additive VPC
scope or VPC scope, the Kubernetes Cloud DNS names are resolvable from
non-GKE VMs inside your Virtual Private Cloud. The names have the format
<service_name>.<namespace>.svc.<custom_dns_domain>
. The Pathways head Pod has
a service named
<jobset_name>-pathways-head-0-0.<jobset_name>.<namespace>.svc.<custom_dns_domain>
.
The following example shows how to:
- Enable Cloud DNS with cluster scope
- Enable additive VPC scope with custom domain
${USER}-test
- Create a CPU-only VM in the same VPC as the GKE cluster.
Open a new terminal window. Set the
JAX_PLATFORMS
andJAX_BACKEND_TARGET
environment variables, and run a Python script that importspathwaysutils
andjax
:python3 -m venv .venv source .venv/bin/activate pip install pathwaysutils jax $ JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'
You should see output like the following:
[device(144,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3), device(145,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3), device(146,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3), device(147,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3), device(148,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3), device(149,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3), device(150,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3), device(151,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3), device(162,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3), device(163,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3), device(164,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3), device(165,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3), device(166,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3), device(167,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3), device(168,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3), device(169,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)] Waiting up to 5 seconds. Sent all pending logs. 2024-11-13 21:38:51.267523: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled
Confirm the leader Cloud DNS entry is resolvable from a non-GKE host:
host pathways-USERNAME-pathways-head-0-0.pathways-USERNAME.default.svc.USERNAME-test
You should see output similar to:
pathways-<user>-pathways-head-0-0.pathways-<user>.default.svc.<user>-test has address 10.0.2.75
Connect to the Pathways cluster using the Cloud DNS name:
JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://pathways-USERNAME-pathways-head-0-0.pathways-USERNAME.default.svc.USERNAME-test:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'
You should see output similar to:
[device(216,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3), device(217,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3), device(218,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3), device(219,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3), device(220,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3), device(221,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3), device(222,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3), device(223,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3), device(234,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3), device(235,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3), device(236,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3), device(237,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3), device(238,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3), device(239,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3), device(240,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3), device(241,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)] Waiting up to 5 seconds. Sent all pending logs. 2024-11-14 00:02:49.882044: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled
Connect using an internal load balancer
For a private IP address in your VPC that points to your pathways deployment, create a service backed by an internal load balancer. This does not require your cluster to have Cloud DNS enabled.
For clusters with many VMs, we recommend that you enable the ILB subsetting if your are creating internal load balancers. For more information, see Enable GKE subsetting in an existing cluster. When the ILB subsetting is not enabled, all nodes in the cluster will be part of the backend instance group for all internal load balancers. This does not scale beyond 250 nodes. With ILB subsetting enabled, GKE creates network endpoint groups instead of instance groups and only nodes that are running one of the service's serving Pods are included. Enabling ILB subsetting has a one-time setup latency (~15 minutes). The following command shows how to enable the ILB subsetting:
gcloud container clusters update ${CLUSTER} \ --project=${PROJECT} \ [--zone=${ZONE} | --region=${REGION}] \ --enable-l4-ilb-subsetting
Once ILB subsetting is enabled, You can create a Kubernetes service of type LoadBalancer using the following yaml. This will cause GKE to create an internal load balancer inside your cluster's VPC:
apiVersion: v1 kind: Service metadata: name: pathways-USERNAME-ilb annotations: networking.gke.io/load-balancer-type: "Internal" networking.gke.io/internal-load-balancer-allow-global-access: "true" spec: type: LoadBalancer externalTrafficPolicy: Local selector: jobset.sigs.k8s.io/jobset-name: pathways-USER jobset.sigs.k8s.io/replicatedjob-name: pathways-head ports: - name: tcp-port protocol: TCP port: 29000 targetPort: 29000
Update the USER
with your Google Cloud user ID and save the file as
pathways-headless-ilb.yaml
.
Apply the manifest:
kubectl apply -f pathways-headless-ilb.yaml
After the load balancer is created (~a minute later), the EXTERNAL-IP column will have a value:
kubectl get services
NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE
pathways-$USER ClusterIP None <none> <none> 30m
pathways-$USER-ilb LoadBalancer 34.118.232.46 10.0.0.22 80:31246/TCP 2m41s
You can access the pathways deployment without port-forwarding on hosts in the same VPC as your cluster:
JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://10.0.0.22:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'
You should see output similar to:
[device(288,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(289,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(290,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(291,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(292,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(293,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(294,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(295,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(306,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(307,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(308,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(309,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(310,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(311,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(312,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(313,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)]
Waiting up to 5 seconds.
Sent all pending logs.
2024-11-14 00:30:07.296917: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled
Jupyter notebooks
You can create a Jupyter notebook using Vertex AI or you can create a self hosted Jupyter notebook.
Create a Vertex AI workbench instance
After setting up and verifying your Pathways cluster, you can access the GKE TPU VMs from a Vertex AI Jupyter notebook. The following setup instructions assume your GKE Pathways cluster resides in the same Virtual Private Cloud network (which is the default network unless you've configured otherwise). Navigate to the Vertex AI Workbench console.
Create a new Workbench instance (from the Instances tab) with the Create new button. Ensure that the network is the same as your GKE cluster's network. You can use the command line to create a new Workbench instance.
gcloud workbench instances create INSTANCE_NAME \ --machine-type=e2-standard-4 \ --data-disk-size=100 \ --location=ZONE \ [--network=NETWORK]
Once the instance is created, navigate to it and click Open Jupyterlab.
Create a self hosted Jupyter notebook instance
The following command shows how to create a self hosted Jupyter notebook instance using XPK:
xpk workload create-pathways \
--workload=${WORKLOAD} \
--num-slices=${WORKLOAD_NODEPOOL_COUNT} \
--tpu-type=${TPU_TYPE} \
--project=${PROJECT} \
--zone=${ZONE} \
--cluster=${CLUSTER} \
--docker-image=jupyter/base-notebook \
--command "start-notebook.sh"
This command will display a URL that you can use to open the notebook in your browser.
The following YAML shows how to create a self hosted Jupyter notebook instance using kubectl. Apply the following YAML after a headless Pathways cluster has been created. For more information, see Run Pathways in interactive mode with kubectl.
apiVersion: batch/v1 kind: Job metadata: name: jupyter-notebook-USERNAME spec: template: spec: restartPolicy: OnFailure containers: - name: jupyter-notebook image: jupyter/base-notebook # Use the appropriate Jupyter image ports: - containerPort: 8888
Navigate to the notebook instance
Connect to the notebook from your local machine using port forwarding:
kubectl
MAIN_POD=$(kubectl get pods | grep jupyter-notebook-USER | awk '{print $1}')
kubectl port-forward pod/${MAIN_POD} 8888:8888
XPK
MAIN_POD=$(kubectl get pods | grep ${WORKLOAD}-main | awk '{print $1}')
Navigate on your local browser to http://localhost:8888?token=<var>your-token</var>
.
Replace <your-token>
with the token from the Jupyter notebook container's logs.
kubectl logs ${MAIN_POD}
Which should output:
... Or copy and paste one of these URLs: http://jupyter-notebook-<use>-bbbdh:8888/lab?token=<token> http://127.0.0.1:8888/lab?token=<token>
Notebook connectivity to the Pathways cluster
- From within Jupyterlab, create a new Python 3 notebook
- Connect to the Pathways proxy server
In the notebook, add a cell to install pathwaysutils
, set JAX_PLATFORMS
to
proxy
, and set JAX_BACKEND_TARGET
to PROXY_ADDRESS
.
!pip install pathwaysutils %env JAX_PLATFORMS=proxy # Replace your proxy address below: %env JAX_BACKEND_TARGET=PROXY_ADDRESS
Add a second cell as a "hello world" type check and print the devices in the Pathways cluster.
import pathwaysutils
import jax
pathwaysutils.initialize()
print(jax.devices())
If everything is working well, you should see a message indicating the Pathways-on-Cloud backend was detected.
The number of JAX devices listed should match the number of TPU chips and the number of slices you specified when you created the Pathways cluster.
Add your code to a notebook
Add your own JAX code and execute interactively on the TPUs in the Pathways cluster. The following code shows how to perform computations across two slices from a single notebook.
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
# You can use JAX APIs as usual across any of the devices.
jax.jit(jnp.sin, device=jax.devices()[-1])(np.pi / 2.)
# pmap can run across all devices on all slices
num_tpus = jax.device_count()
f = jax.pmap(lambda x: lax.psum(1, 'i'), 'i')
x = jnp.arange(num_tpus)
y = f(x)
print(y)
# You can also target devices from a specific slice
slice0_devices = [d for d in jax.devices() if d.slice_index == 0]
f = jax.pmap(lambda x: lax.psum(1, 'i'), 'i', devices=slice0_devices)
x = jnp.arange(len(slice0_devices))
y = f(x)
print(y)
print(y.global_shards)
# You can send data produced on one slice to another slice
slice1_devices = [d for d in jax.devices() if d.slice_index == 1]
g = jax.pmap(lambda x: x + lax.axis_index('i'), 'i', devices=slice1_devices)
z = g(y)
print(z)
print(z.global_shards)
Delete your Pathways interactive cluster
XPK workload
xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE
kubectl
kubectl delete -f pathways-headless-workload.yaml
What's next
- Create a GKE Cluster with Pathways
- Multihost inference with Pathways
- Batch workloads with Pathways
- Pathways interactive mode
- Porting JAX workloads to Pathways
- Resilient training with Pathways
- Troubleshooting Pathways