Inferencia de JetStream MaxText en VMs de TPU v6e
En este instructivo, se muestra cómo usar JetStream para entregar modelos de MaxText en TPU v6e. JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos de lenguaje grandes (LLM) en dispositivos XLA (TPU). En este instructivo, ejecutarás la comparativa de inferencia para el modelo Llama2-7B.
Antes de comenzar
Prepara el aprovisionamiento de una TPU v6e con 4 chips:
Sigue las instrucciones de la guía Configura el entorno de Cloud TPU para configurar un proyecto de Google Cloud , configurar Google Cloud CLI, habilitar la API de Cloud TPU y asegurarte de tener acceso para usar las Cloud TPU.
Autentica con Google Cloud y configura el proyecto y la zona predeterminados para Google Cloud CLI.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Cómo proteger la capacidad
Cuando tengas todo listo para proteger la capacidad de las TPU, consulta Cuotas de Cloud TPU para obtener más información sobre las cuotas de Cloud TPU. Si tienes más preguntas sobre cómo proteger la capacidad, comunícate con tu equipo de ventas o de cuentas de Cloud TPU.
Aprovisiona el entorno de Cloud TPU
Puedes aprovisionar VMs de TPU con GKE, con GKE y XPK, o como recursos en cola.
Requisitos previos
- Verifica que tu proyecto tenga suficiente cuota de
TPUS_PER_TPU_FAMILY
, que especifica la cantidad máxima de chips a los que puedes acceder en tu proyecto deGoogle Cloud . - Verifica que tu proyecto tenga suficiente cuota de TPU para lo siguiente:
- Cuota de VM de TPU
- Cuota de direcciones IP
- Quota de Hyperdisk Balanced
- Permisos del proyecto del usuario
- Si usas GKE con XPK, consulta Permisos de la consola de Google Cloud en la cuenta de usuario o de servicio para conocer los permisos necesarios para ejecutar XPK.
Crea variables de entorno
En Cloud Shell, crea las siguientes variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-b export ACCELERATOR_TYPE=v6e-4 export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descripciones de las variables de entorno
Variable | Descripción |
---|---|
PROJECT_ID |
El Google Cloud ID de tu proyecto. Usa un proyecto existente o crea uno nuevo. |
TPU_NAME |
El nombre de la TPU. |
ZONE |
Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU. |
ACCELERATOR_TYPE |
El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. |
RUNTIME_VERSION |
La versión de software de Cloud TPU |
SERVICE_ACCOUNT |
La dirección de correo electrónico de tu cuenta de servicio. Para encontrarla, ve a la
página Cuentas de servicio en la
consola de Google Cloud .
Por ejemplo:
|
QUEUED_RESOURCE_ID |
El ID de texto asignado por el usuario de la solicitud de recursos en cola. |
Aprovisiona una TPU v6e
Usa el siguiente comando para aprovisionar una TPU v6e:
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Usa los comandos list
o describe
para consultar el estado de tu recurso en cola.
gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE}
Para obtener más información sobre los estados de las solicitudes de recursos en cola, consulta Administra recursos en cola.
Cómo conectarse a la TPU con SSH
gcloud compute tpus tpu-vm ssh ${TPU_NAME}
Una vez que te conectes a la TPU, podrás ejecutar la comparativa de inferencia.
Configura tu entorno de VM de TPU
Crea un directorio para ejecutar la comparativa de inferencia:
export MAIN_DIR=your-main-directory mkdir -p ${MAIN_DIR}
Configura un entorno virtual de Python:
cd ${MAIN_DIR} sudo apt update sudo apt install python3.10 python3.10-venv python3.10 -m venv venv source venv/bin/activate
Instala el almacenamiento de archivos grandes (LFS) de Git (para datos de OpenOrca):
sudo apt-get install git-lfs git lfs install
Clona e instala JetStream:
cd $MAIN_DIR git clone https://github.com/google/JetStream.git cd JetStream git checkout main pip install -e . cd benchmarks pip install -r requirements.in
Configura MaxText:
cd $MAIN_DIR git clone https://github.com/google/maxtext.git cd maxtext git checkout main bash setup.sh pip install torch --index-url https://download.pytorch.org/whl/cpu
Solicita acceso a los modelos de Llama para obtener una clave de descarga de Meta para el modelo Llama 2.
Clona el repositorio de Llama:
cd $MAIN_DIR git clone https://github.com/meta-llama/llama cd llama
Ejecuta
bash download.sh
. Cuando se te solicite, proporciona tu clave de descarga. Esta secuencia de comandos crea un directoriollama-2-7b
dentro de tu directoriollama
.bash download.sh
Crea buckets de almacenamiento:
export CHKPT_BUCKET=gs://your-checkpoint-bucket export BASE_OUTPUT_DIRECTORY=gs://your-output-dir export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data gcloud storage buckets create ${CHKPT_BUCKET} gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} gcloud storage buckets create ${CONVERTED_CHECKPOINT_PATH} gcloud storage buckets create ${MAXTEXT_BUCKET_UNSCANNED} gcloud storage cp --recursive llama-2-7b/* ${CHKPT_BUCKET}
Realiza la conversión de puntos de control
Realiza la conversión a los puntos de control escaneados:
cd $MAIN_DIR/maxtext python3 -m MaxText.llama_or_mistral_ckpt \ --base-model-path $MAIN_DIR/llama/llama-2-7b \ --model-size llama2-7b \ --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
Convierte en puntos de control sin analizar:
export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint python3 -m MaxText.generate_param_only_checkpoint \ MaxText/configs/base.yml \ base_output_directory=${MAXTEXT_BUCKET_UNSCANNED} \ load_parameters_path=${CONVERTED_CHECKPOINT} \ run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} \ model_name='llama2-7b' \ force_unroll=true
Realiza la inferencia
Ejecuta una prueba de validación:
export UNSCANNED_CKPT_PATH=${MAXTEXT_BUCKET_UNSCANNED}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items python3 -m MaxText.decode \ MaxText/configs/base.yml \ load_parameters_path=${UNSCANNED_CKPT_PATH} \ run_name=runner_decode_unscanned_${idx} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ per_device_batch_size=1 \ model_name='llama2-7b' \ ici_autoregressive_parallelism=4 \ max_prefill_predict_length=4 \ max_target_length=16 \ prompt="I love to" \ attention=dot_product \ scan_layers=false
Ejecuta el servidor en tu terminal actual:
export TOKENIZER_PATH=assets/tokenizer.llama2 export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH} export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-7b export ICI_FSDP_PARALLELISM=1 export ICI_AUTOREGRESSIVE_PARALLELISM=1 export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 cd $MAIN_DIR/maxtext python3 -m MaxText.maxengine_server \ MaxText/configs/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${LOAD_PARAMETERS_PATH} \ max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ max_target_length=${MAX_TARGET_LENGTH} \ model_name=${MODEL_NAME} \ ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ scan_layers=${SCAN_LAYERS} \ weight_dtype=${WEIGHT_DTYPE} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
Abre una ventana de terminal nueva, conéctate a la TPU y cambia al mismo entorno virtual que usaste en la primera ventana de terminal:
source venv/bin/activate
Ejecuta los siguientes comandos para ejecutar la comparativa de JetStream.
export MAIN_DIR=your-main-directory cd $MAIN_DIR python JetStream/benchmarks/benchmark_serving.py \ --tokenizer $MAIN_DIR/maxtext/assets/tokenizer.llama2 \ --warmup-mode sampled \ --save-result \ --save-request-outputs \ --request-outputs-file-path outputs.json \ --num-prompts 1000 \ --max-output-length 1024 \ --dataset openorca \ --dataset-path $MAIN_DIR/JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl
Resultados
El siguiente resultado se generó cuando se ejecutó la comparativa con la v6e-8. Los resultados variarán según el hardware, el software, el modelo y las redes.
Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms
Limpia
Desconecta la TPU:
$ (vm) exit
Borra la TPU:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --force \ --async
Borra los buckets y su contenido:
export CHKPT_BUCKET=gs://your-checkpoint-bucket export BASE_OUTPUT_DIRECTORY=gs://your-output-dir export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data gcloud storage rm -r ${CHKPT_BUCKET} gcloud storage rm -r ${BASE_OUTPUT_DIRECTORY} gcloud storage rm -r ${CONVERTED_CHECKPOINT_PATH} gcloud storage rm -r ${MAXTEXT_BUCKET_UNSCANNED}