Inférence JetStream MaxText sur les VM TPU v6e

Ce tutoriel explique comment utiliser JetStream pour diffuser des modèles MaxText sur TPU v6e. JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence LLM (Large Language Model) sur les appareils XLA (TPU). Dans ce tutoriel, vous exécutez le benchmark d'inférence pour le modèle Llama2-7B.

Avant de commencer

Préparez-vous à provisionner un TPU v6e avec quatre puces:

  1. Suivez le guide Configurer l'environnement Cloud TPU pour configurer un projet Google Cloud , configurer Google Cloud CLI, activer l'API Cloud TPU et vous assurer que vous avez accès à l'utilisation des Cloud TPU.

  2. Authentifiez-vous avec Google Cloud et configurez le projet et la zone par défaut pour Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Sécuriser la capacité

Lorsque vous êtes prêt à sécuriser la capacité de TPU, consultez la section Quotas Cloud TPU pour en savoir plus sur les quotas Cloud TPU. Si vous avez d'autres questions sur la sécurisation de la capacité, contactez votre équipe commerciale ou chargée de votre compte Cloud TPU.

Provisionner l'environnement Cloud TPU

Vous pouvez provisionner des VM TPU avec GKE, avec GKE et XPK, ou en tant que ressources mises en file d'attente.

Prérequis

  • Vérifiez que votre projet dispose d'un quota TPUS_PER_TPU_FAMILY suffisant, qui spécifie le nombre maximal de puces auxquelles vous pouvez accéder dans votre projetGoogle Cloud .
  • Vérifiez que votre projet dispose d'un quota TPU suffisant pour :
    • Quota de VM TPU
    • Quota d'adresses IP
    • Quota Hyperdisk équilibré
  • Autorisations des utilisateurs pour les projets

Créer des variables d'environnement

Dans Cloud Shell, créez les variables d'environnement suivantes:

export PROJECT_ID=your-project-id
export TPU_NAME=your-tpu-name
export ZONE=us-east5-b
export ACCELERATOR_TYPE=v6e-4
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id

Descriptions des variables d'environnement

Variable Description
PROJECT_ID L'ID de votre Google Cloud projet. Utilisez un projet existant ou créez-en un.
TPU_NAME Nom du TPU.
ZONE Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez la section Régions et zones de TPU.
ACCELERATOR_TYPE Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU.
RUNTIME_VERSION Version logicielle de Cloud TPU.
SERVICE_ACCOUNT Adresse e-mail de votre compte de service. Pour le trouver, accédez à la page Comptes de service dans la console Google Cloud .

Par exemple : tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

QUEUED_RESOURCE_ID ID de texte attribué par l'utilisateur de la requête de ressource mise en file d'attente.

Provisionner un TPU v6e

Utilisez la commande suivante pour provisionner un TPU v6e:

gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id=${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --accelerator-type=${ACCELERATOR_TYPE} \
    --runtime-version=${RUNTIME_VERSION} \
    --service-account=${SERVICE_ACCOUNT}

Utilisez les commandes list ou describe pour interroger l'état de votre ressource mise en file d'attente.

gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
    --project ${PROJECT_ID} --zone ${ZONE}

Pour en savoir plus sur les états des demandes de ressources en file d'attente, consultez la section Gérer les ressources en file d'attente.

Se connecter au TPU à l'aide de SSH

   gcloud compute tpus tpu-vm ssh ${TPU_NAME}

Une fois connecté au TPU, vous pouvez exécuter le benchmark d'inférence.

Configurer votre environnement VM TPU

  1. Créez un répertoire pour exécuter le benchmark d'inférence:

    export MAIN_DIR=your-main-directory
    mkdir -p ${MAIN_DIR}
  2. Configurez un environnement virtuel Python:

    cd ${MAIN_DIR}
    sudo apt update
    sudo apt install python3.10 python3.10-venv
    python3.10 -m venv venv
    source venv/bin/activate
  3. Installez le stockage de fichiers volumineux (LFS, Large File Storage) Git (pour les données OpenOrca):

    sudo apt-get install git-lfs
    git lfs install
  4. Clonez et installez JetStream:

    cd $MAIN_DIR
    git clone https://github.com/google/JetStream.git
    cd JetStream
    git checkout main
    pip install -e .
    cd benchmarks
    pip install -r requirements.in
  5. Configurez MaxText:

    cd $MAIN_DIR
    git clone https://github.com/google/maxtext.git
    cd maxtext
    git checkout main
    bash setup.sh
    pip install torch --index-url https://download.pytorch.org/whl/cpu
  6. Demander l'accès aux modèles Llama pour obtenir une clé de téléchargement de Meta pour le modèle Llama 2.

  7. Clonez le dépôt Llama:

    cd $MAIN_DIR
    git clone https://github.com/meta-llama/llama
    cd llama
  8. Exécutez bash download.sh. Lorsque vous y êtes invité, indiquez votre clé de téléchargement. Ce script crée un répertoire llama-2-7b dans votre répertoire llama.

    bash download.sh
  9. Créez des buckets de stockage:

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage buckets create ${CHKPT_BUCKET}
    gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY}
    gcloud storage buckets create ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage buckets create ${MAXTEXT_BUCKET_UNSCANNED}
    gcloud storage cp --recursive llama-2-7b/* ${CHKPT_BUCKET}

Effectuer la conversion des points de contrôle

  1. Effectuez la conversion en points de contrôle numérisés:

    cd $MAIN_DIR/maxtext
    python3 -m MaxText.llama_or_mistral_ckpt \
        --base-model-path $MAIN_DIR/llama/llama-2-7b \
        --model-size llama2-7b \
        --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
  2. Convertir en points de contrôle non analysés:

    export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items
    export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint
    python3 -m MaxText.generate_param_only_checkpoint \
        MaxText/configs/base.yml \
        base_output_directory=${MAXTEXT_BUCKET_UNSCANNED} \
        load_parameters_path=${CONVERTED_CHECKPOINT} \
        run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} \
        model_name='llama2-7b' \
        force_unroll=true

Effectuer une inférence

  1. Exécutez un test de validation:

    export UNSCANNED_CKPT_PATH=${MAXTEXT_BUCKET_UNSCANNED}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items
    python3 -m MaxText.decode \
        MaxText/configs/base.yml \
        load_parameters_path=${UNSCANNED_CKPT_PATH} \
        run_name=runner_decode_unscanned_${idx} \
        base_output_directory=${BASE_OUTPUT_DIRECTORY} \
        per_device_batch_size=1 \
        model_name='llama2-7b' \
        ici_autoregressive_parallelism=4 \
        max_prefill_predict_length=4 \
        max_target_length=16 \
        prompt="I love to" \
        attention=dot_product \
        scan_layers=false
  2. Exécutez le serveur dans votre terminal actuel:

    export TOKENIZER_PATH=assets/tokenizer.llama2
    export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
    export MAX_PREFILL_PREDICT_LENGTH=1024
    export MAX_TARGET_LENGTH=2048
    export MODEL_NAME=llama2-7b
    export ICI_FSDP_PARALLELISM=1
    export ICI_AUTOREGRESSIVE_PARALLELISM=1
    export ICI_TENSOR_PARALLELISM=-1
    export SCAN_LAYERS=false
    export WEIGHT_DTYPE=bfloat16
    export PER_DEVICE_BATCH_SIZE=11
    
    cd $MAIN_DIR/maxtext
    python3 -m MaxText.maxengine_server \
        MaxText/configs/base.yml \
        tokenizer_path=${TOKENIZER_PATH} \
        load_parameters_path=${LOAD_PARAMETERS_PATH} \
        max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
        max_target_length=${MAX_TARGET_LENGTH} \
        model_name=${MODEL_NAME} \
        ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
        ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
        ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
        scan_layers=${SCAN_LAYERS} \
        weight_dtype=${WEIGHT_DTYPE} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
  3. Ouvrez une nouvelle fenêtre de terminal, connectez-vous au TPU et passez au même environnement virtuel que celui que vous avez utilisé dans la première fenêtre de terminal:

    source venv/bin/activate
    
  4. Exécutez les commandes suivantes pour exécuter le benchmark JetStream.

    export MAIN_DIR=your-main-directory
    cd $MAIN_DIR
    
    python JetStream/benchmarks/benchmark_serving.py \
        --tokenizer $MAIN_DIR/maxtext/assets/tokenizer.llama2 \
        --warmup-mode sampled \
        --save-result \
        --save-request-outputs \
        --request-outputs-file-path outputs.json \
        --num-prompts 1000 \
        --max-output-length 1024 \
        --dataset openorca \
        --dataset-path $MAIN_DIR/JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl

Résultats

Le résultat suivant a été généré lors de l'exécution du benchmark avec la version v6e-8. Les résultats varient en fonction du matériel, des logiciels, du modèle et de la mise en réseau.

Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms

Effectuer un nettoyage

  1. Se déconnecter du TPU:

    $ (vm) exit
  2. Supprimez le TPU:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project ${PROJECT_ID} \
        --zone ${ZONE} \
        --force \
        --async
  3. Supprimez les buckets et leur contenu:

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage rm -r ${CHKPT_BUCKET}
    gcloud storage rm -r ${BASE_OUTPUT_DIRECTORY}
    gcloud storage rm -r ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage rm -r ${MAXTEXT_BUCKET_UNSCANNED}