Before you begin
Follow the steps in Set up the Cloud TPU environment to create a Google Cloud project, activate the TPU API, install the TPU CLI, and request TPU quota.
Follow the steps in Create a Cloud TPU using the CreateNode API to
create a TPU VM setting --accelerator-type
to v5litepod-8
.
Clone the JetStream repository and install dependencies
Connect to your TPU VM using SSH
- Set ${TPU_NAME} to your TPU's name.
- Set ${PROJECT} to your Google Cloud project
- Set ${ZONE} to the Google Cloud zone in which to create your TPUs
gcloud compute config-ssh gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
Get the jetstream-pytorch code
bash git clone https://github.com/google/jetstream-pytorch.git git checkout jetstream-v0.2.4
(optional) Create a virtual env using venv
or conda
and activate it.
sudo apt install python3.10-venv
python -m venv venv
source venv/bin/activate
- Run installation script:
cd jetstream-pytorch
source install_everything.sh
Run jetstream pytorch
List out supported models
jpt list
This will print out list of support models and variants:
meta-llama/Llama-2-7b-chat-hf
meta-llama/Llama-2-7b-hf
meta-llama/Llama-2-13b-chat-hf
meta-llama/Llama-2-13b-hf
meta-llama/Llama-2-70b-hf
meta-llama/Llama-2-70b-chat-hf
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3-70B
meta-llama/Meta-Llama-3-70B-Instruct
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
mistralai/Mixtral-8x7B-v0.1
mistralai/Mixtral-8x7B-Instruct-v0.1
To run jetstream-pytorch server with one model:
bash
jpt serve --model_id meta-llama/Llama-2-7b-chat-hf
The first time you run this model, the jpt serve
command will attempt
to download weights from HuggingFace which requires that you authenticate with
HuggingFace.
To authenticate, run huggingface-cli login
to set your access token, or pass
your HuggingFace access token to the jpt serve
command using the --hf_token
flag:
jpt serve --model_id meta-llama/Llama-2-7b-chat-hf --hf_token=...
For more information about HuggingFace access tokens, see Access Tokens.
To login using HuggingFace Hub, run the following command and follow the prompts:
pip install -U "huggingface_hub[cli]"
huggingface-cli login
After the weights are downloaded, you no longer need to specify the --hf_token
flag.
To run this model with int8
quantization, add --quantize_weights=1
.
Quantization will be done on the flight as the weight loads.
Weights downloaded from HuggingFace are stored by default in a directory called
checkpoints
folder in the directory where you run jpt
. You can change also
specify a directory using the --working_dir
flag.
If you want to use your own checkpoint, place them inside the
checkpoints/<org>/<model>/hf_original
dir (or the corresponding subdir in
--working_dir
). For example, Llama2-7b checkpoints will be in checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors
. You can replace these files with modified weights in
HuggingFace format.
Run benchmarks
Change to the deps/JetStream
folder that was downloaded when you ran
install_everything.sh
.
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
For more information see, deps/JetStream/benchmarks/README.md
.
Typical errors
If you get an Unexpected keyword argument 'device'
error, try the following:
- Uninstall
jax
andjaxlib
dependencies - Reinstall using
source install_everything.sh
If you get an Out of memory
error, try the following:
- Use smaller batch size
- Use quantization
Clean up
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
Clean up the GitHub repositories
# Clean up the JetStream repository rm -rf JetStream # Clean up the xla repository rm -rf xla
Clean up the python virtual environment
rm -rf .env
Delete your TPU resources
For more information, see Delete your TPU resources.