Prima di iniziare
Segui i passaggi in Gestione delle risorse TPU per creare un'impostazione della VM TPU da --accelerator-type
a v5litepod-8
e connetterti alla VM TPU.
Configurare JetStream e MaxText
Scarica JetStream e il repository GitHub di MaxText
git clone -b jetstream-v0.2.0 https://github.com/google/maxtext.git git clone -b v0.2.0 https://github.com/google/JetStream.git
Imposta MaxText
# Create a python virtual environment sudo apt install python3.10-venv python -m venv .env source .env/bin/activate # Set up MaxText cd maxtext/ bash setup.sh
Converti i checkpoint del modello
Puoi eseguire JetStream MaxText Server con modelli Gemma o Llama2. Questa sezione descrive come eseguire il server JetStream MaxText con varie dimensioni di questi modelli.
Utilizza un checkpoint del modello Gemma
- Scarica un checkpoint Gemma da Kaggle.
Copia il checkpoint nel tuo bucket Cloud Storage
# Set YOUR_CKPT_PATH to the path to the checkpoints # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
Per un esempio che include i valori per
${YOUR_CKPT_PATH}
e${CHKPT_BUCKET}
, consulta lo script di conversione.Converti il checkpoint Gemma in un checkpoint non analizzato compatibile con MaxText.
# For gemma-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
Utilizza un checkpoint del modello Llama2
Scarica un checkpoint Llama2 dalla community open source o utilizzane uno generato da te.
Copia i checkpoint nel bucket Cloud Storage.
gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
Per un esempio che include i valori per
${YOUR_CKPT_PATH}
e${CHKPT_BUCKET}
, consulta lo script di conversione.Converti il checkpoint Lama2 in un checkpoint non scansionato compatibile con MaxText.
# For llama2-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} # For llama2-13b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
Esegui il server JetStream MaxText
Questa sezione descrive come eseguire il server MaxText utilizzando un checkpoint compatibile con MaxText.
Configura le variabili di ambiente per il server MaxText
Esporta le seguenti variabili di ambiente in base al modello che stai utilizzando.
Utilizza il valore di UNSCANNED_CKPT_PATH
dell'output model_ckpt_conversion.sh
.
crea variabili di ambiente Gemma-7b per i flag del server
Configura i flag del server JetStream MaxText.
export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4
crea variabili di ambiente Lama2-7b per i flag del server
Configura i flag del server JetStream MaxText.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=6
crea variabili di ambiente Lama2-13b per i flag del server
Configura i flag del server JetStream MaxText.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=2
Avvia il server JetStream MaxText
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
Descrizioni dei flag del server JetStream MaxText
tokenizer_path
- Il percorso a un tokenizzatore (deve corrispondere al tuo modello).
load_parameters_path
- Carica i parametri (nessun stato di ottimizzazione) da una directory specifica
per_device_batch_size
- decodifica della dimensione del batch per dispositivo (1 chip TPU = 1 dispositivo)
max_prefill_predict_length
- Lunghezza massima della precompilazione durante l'autoregressione
max_target_length
- Lunghezza massima della sequenza
model_name
- Nome del modello
ici_fsdp_parallelism
- Il numero di shard per il parallelismo FSDP
ici_autoregressive_parallelism
- Il numero di shard per il parallelismo autoregressivo
ici_tensor_parallelism
- Il numero di shard per il parallelismo tensore
weight_dtype
- Tipo di dati sulla ponderazione (ad esempio bfloat16)
scan_layers
- Flag booleano dei livelli di scansione
Invia una richiesta di prova al server JetStream MaxText
cd ~
python JetStream/jetstream/tools/requester.py
L'output sarà simile al seguente:
Sending request to: dns:///[::1]:9000
Prompt: Today is a good day
Response: to be a fan
Esegui benchmark con il server JetStream MaxText
Per ottenere i migliori risultati dei benchmark, abilita la quantizzazione (utilizza punti di controllo addestrati o perfezionati per AQT per garantire l'accuratezza) sia per le ponderazioni che per la cache KV. Per abilitare la quantizzazione, imposta i relativi flag:
# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true
# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
export PER_DEVICE_BATCH_SIZE=12
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}
Benchmarking di Gemma-7b
Per confrontare Gemma-7b:
- Scarica il set di dati ShareGPT.
- Assicurati di utilizzare il tokenizzatore Gemma (tokenizer.gemma) quando esegui Gemma 7b.
- Aggiungi il flag
--warmup-first
per la tua prima esecuzione per riscaldare il server.
# Activate the env python virtual environment
cd ~
source .env/bin/activate
# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer /home/$USER/maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true
Benchmarking della lama2 più grande
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true
Esegui la pulizia
Per evitare che al tuo Account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.
# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}
# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream
# Clean up the python virtual environment
rm -rf .env