准备工作
按照设置 Cloud TPU 环境中的步骤创建一个 Google Cloud 项目、激活 TPU API、安装 TPU CLI 并申请 TPU 配额。
按照使用 CreateNode API 创建 Cloud TPU 中的步骤创建一个 TPU 虚拟机,并将 --accelerator-type
设置为 v5litepod-8
。
克隆 JetStream 代码库并安装依赖项
使用 SSH 连接到您的 TPU 虚拟机
- 将 ${TPU_NAME} 设置为您的 TPU 名称。
- 将 ${PROJECT} 设为您的 Google Cloud 项目
- 将 ${ZONE} 设置为要在其中创建 TPU 的 Google Cloud 可用区
gcloud compute config-ssh gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
克隆 JetStream 代码库
git clone https://github.com/google/jetstream-pytorch.git
(可选)使用
venv
或conda
创建一个虚拟 Python 环境并将其激活。运行安装脚本
cd jetstream-pytorch source install_everything.sh
下载并转换权重
从 GitHub 下载官方 Llama 权重。
转换权重。
- 将 ${IN_CKPOINT} 设置为包含 Llama 权重的位置
- 将 ${OUT_CKPOINT} 设置为位置写入检查点
export input_ckpt_dir=${IN_CKPOINT} export output_ckpt_dir=${OUT_CKPOINT} export quantize=True python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
在本地运行 JetStream PyTorch 引擎
如需在本地运行 JetStream PyTorch 引擎,请设置标记生成器路径:
export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama
使用 Llama 7B 运行 JetStream PyTorch 引擎
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
使用 Llama 13b 运行 JetStream PyTorch 引擎
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
运行 JetStream 服务器
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8
注意:--platform=tpu=
参数需要指定 TPU 设备的数量(v4-8
为 4,v5lite-8
为 8)。例如 --platform=tpu=8
。
运行 run_server.py
后,JetStream PyTorch 引擎即可接收 gRPC 调用。
运行基准测试
切换到您运行 install_everything.sh
时下载的 deps/JetStream
文件夹。
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
如需了解详情,请参阅 deps/JetStream/benchmarks/README.md
。
典型错误
如果您收到 Unexpected keyword argument 'device'
错误,请尝试以下操作:
- 卸载
jax
和jaxlib
依赖项 - 使用
source install_everything.sh
重新安装
如果您收到 Out of memory
错误,请尝试以下操作:
- 使用较小的批次大小
- 使用量化
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
清理 GitHub 代码库
# Clean up the JetStream repository rm -rf JetStream # Clean up the xla repository rm -rf xla
清理 Python 虚拟环境
rm -rf .env
删除 TPU 资源
如需了解详情,请参阅删除 TPU 资源。