JetStream PyTorch Inference on v5e Cloud TPU VM


JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs).

Before you begin

Follow the steps in Set up the Cloud TPU environment to create a Google Cloud project, activate the TPU API, install the TPU CLI, and request TPU quota.

Follow the steps in Create a Cloud TPU using the CreateNode API to create a TPU VM setting --accelerator-type to v5litepod-8.

Clone the JetStream repository and install dependencies

  1. Connect to your TPU VM using SSH

    • Set ${TPU_NAME} to your TPU's name.
    • Set ${PROJECT} to your Google Cloud project
    • Set ${ZONE} to the Google Cloud zone in which to create your TPUs
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Get the jetstream-pytorch code bash git clone https://github.com/google/jetstream-pytorch.git git checkout jetstream-v0.2.4

(optional) Create a virtual env using venv or conda and activate it.

sudo apt install python3.10-venv
python -m venv venv
source venv/bin/activate
  1. Run installation script:
cd jetstream-pytorch
source install_everything.sh

Run jetstream pytorch

List out supported models

jpt list

This will print out list of support models and variants:

meta-llama/Llama-2-7b-chat-hf
meta-llama/Llama-2-7b-hf
meta-llama/Llama-2-13b-chat-hf
meta-llama/Llama-2-13b-hf
meta-llama/Llama-2-70b-hf
meta-llama/Llama-2-70b-chat-hf
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3-70B
meta-llama/Meta-Llama-3-70B-Instruct
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
mistralai/Mixtral-8x7B-v0.1
mistralai/Mixtral-8x7B-Instruct-v0.1

To run jetstream-pytorch server with one model: bash jpt serve --model_id meta-llama/Llama-2-7b-chat-hf

The first time you run this model, the jpt serve command will attempt to download weights from HuggingFace which requires that you authenticate with HuggingFace.

To authenticate, run huggingface-cli login to set your access token, or pass your HuggingFace access token to the jpt serve command using the --hf_token flag:

jpt serve --model_id meta-llama/Llama-2-7b-chat-hf --hf_token=...

For more information about HuggingFace access tokens, see Access Tokens.

To login using HuggingFace Hub, run the following command and follow the prompts:

pip install -U "huggingface_hub[cli]"
huggingface-cli login

After the weights are downloaded, you no longer need to specify the --hf_token flag.

To run this model with int8 quantization, add --quantize_weights=1. Quantization will be done on the flight as the weight loads.

Weights downloaded from HuggingFace are stored by default in a directory called checkpoints folder in the directory where you run jpt. You can change also specify a directory using the --working_dir flag.

If you want to use your own checkpoint, place them inside the checkpoints/<org>/<model>/hf_original dir (or the corresponding subdir in --working_dir). For example, Llama2-7b checkpoints will be in checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors. You can replace these files with modified weights in HuggingFace format.

Run benchmarks

Change to the deps/JetStream folder that was downloaded when you ran install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

For more information see, deps/JetStream/benchmarks/README.md.

Typical errors

If you get an Unexpected keyword argument 'device' error, try the following:

  • Uninstall jax and jaxlib dependencies
  • Reinstall using source install_everything.sh

If you get an Out of memory error, try the following:

  • Use smaller batch size
  • Use quantization

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

  1. Clean up the GitHub repositories

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Clean up the python virtual environment

    rm -rf .env
    
  3. Delete your TPU resources

    For more information, see Delete your TPU resources.