MaxDiffusion inference on v6e TPUs

This tutorial shows how to serve MaxDiffusion models on TPU v6e. In this tutorial, you generate images using the Stable Diffusion XL model.

Before you begin

Prepare to provision a TPU v6e with 4 chips:

  1. Follow Set up the Cloud TPU environment guide to ensure you have appropriate access to use Cloud TPUs.

  2. Create a service identity for the TPU VM.

    gcloud alpha compute tpus tpu-vm service-identity create --zone=zone
  3. Create a TPU service account and grant access to Google Cloud services.

    Service accounts allow the Google Cloud TPU service to access other Google Cloud services. A user-managed service account is recommended. You can create a service account from the Google Cloud console or through the gcloud command.

    Create a service account using the gcloud command-line tool:

    gcloud iam service-accounts create your-service-account-name \
    --description="your-sa-description" \
    --display-name="your-sa-display-name"
    export SERVICE_ACCOUNT_NAME=your-service-account-name

    Create a service account from the Google Cloud console:

    1. Go to the Service Accounts page in the Google Cloud console.
    2. Click Create service account.
    3. Enter the service account name.
    4. (Optional) Enter a description for the service account.
    5. Click Create and continue.
    6. Choose the roles you want to grant to the service account.
    7. Click Continue.
    8. (Optional) Specify users or groups that can manage the service account.
    9. Click Done to finish creating the service account.

    After creating your service account, follow these steps to grant service account roles.

    The following roles are necessary:

    • TPU Admin: Needed to create a TPU
    • Storage Admin: Needed for accessing Cloud Storage
    • Logs Writer
    • Monitoring Metric Writer: Needed for writing metrics to Cloud Monitoring

    Your administrator must grant you the roles/resourcemanager.projectIamAdmin in order for you to assign IAM roles to users. A user with the Project IAM Admin roles/resourcemanager.projectIamAdmin role can also grant this role.

    Use the following gcloud commands to add service account roles:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/tpu.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/storage.admin
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/logging.logWriter
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
       --member serviceAccount:${SERVICE_ACCOUNT_NAME}@${PROJECT_ID}.iam.gserviceaccount.com \
       --role roles/monitoring.metricWriter

    You can also assign roles using the Google Cloud console.

    From the Google Cloud console, select the following roles:

    1. Select your service account and click Add Principal.
    2. In the New Principals field, enter the email address of your service account.
    3. In the Select a role drop-down, search for role (for example, Storage Admin) and select it.
    4. Click Save.
  4. Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Secure capacity

When you are ready to secure TPU capacity, review the quotas page to learn about the Cloud Quotas system. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.

Provision the Cloud TPU environment

You can provision TPU VMs with GKE, with GKE and XPK, or as queued resources.

Prerequisites

  • This tutorial has been tested with Python 3.10 or later.
  • Verify that your project has enough TPUS_PER_TPU_FAMILY quota, which specifies the maximum number of chips you can access within your Google Cloud project.
  • Verify that your project has enough TPU quota for:
    • TPU VM quota
    • IP Address quota
    • Hyperdisk balanced quota
  • User project permissions

Provision a TPU v6e

   gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT

Use the list or describe commands to query the status of your queued resource.

   gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
      --project ${PROJECT_ID} --zone ${ZONE}

For a complete list of queued resource request statuses, see the Queued Resources documentation.

Connect to the TPU using SSH

   gcloud compute tpus tpu-vm ssh TPU_NAME

Create a Conda environment

  1. Create a directory for Miniconda:

    mkdir -p ~/miniconda3
  2. Download the Miniconda installer script:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Install Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Remove the Miniconda installer script:

    rm -rf ~/miniconda3/miniconda.sh
  5. Add Miniconda to your PATH variable:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. Reload ~/.bashrc to apply the changes to the PATH variable:

    source ~/.bashrc
  7. Create a new Conda environment:

    conda create -n tpu python=3.10
  8. Activate the Conda environment:

    source activate tpu

Set up MaxDiffusion

  1. Clone the MaxDiffusion repository and navigate to the MaxDiffusion directory:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. Switch to the mlperf-4.1 branch:

    git checkout mlperf4.1
  3. Install MaxDiffusion:

    pip install -e .
  4. Install dependencies:

    pip install -r requirements.txt
  5. Install JAX:

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  6. Install additional dependencies:

     pip install huggingface_hub==0.25 absl-py flax tensorboardX google-cloud-storage torch tensorflow transformers 

Generate images

  1. Set environment variables to configure the TPU runtime:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. Generate images using the prompt and configurations defined in src/maxdiffusion/configs/base_xl.yml:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

    When the images have been generated, be sure to clean up the TPU resources.

Clean up

Delete the TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async