PyTorch を使用した拡散モデルのトレーニング


このチュートリアルでは、PyTorch Lightning と Pytorch XLA を使用して TPU で拡散モデルをトレーニングする方法について説明します。

目標

  • Cloud TPU の作成
  • PyTorch Lightning をインストールする
  • 拡散リポジトリのクローンを作成する
  • Imagenette データセットを準備する
  • トレーニング スクリプトを実行します

費用

このドキュメントでは、Google Cloud の次の課金対象のコンポーネントを使用します。

  • Compute Engine
  • Cloud TPU

料金計算ツールを使うと、予想使用量に基づいて費用の見積もりを生成できます。 新しい Google Cloud ユーザーは無料トライアルをご利用いただける場合があります。

始める前に

このチュートリアルを開始する前に、Google Cloud プロジェクトが正しく設定されていることを確認します。

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Google Cloud project.

  6. このチュートリアルでは、Google Cloud の課金対象となるコンポーネントを使用します。費用を見積もるには、Cloud TPU の料金ページを確認してください。不要な課金を回避するために、このチュートリアルを完了したら、作成したリソースを必ずクリーンアップしてください。

Cloud TPU の作成

このチュートリアルでは v4-8 を使用しますが、単一ホストのすべてのアクセラレータ サイズで同様に動作します。

コマンドを簡単に使用できるように、いくつかの環境変数を設定します。

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-8
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=your_tpu_name

Cloud TPU の作成

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet

必要なソフトウェアのインストール

  1. PyTorch / XLA の最新リリース v2.4.0 とともに、必要なパッケージをインストールします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="sudo apt-get update -y && sudo apt-get install libgl1 -y
    git clone https://github.com/pytorch-tpu/stable-diffusion.git
    cd stable-diffusion
    pip install -r requirements.txt
    pip install -e .
    pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
    pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
    pip install clip
    pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. torch 2.2 以降と互換性があるようにソースファイルを修正します。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    sed -i 's/from torch._six import string_classes/string_classes = (str, bytes)/g' src/taming-transformers/taming/data/utils.py
    sed -i 's/trainer_kwargs\\[\"callbacks\"\\]/# trainer_kwargs\\[\"callbacks\"\\]/g' main_tpu.py"
  3. Imagenette(Imagenet データセットの小さいバージョン)をダウンロードして、適切なディレクトリに移動します。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="wget -nv https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    tar -xf  imagenette2.tgz
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mkdir -p ~/.cache/autoencoders/data/ILSVRC2012_validation/data
    mv imagenette2/train/*  ~/.cache/autoencoders/data/ILSVRC2012_train/data
    mv imagenette2/val/* ~/.cache/autoencoders/data/ILSVRC2012_validation/data"
  4. 最初のステージの事前トレーニング済みモデルをダウンロードします。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --command="cd stable-diffusion/
    wget -nv -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
    cd  models/first_stage_models/vq-f8/
    unzip -o model.zip"

モデルのトレーニング

次のコマンドでトレーニングを実行します。v4-8 では、トレーニング プロセスに約 30 分かかることが想定されます。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--command="python3 stable-diffusion/main_tpu.py --train --no-test --base=stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8-ss.yaml -- data.params.batch_size=32 lightning.trainer.max_epochs=5 model.params.first_stage_config.params.ckpt_path=stable-diffusion/models/first_stage_models/vq-f8/model.ckpt lightning.trainer.enable_checkpointing=False lightning.strategy.sync_module_states=False"

クリーンアップ

作成したリソースを使用した後、アカウントに不要な請求が発生しないようにクリーンアップを行います。

Google Cloud CLI を使用して Cloud TPU リソースを削除します。

  $  gcloud compute tpus tpu-vm delete diffusion-tutorial --zone=us-central2-b
  

次のステップ