Run an interactive workload with Pathways

Pathways interactive workloads are remote JAX workloads that run within a VM that is not part of GKE cluster hosting the Pathways cluster. Unlike batch workloads, the completion of interactive workload operation does not shut down the Pathways cluster components, which remain available for connection by other JAX clients. This document uses a Jupyter notebook as an example to demonstrate interactive workloads.

Using the IFRT interface, JAX users send commands to a Pathways cluster. JAX code, whether executed from a terminal, notebook, or any Python-compatible environment, can seamlessly interact with Pathways resources.

Before you begin

Make sure you have:

Run Pathways in interactive mode

You can run Pathways in interactive mode using xpk or gcloud.

XPK

  1. Set the following environment variables:

    export WORKLOAD=WORKLOAD
    export WORKLOAD_NODEPOOL_COUNT=WORKLOAD_NODEPOOL_COUNT
    export TPU_TYPE=TPU_TYPE
    export PROJECT_ID=PROJECT
    export ZONE=ZONE \
    export CLUSTER=CLUSTER

    Replace the following:

    • WORKLOAD: set this to a unique name to identify your workload
    • WORKLOAD_NODEPOOL_COUNT: the number of node pools used by a Pathways workload
    • TPU_TYPE: the TPU type specifies the version and size of the Cloud TPU you want to create. For more information about supported TPU types for each TPU version, see TPU versions
    • PROJECT: your Google Cloud project ID
    • ZONE: the zone where you plan to run your workload
    • CLUSTER: the name of your GKE cluster
  2. Create the Pathways containers on the cluster. To run a headless workload, run the following command:

    xpk workload create-pathways \
    --headless \
    --workload=${WORKLOAD} \
    --num-slices=${WORKLOAD_NODEPOOL_COUNT} \
    --tpu-type=${TPU_TYPE} \
    --project=${PROJECT} \
    --zone=${ZONE} \
    --cluster=${CLUSTER}

At this point, your JAX workload can connect to the IFRT proxy server.

gcloud

The following YAML is the same as the batch workload YAML except it doesn't specify the main container.

  1. Replace the placeholders, copy the following YAML, and paste it into a file called pathways-headless-workload.yaml.
    apiVersion: pathways-job.pathways.domain/v1
    kind: PathwaysJob
    metadata:
      name: pathways-USERNAME
    spec:
      maxRestarts: MAX_RESTARTS
      workers:
        - type: TPU_MACHINE_TYPE
          topology: TOPOLOGY
          numSlices: WORKLOAD_NODEPOOL_COUNT
      pathwaysDir: "gs://BUCKET_NAME"
      controller:
        deploymentMode: default
        
    Replace the following:
    • USERNAME : your username
    • MAX_RESTARTS : the maximum number of times the PathwaysJob can be restarted
    • TPU_MACHINE_TYPE : the TPU machine type you want to use
    • TOPOLOGY : the TPU topology
    • WORKLOAD_NODEPOOL_COUNT : the number of node pools used by a Pathways workload
    • BUCKET_NAME : a Cloud Storage bucket used to store temporary files
    To change the number of node pools, (pathways-worker replicas) specified by WORKLOAD_NODEPOOL_COUNT in the previous YAML, you need to delete this PathwaysJob and create a new PathwaysJob with the updated number of node pools. You also need to restart any connected notebooks to establish the connection with the new Pathways cluster.
  2. Apply the pathways-headless-workload.yaml file:
      kubectl apply -f pathways-headless-workload.yaml
      
  3. Run kubectl get pods to check that all containers in the Pod are running. The following output is for a 2 slice v5p 2x2x2, where USER is the ID of the user running the command:
        NAME                                         READY   STATUS    RESTARTS   AGE
        pathways-USER-pathways-head-0-0-n848j      2/2     Running   0          49s
        pathways-USER-pathways-workers-0-0-jxt2z   1/1     Running   0          71s
        pathways-USER-pathways-workers-0-1-cxmhc   1/1     Running   0          70s
        pathways-USER-pathways-workers-1-0-5kmz9   1/1     Running   0          71s
        pathways-USER-pathways-workers-1-1-vg5n4   1/1     Running   0          71s
        

Connecting to the Pathways cluster in interactive mode

You can connect to the Pathways cluster with or without port forwarding. Use one of the following sections to connect to the Pathways cluster.

Connect using port-forwarding

At this point you can use port-forwarding (from any host with access to your cluster's control plane) to access the proxy server:

Use the command appropriate for your workload:

XPK

PROXY_POD=$(kubectl get pods | grep ${WORKLOAD}-pathways-head | awk '{print $1}')
PROXY_PORT=29000
kubectl port-forward ${PROXY_POD} ${PROXY_PORT}:${PROXY_PORT}

You should see output similar to:

Forwarding from 127.0.0.1:29000 -> 29000
Forwarding from [::1]:29000 -> 29000

gcloud

PROXY_POD=$(kubectl get pods | grep pathways-${USER}-pathways-head | awk '{print $1}')

On the same host, open a new terminal window. Set the JAX_PLATFORMS and JAX_BACKEND_TARGET environment variables, and run a Python script that imports pathwaysutils and jax:

python3 -m venv .venv
source .venv/bin/activate
pip install pathwaysutils jax

JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'

You should see output like the following:

[device(144,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(145,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(146,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(147,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
device(148,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(149,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(150,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(151,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
device(162,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(163,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(164,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(165,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
device(166,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(167,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(168,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
device(169,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)]
Waiting up to 5 seconds.
Sent all pending logs.
2024-11-13 21:38:51.267523: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled

Connect from hosts in the VPC without using port forwarding

If you don't want to use port forwarding you can connect to the Pathways cluster using Cloud DNS or an internal load balancer.

Connect using Cloud DNS

Enabling Cloud DNS in your cluster switches the Cloud DNS provider from kube-dns to Cloud DNS. When enabled, a private Cloud DNS zone is created in your Virtual Private Cloud for the Cloud DNS names. For more information, see Using Cloud DNS for GKE.

If you enable Cloud DNS with either the cluster scope and additive VPC scope or VPC scope, the Kubernetes Cloud DNS names are resolvable from non-GKE VMs inside your Virtual Private Cloud. The names have the format <service_name>.<namespace>.svc.<custom_dns_domain>. The Pathways head Pod has a service named <jobset_name>-pathways-head-0-0.<jobset_name>.<namespace>.svc.<custom_dns_domain>.

The following example shows how to:

  • Enable Cloud DNS with cluster scope
  • Enable additive VPC scope with custom domain ${USER}-test
  • Create a CPU-only VM in the same VPC as the GKE cluster.
  1. Open a new terminal window. Set the JAX_PLATFORMS and JAX_BACKEND_TARGET environment variables, and run a Python script that imports pathwaysutils and jax:

    python3 -m venv .venv
    source .venv/bin/activate
    pip install pathwaysutils jax
    
    $ JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'
    

    You should see output like the following:

    [device(144,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(145,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(146,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(147,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(148,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(149,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(150,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(151,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(162,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(163,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(164,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(165,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(166,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
    device(167,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
    device(168,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
    device(169,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)]
    Waiting up to 5 seconds.
    Sent all pending logs.
    2024-11-13 21:38:51.267523: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled
    
  2. Confirm the leader Cloud DNS entry is resolvable from a non-GKE host:

    host pathways-USERNAME-pathways-head-0-0.pathways-USERNAME.default.svc.USERNAME-test

    You should see output similar to:

    pathways-<user>-pathways-head-0-0.pathways-<user>.default.svc.<user>-test has address 10.0.2.75
  3. Connect to the Pathways cluster using the Cloud DNS name:

    JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://pathways-USERNAME-pathways-head-0-0.pathways-USERNAME.default.svc.USERNAME-test:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'

    You should see output similar to:

    [device(216,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(217,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(218,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(219,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
    device(220,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(221,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(222,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(223,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
    device(234,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(235,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(236,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(237,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
    device(238,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
    device(239,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
    device(240,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
    device(241,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)]
    Waiting up to 5 seconds.
    Sent all pending logs.
    2024-11-14 00:02:49.882044: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled

Connect using an internal load balancer

For a private IP address in your VPC that points to your pathways deployment, create a service backed by an internal load balancer. This does not require your cluster to have Cloud DNS enabled.

For clusters with many VMs, we recommend that you enable the ILB subsetting if your are creating internal load balancers. For more information, see Enable GKE subsetting in an existing cluster. When the ILB subsetting is not enabled, all nodes in the cluster will be part of the backend instance group for all internal load balancers. This does not scale beyond 250 nodes. With ILB subsetting enabled, GKE creates network endpoint groups instead of instance groups and only nodes that are running one of the service's serving Pods are included. Enabling ILB subsetting has a one-time setup latency (~15 minutes). The following command shows how to enable the ILB subsetting:

gcloud container clusters update ${CLUSTER} \
  --project=${PROJECT} \
  [--zone=${ZONE} | --region=${REGION}] \
  --enable-l4-ilb-subsetting

Once ILB subsetting is enabled, You can create a Kubernetes service of type LoadBalancer using the following yaml. This will cause GKE to create an internal load balancer inside your cluster's VPC:

apiVersion: v1
kind: Service
metadata:
  name: pathways-USERNAME-ilb
  annotations:
    networking.gke.io/load-balancer-type: "Internal"
    networking.gke.io/internal-load-balancer-allow-global-access: "true"
spec:
  type: LoadBalancer
  externalTrafficPolicy: Local
  selector:
    jobset.sigs.k8s.io/jobset-name: pathways-USER
    jobset.sigs.k8s.io/replicatedjob-name: pathways-head
  ports:
  - name: tcp-port
    protocol: TCP
    port: 29000
    targetPort: 29000

Update the USER with your Google Cloud user ID and save the file as pathways-headless-ilb.yaml.

Apply the manifest:

kubectl apply -f pathways-headless-ilb.yaml

After the load balancer is created (~a minute later), the EXTERNAL-IP column will have a value:

kubectl get services
NAME                  TYPE           CLUSTER-IP      EXTERNAL-IP   PORT(S)        AGE
pathways-$USER       ClusterIP      None            <none>        <none>         30m
pathways-$USER-ilb   LoadBalancer   34.118.232.46   10.0.0.22     80:31246/TCP   2m41s

You can access the pathways deployment without port-forwarding on hosts in the same VPC as your cluster:

JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://10.0.0.22:29000 python -c 'import pathwaysutils; import jax; import pprint; pathwaysutils.initialize(); pprint.pprint(jax.devices())'

You should see output similar to:

[device(288,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
 device(289,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
 device(290,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
 device(291,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3),
 device(292,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
 device(293,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
 device(294,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
 device(295,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=0,default_mem=device,mem_spaces=3),
 device(306,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
 device(307,TPU_DEVICE,coords=[1,0,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
 device(308,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
 device(309,TPU_DEVICE,coords=[1,1,0,0],vtask=0,slice=1,default_mem=device,mem_spaces=3),
 device(310,TPU_DEVICE,coords=[0,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
 device(311,TPU_DEVICE,coords=[1,0,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
 device(312,TPU_DEVICE,coords=[0,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3),
 device(313,TPU_DEVICE,coords=[1,1,1,0],vtask=1,slice=1,default_mem=device,mem_spaces=3)]
Waiting up to 5 seconds.
Sent all pending logs.
2024-11-14 00:30:07.296917: W external/xla/xla/python/ifrt_proxy/client/grpc_client.cc:63] IFRT proxy server disconnected: CANCELLED: Cancelled

Jupyter notebooks

You can create a Jupyter notebook using Vertex AI or you can create a self hosted Jupyter notebook.

Create a Vertex AI workbench instance

After setting up and verifying your Pathways cluster, you can access the GKE TPU VMs from a Vertex AI Jupyter notebook. The following setup instructions assume your GKE Pathways cluster resides in the same Virtual Private Cloud network (which is the default network unless you've configured otherwise). Navigate to the Vertex AI Workbench console.

Create a new Workbench instance (from the Instances tab) with the Create new button. Ensure that the network is the same as your GKE cluster's network. You can use the command line to create a new Workbench instance.

gcloud workbench instances create INSTANCE_NAME \
--machine-type=e2-standard-4 \
--data-disk-size=100 \
--location=ZONE \
[--network=NETWORK]

Once the instance is created, navigate to it and click Open Jupyterlab.

Create a self hosted Jupyter notebook instance

The following command shows how to create a self hosted Jupyter notebook instance using XPK:

xpk workload create-pathways \
--workload=${WORKLOAD} \
--num-slices=${WORKLOAD_NODEPOOL_COUNT} \
--tpu-type=${TPU_TYPE} \
--project=${PROJECT} \
--zone=${ZONE} \
--cluster=${CLUSTER} \
--docker-image=jupyter/base-notebook \
--command "start-notebook.sh"

This command will display a URL that you can use to open the notebook in your browser.

The following YAML shows how to create a self hosted Jupyter notebook instance using kubectl. Apply the following YAML after a headless Pathways cluster has been created. For more information, see Run Pathways in interactive mode with kubectl.

apiVersion: batch/v1
kind: Job
metadata:
  name: jupyter-notebook-USERNAME
spec:
  template:
    spec:
      restartPolicy: OnFailure
      containers:
      - name: jupyter-notebook
        image: jupyter/base-notebook  # Use the appropriate Jupyter image
        ports:
        - containerPort: 8888

Connect to the notebook from your local machine using port forwarding:

kubectl

  MAIN_POD=$(kubectl get pods | grep jupyter-notebook-USER | awk '{print $1}')
  kubectl port-forward pod/${MAIN_POD} 8888:8888

XPK

  MAIN_POD=$(kubectl get pods | grep ${WORKLOAD}-main | awk '{print $1}')

Navigate on your local browser to http://localhost:8888?token=<var>your-token</var>. Replace <your-token> with the token from the Jupyter notebook container's logs.

kubectl logs ${MAIN_POD}

Which should output:

...
Or copy and paste one of these URLs:
  http://jupyter-notebook-<use>-bbbdh:8888/lab?token=<token>
  http://127.0.0.1:8888/lab?token=<token>

Notebook connectivity to the Pathways cluster

  1. From within Jupyterlab, create a new Python 3 notebook
  2. Connect to the Pathways proxy server

In the notebook, add a cell to install pathwaysutils, set JAX_PLATFORMS to proxy, and set JAX_BACKEND_TARGET to PROXY_ADDRESS.

!pip install pathwaysutils
%env JAX_PLATFORMS=proxy
# Replace your proxy address below:
%env JAX_BACKEND_TARGET=PROXY_ADDRESS

Add a second cell as a "hello world" type check and print the devices in the Pathways cluster.

import pathwaysutils
import jax

pathwaysutils.initialize()
print(jax.devices())

If everything is working well, you should see a message indicating the Pathways-on-Cloud backend was detected.

The number of JAX devices listed should match the number of TPU chips and the number of slices you specified when you created the Pathways cluster.

Add your code to a notebook

Add your own JAX code and execute interactively on the TPUs in the Pathways cluster. The following code shows how to perform computations across two slices from a single notebook.

import jax
import jax.numpy as jnp
from jax import lax
import numpy as np

# You can use JAX APIs as usual across any of the devices.
jax.jit(jnp.sin, device=jax.devices()[-1])(np.pi / 2.)

# pmap can run across all devices on all slices
num_tpus = jax.device_count()
f = jax.pmap(lambda x: lax.psum(1, 'i'), 'i')
x = jnp.arange(num_tpus)
y = f(x)
print(y)

# You can also target devices from a specific slice
slice0_devices = [d for d in jax.devices() if d.slice_index == 0]
f = jax.pmap(lambda x: lax.psum(1, 'i'), 'i', devices=slice0_devices)
x = jnp.arange(len(slice0_devices))
y = f(x)
print(y)
print(y.global_shards)

# You can send data produced on one slice to another slice
slice1_devices = [d for d in jax.devices() if d.slice_index == 1]
g = jax.pmap(lambda x: x + lax.axis_index('i'), 'i', devices=slice1_devices)
z = g(y)
print(z)
print(z.global_shards)

Delete your Pathways interactive cluster

XPK workload

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

kubectl

kubectl delete -f pathways-headless-workload.yaml

What's next