JetStream MaxText-Inferenz auf v6e-TPU-VMs

In dieser Anleitung wird gezeigt, wie Sie mit JetStream MaxText-Modelle auf TPU v6e bereitstellen. JetStream ist eine durchsatz- und speicheroptimierte Engine für die LLM-Inferenz (Large Language Model) auf XLA-Geräten (TPUs). In dieser Anleitung führen Sie die Inferenzbenchmark für das Modell Llama2-7B aus.

Hinweis

Bereiten Sie die Bereitstellung einer v6e-TPU mit 4 Chips vor:

  1. Folgen Sie der Anleitung unter Cloud TPU-Umgebung einrichten, um ein Google Cloud -Projekt einzurichten, die Google Cloud CLI zu konfigurieren, die Cloud TPU API zu aktivieren und dafür zu sorgen, dass Sie Zugriff auf Cloud TPUs haben.

  2. Authentifizieren Sie sich bei Google Cloud und konfigurieren Sie das Standardprojekt und die Standardzone für die Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Kapazität sichern

Wenn Sie sich TPU-Kapazität sichern möchten, finden Sie unter Cloud TPU-Kontingente weitere Informationen zu den Cloud TPU-Kontingenten. Wenn Sie weitere Fragen zur Sicherung von Kapazität haben, wenden Sie sich an Ihr Cloud TPU-Vertriebs- oder ‑Account-Management-Team.

Cloud TPU-Umgebung bereitstellen

Sie können TPU-VMs mit GKE, mit GKE und XPK oder als in die Warteschlange gestellte Ressourcen bereitstellen.

Vorbereitung

  • Prüfen Sie, ob Ihr Projekt über ein ausreichendes TPUS_PER_TPU_FAMILY-Kontingent verfügt. Damit wird die maximale Anzahl von Chips angegeben, auf die Sie in IhremGoogle Cloud -Projekt zugreifen können.
  • Prüfen Sie, ob folgende Kontingente für Ihr Projekt ausreichen:
    • TPU-VM-Kontingent
    • Kontingent für IP-Adressen
    • Hyperdisk Balanced-Kontingent
  • Nutzerprojektberechtigungen

Umgebungsvariablen erstellen

Erstellen Sie in einer Cloud Shell die folgenden Umgebungsvariablen:

export PROJECT_ID=your-project-id
export TPU_NAME=your-tpu-name
export ZONE=us-east5-b
export ACCELERATOR_TYPE=v6e-4
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id

Beschreibungen von Umgebungsvariablen

Variable Beschreibung
PROJECT_ID Ihre Google Cloud -Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues Projekt.
TPU_NAME Der Name der TPU
ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen.
ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
RUNTIME_VERSION Die Softwareversion der Cloud TPU.
SERVICE_ACCOUNT Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der SeiteDienstkonten.

Beispiel: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der Anfrage für in die Warteschlange gestellte Ressourcen.

TPU v6e bereitstellen

Verwenden Sie den folgenden Befehl, um eine TPU v6e bereitzustellen:

gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id=${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --accelerator-type=${ACCELERATOR_TYPE} \
    --runtime-version=${RUNTIME_VERSION} \
    --service-account=${SERVICE_ACCOUNT}

Verwenden Sie den Befehl list oder describe, um den Status Ihrer in die Warteschlange gestellten Ressource abzufragen.

gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
    --project ${PROJECT_ID} --zone ${ZONE}

Weitere Informationen zum Status von Anfragen für in die Warteschlange gestellte Ressourcen finden Sie unter In die Warteschlange gestellte Ressourcen verwalten.

SSH-Verbindung zur TPU herstellen

   gcloud compute tpus tpu-vm ssh ${TPU_NAME}

Sobald Sie eine Verbindung zur TPU hergestellt haben, können Sie die Inferenzbenchmark ausführen.

TPU-VM-Umgebung einrichten

  1. Erstellen Sie ein Verzeichnis zum Ausführen der Inferenzbenchmark:

    export MAIN_DIR=your-main-directory
    mkdir -p ${MAIN_DIR}
  2. Richten Sie eine virtuelle Python-Umgebung ein:

    cd ${MAIN_DIR}
    sudo apt update
    sudo apt install python3.10 python3.10-venv
    python3.10 -m venv venv
    source venv/bin/activate
  3. Installieren Sie Git Large File Storage (LFS) für OpenOrca-Daten:

    sudo apt-get install git-lfs
    git lfs install
  4. Klonen und installieren Sie JetStream:

    cd $MAIN_DIR
    git clone https://github.com/google/JetStream.git
    cd JetStream
    git checkout main
    pip install -e .
    cd benchmarks
    pip install -r requirements.in
  5. Richten Sie MaxText ein:

    cd $MAIN_DIR
    git clone https://github.com/google/maxtext.git
    cd maxtext
    git checkout main
    bash setup.sh
    pip install torch --index-url https://download.pytorch.org/whl/cpu
  6. Fordern Sie Zugriff auf Llama-Modelle an, um einen Downloadschlüssel von Meta für das Llama 2-Modell zu erhalten.

  7. Klonen Sie das Llama-Repository:

    cd $MAIN_DIR
    git clone https://github.com/meta-llama/llama
    cd llama
  8. Führen Sie bash download.sh aus. Geben Sie bei entsprechender Aufforderung Ihren Downloadschlüssel ein. Mit diesem Script wird ein llama-2-7b-Verzeichnis in Ihrem llama-Verzeichnis erstellt.

    bash download.sh
  9. Erstellen Sie Storage-Buckets:

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage buckets create ${CHKPT_BUCKET}
    gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY}
    gcloud storage buckets create ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage buckets create ${MAXTEXT_BUCKET_UNSCANNED}
    gcloud storage cp --recursive llama-2-7b/* ${CHKPT_BUCKET}

Prüfpunktkonvertierung durchführen

  1. Führen Sie die Konvertierung in gescannte Prüfpunkte durch:

    cd $MAIN_DIR/maxtext
    python3 -m MaxText.llama_or_mistral_ckpt \
        --base-model-path $MAIN_DIR/llama/llama-2-7b \
        --model-size llama2-7b \
        --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
  2. Führen Sie die Konvertierung in nicht gescannte Prüfpunkte durch:

    export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items
    export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint
    python3 -m MaxText.generate_param_only_checkpoint \
        MaxText/configs/base.yml \
        base_output_directory=${MAXTEXT_BUCKET_UNSCANNED} \
        load_parameters_path=${CONVERTED_CHECKPOINT} \
        run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} \
        model_name='llama2-7b' \
        force_unroll=true

Inferenz durchführen

  1. Führen Sie einen Validierungstest aus:

    export UNSCANNED_CKPT_PATH=${MAXTEXT_BUCKET_UNSCANNED}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items
    python3 -m MaxText.decode \
        MaxText/configs/base.yml \
        load_parameters_path=${UNSCANNED_CKPT_PATH} \
        run_name=runner_decode_unscanned_${idx} \
        base_output_directory=${BASE_OUTPUT_DIRECTORY} \
        per_device_batch_size=1 \
        model_name='llama2-7b' \
        ici_autoregressive_parallelism=4 \
        max_prefill_predict_length=4 \
        max_target_length=16 \
        prompt="I love to" \
        attention=dot_product \
        scan_layers=false
  2. Führen Sie den Server im aktuellen Terminal aus:

    export TOKENIZER_PATH=assets/tokenizer.llama2
    export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
    export MAX_PREFILL_PREDICT_LENGTH=1024
    export MAX_TARGET_LENGTH=2048
    export MODEL_NAME=llama2-7b
    export ICI_FSDP_PARALLELISM=1
    export ICI_AUTOREGRESSIVE_PARALLELISM=1
    export ICI_TENSOR_PARALLELISM=-1
    export SCAN_LAYERS=false
    export WEIGHT_DTYPE=bfloat16
    export PER_DEVICE_BATCH_SIZE=11
    
    cd $MAIN_DIR/maxtext
    python3 -m MaxText.maxengine_server \
        MaxText/configs/base.yml \
        tokenizer_path=${TOKENIZER_PATH} \
        load_parameters_path=${LOAD_PARAMETERS_PATH} \
        max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
        max_target_length=${MAX_TARGET_LENGTH} \
        model_name=${MODEL_NAME} \
        ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
        ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
        ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
        scan_layers=${SCAN_LAYERS} \
        weight_dtype=${WEIGHT_DTYPE} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
  3. Öffnen Sie ein neues Terminalfenster, stellen Sie eine Verbindung zur TPU her und wechseln Sie zu der virtuellen Umgebung, die Sie im ersten Terminalfenster verwendet haben:

    source venv/bin/activate
    
  4. Führen Sie die folgenden Befehle aus, um die JetStream-Benchmark auszuführen.

    export MAIN_DIR=your-main-directory
    cd $MAIN_DIR
    
    python JetStream/benchmarks/benchmark_serving.py \
        --tokenizer $MAIN_DIR/maxtext/assets/tokenizer.llama2 \
        --warmup-mode sampled \
        --save-result \
        --save-request-outputs \
        --request-outputs-file-path outputs.json \
        --num-prompts 1000 \
        --max-output-length 1024 \
        --dataset openorca \
        --dataset-path $MAIN_DIR/JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl

Ergebnisse

Die folgende Ausgabe wurde beim Ausführen der Benchmark mit v6e‑8 generiert. Die Ergebnisse variieren je nach Hardware, Software, Modell und Netzwerk.

Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms

Bereinigen

  1. Trennen Sie die Verbindung zur TPU:

    $ (vm) exit
  2. Löschen Sie die TPU:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project ${PROJECT_ID} \
        --zone ${ZONE} \
        --force \
        --async
  3. Löschen Sie die Buckets und ihre Inhalte:

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage rm -r ${CHKPT_BUCKET}
    gcloud storage rm -r ${BASE_OUTPUT_DIRECTORY}
    gcloud storage rm -r ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage rm -r ${MAXTEXT_BUCKET_UNSCANNED}