JetStream MaxText-Inferenz auf v6e-TPU-VMs

In dieser Anleitung erfahren Sie, wie Sie mit JetStream MaxText-Modelle auf TPU v6e bereitstellen. JetStream ist eine durchsatz- und speicheroptimierte Engine für die LLM-Inferenz (Large Language Model) auf XLA-Geräten (TPUs). In dieser Anleitung führen Sie den Inferenz-Benchmark für das Llama2-7B-Modell aus.

Hinweise

TPU v6e mit 4 Chips vorbereiten:

  1. Folgen Sie der Anleitung Cloud TPU-Umgebung einrichten, um ein Google Cloud -Projekt einzurichten, die Google Cloud CLI zu konfigurieren, die Cloud TPU API zu aktivieren und sicherzustellen, dass Sie Zugriff auf Cloud TPUs haben.

  2. Authentifizieren Sie sich mit Google Cloud und konfigurieren Sie das Standardprojekt und die Standardzone für die Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Sichere Kapazität

Wenn Sie bereit sind, TPU-Kapazität zu reservieren, finden Sie unter Cloud TPU-Kontingente weitere Informationen zu den Cloud TPU-Kontingenten. Wenn Sie weitere Fragen zur Kapazitätssicherung haben, wenden Sie sich an Ihr Cloud TPU-Vertriebs- oder Account-Management-Team.

Cloud TPU-Umgebung bereitstellen

Sie können TPU-VMs mit GKE, mit GKE und XPK oder als in der Warteschlange befindliche Ressourcen bereitstellen.

Vorbereitung

  • Prüfen Sie, ob Ihr Projekt ein ausreichendes TPUS_PER_TPU_FAMILY-Kontingent hat. Dieses gibt die maximale Anzahl von Chips an, auf die Sie in IhremGoogle Cloud -Projekt zugreifen können.
  • Prüfen Sie, ob Ihr Projekt ein ausreichendes TPU-Kontingent für Folgendes hat:
    • TPU-VM-Kontingent
    • Kontingent für IP-Adressen
    • Hyperdisk Balanced-Kontingent
  • Nutzerberechtigungen für Projekte

Umgebungsvariablen erstellen

Erstellen Sie in Cloud Shell die folgenden Umgebungsvariablen:
export PROJECT_ID=your-project-id
export TPU_NAME=your-tpu-name
export ZONE=us-east5-b
export ACCELERATOR_TYPE=v6e-4
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id

Beschreibung der Befehls-Flags

Variable Beschreibung
PROJECT_ID Google Cloud ist der Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues.
TPU_NAME Der Name der TPU.
ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen.
ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
RUNTIME_VERSION Die Cloud TPU-Softwareversion.
SERVICE_ACCOUNT Die E-Mail-Adresse Ihres Dienstkontos . Sie finden ihn in der Google Cloud Console auf der Seite Dienstkonten.

Beispiel: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der anstehenden Ressourcenanfrage.

TPU v6e bereitstellen

Verwenden Sie den folgenden Befehl, um eine TPU v6e bereitzustellen:

gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id=${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --accelerator-type=${ACCELERATOR_TYPE} \
    --runtime-version=${RUNTIME_VERSION} \
    --service-account=${SERVICE_ACCOUNT}

Verwenden Sie die Befehle list oder describe, um den Status der in der Warteschlange befindlichen Ressource abzufragen.

gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
    --project ${PROJECT_ID} --zone ${ZONE}

Weitere Informationen zu den Status von angeforderten Ressourcen in der Warteschlange finden Sie unter Warteschlange für Ressourcen verwalten.

Über SSH eine Verbindung zur TPU herstellen

   gcloud compute tpus tpu-vm ssh ${TPU_NAME}

Sobald Sie eine Verbindung zur TPU hergestellt haben, können Sie den Inferenz-Benchmark ausführen.

TPU-VM-Umgebung einrichten

  1. Erstellen Sie ein Verzeichnis zum Ausführen des Inferenz-Benchmarks:

    export MAIN_DIR=your-main-directory
    mkdir -p ${MAIN_DIR}
  2. Richten Sie eine virtuelle Python-Umgebung ein:

    cd ${MAIN_DIR}
    sudo apt update
    sudo apt install python3.10 python3.10-venv
    python3.10 -m venv venv
    source venv/bin/activate
  3. Git Large File Storage (LFS) für OpenOrca-Daten installieren:

    sudo apt-get install git-lfs
    git lfs install
  4. Klonen und installieren Sie JetStream:

    cd $MAIN_DIR
    git clone https://github.com/google/JetStream.git
    cd JetStream
    git checkout main
    pip install -e .
    cd benchmarks
    pip install -r requirements.in
  5. So richten Sie MaxText ein:

    cd $MAIN_DIR
    git clone https://github.com/google/maxtext.git
    cd maxtext
    git checkout main
    bash setup.sh
    pip install torch --index-url https://download.pytorch.org/whl/cpu
  6. Fordern Sie Zugriff auf Llama-Modelle an, um einen Downloadschlüssel von Meta für das Llama 2-Modell zu erhalten.

  7. Klonen Sie das Llama-Repository:

    cd $MAIN_DIR
    git clone https://github.com/meta-llama/llama
    cd llama
  8. Führen Sie bash download.sh aus. Gib den Downloadschlüssel ein, wenn du dazu aufgefordert wirst. Dieses Script erstellt ein llama-2-7b-Verzeichnis im llama-Verzeichnis.

    bash download.sh
  9. So erstellen Sie Storage-Buckets:

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage buckets create ${CHKPT_BUCKET}
    gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY}
    gcloud storage buckets create ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage buckets create ${MAXTEXT_BUCKET_UNSCANNED}
    gcloud storage cp --recursive llama-2-7b/* ${CHKPT_BUCKET}

Prüfpunktkonvertierung ausführen

  1. So führen Sie die Umwandlung in gescannte Prüfpunkte durch:

    cd $MAIN_DIR/maxtext
    python3 -m MaxText.llama_or_mistral_ckpt \
        --base-model-path $MAIN_DIR/llama/llama-2-7b \
        --model-size llama2-7b \
        --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
  2. So konvertieren Sie sie in nicht gescannte Prüfpunkte:

    export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items
    export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint
    python3 -m MaxText.generate_param_only_checkpoint \
        MaxText/configs/base.yml \
        base_output_directory=${MAXTEXT_BUCKET_UNSCANNED} \
        load_parameters_path=${CONVERTED_CHECKPOINT} \
        run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} \
        model_name='llama2-7b' \
        force_unroll=true

Inferenz ausführen

  1. Validierungstest ausführen:

    export UNSCANNED_CKPT_PATH=${MAXTEXT_BUCKET_UNSCANNED}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items
    python3 -m MaxText.decode \
        MaxText/configs/base.yml \
        load_parameters_path=${UNSCANNED_CKPT_PATH} \
        run_name=runner_decode_unscanned_${idx} \
        base_output_directory=${BASE_OUTPUT_DIRECTORY} \
        per_device_batch_size=1 \
        model_name='llama2-7b' \
        ici_autoregressive_parallelism=4 \
        max_prefill_predict_length=4 \
        max_target_length=16 \
        prompt="I love to" \
        attention=dot_product \
        scan_layers=false
  2. Führen Sie den Server in Ihrem aktuellen Terminal aus:

    export TOKENIZER_PATH=assets/tokenizer.llama2
    export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
    export MAX_PREFILL_PREDICT_LENGTH=1024
    export MAX_TARGET_LENGTH=2048
    export MODEL_NAME=llama2-7b
    export ICI_FSDP_PARALLELISM=1
    export ICI_AUTOREGRESSIVE_PARALLELISM=1
    export ICI_TENSOR_PARALLELISM=-1
    export SCAN_LAYERS=false
    export WEIGHT_DTYPE=bfloat16
    export PER_DEVICE_BATCH_SIZE=11
    
    cd $MAIN_DIR/maxtext
    python3 -m MaxText.maxengine_server \
        MaxText/configs/base.yml \
        tokenizer_path=${TOKENIZER_PATH} \
        load_parameters_path=${LOAD_PARAMETERS_PATH} \
        max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
        max_target_length=${MAX_TARGET_LENGTH} \
        model_name=${MODEL_NAME} \
        ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
        ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
        ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
        scan_layers=${SCAN_LAYERS} \
        weight_dtype=${WEIGHT_DTYPE} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
  3. Öffnen Sie ein neues Terminalfenster, stellen Sie eine Verbindung zur TPU her und wechseln Sie zu derselben virtuellen Umgebung, die Sie im ersten Terminalfenster verwendet haben:

    source venv/bin/activate
    
  4. Führen Sie die folgenden Befehle aus, um den JetStream-Benchmark auszuführen.

    export MAIN_DIR=your-main-directory
    cd $MAIN_DIR
    
    python JetStream/benchmarks/benchmark_serving.py \
        --tokenizer $MAIN_DIR/maxtext/assets/tokenizer.llama2 \
        --warmup-mode sampled \
        --save-result \
        --save-request-outputs \
        --request-outputs-file-path outputs.json \
        --num-prompts 1000 \
        --max-output-length 1024 \
        --dataset openorca \
        --dataset-path $MAIN_DIR/JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl

Ergebnisse

Die folgende Ausgabe wurde beim Ausführen des Benchmarks mit v6e-8 generiert. Die Ergebnisse variieren je nach Hardware, Software, Modell und Netzwerk.

Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms

Bereinigen

  1. Trennen Sie die Verbindung zum TPU:

    $ (vm) exit
  2. Löschen Sie die TPU:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project ${PROJECT_ID} \
        --zone ${ZONE} \
        --force \
        --async
  3. Löschen Sie die Buckets und ihren Inhalt:

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage rm -r ${CHKPT_BUCKET}
    gcloud storage rm -r ${BASE_OUTPUT_DIRECTORY}
    gcloud storage rm -r ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage rm -r ${MAXTEXT_BUCKET_UNSCANNED}