Inferência do MaxDiffusion em TPUs v6e

Este tutorial mostra como disponibilizar modelos MaxDiffusion na TPU v6e. Neste tutorial, você vai gerar imagens usando o modelo Stable Diffusion XL.

Antes de começar

Prepare-se para provisionar uma TPU v6e com 4 chips:

  1. Siga o guia Configurar o ambiente do Cloud TPU para configurar um projeto Google Cloud , configurar o CLI do Google Cloud, ativar a API Cloud TPU e garantir que você tenha acesso para usar o Cloud TPU.

  2. Faça a autenticação com Google Cloud e configure o projeto e a zona padrão para a Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Capacidade segura

Quando estiver tudo pronto para garantir a capacidade da TPU, consulte Cotas da Cloud TPU para mais informações. Se você tiver outras dúvidas sobre como garantir a capacidade, entre em contato com a equipe de vendas ou de conta do Cloud TPU.

Provisionar o ambiente do Cloud TPU

É possível provisionar VMs do TPU com o GKE, com o GKE e o XPK ou como recursos em fila.

Pré-requisitos

  • Verifique se o projeto tem cota de TPUS_PER_TPU_FAMILY suficiente, que especifica o número máximo de chips que você pode acessar no projetoGoogle Cloud .
  • Verifique se o projeto tem cota suficiente de TPU para:
    • Cota da VM de TPU
    • Cota de endereços IP
    • Quota do Hyperdisk equilibrado
  • Permissões do projeto do usuário

Provisionar um TPU v6e

   gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT

Use os comandos list ou describe para consultar o status do recurso em fila.

   gcloud alpha compute tpus queued-resources describe QUEUED_RESOURCE_ID  \
      --project=PROJECT_ID --zone=ZONE

Para uma lista completa de status de solicitações de recursos em fila, consulte a documentação de Recursos em fila.

Conectar-se à TPU usando SSH

   gcloud compute tpus tpu-vm ssh TPU_NAME

Criar um ambiente da Conda

  1. Crie um diretório para o Miniconda:

    mkdir -p ~/miniconda3
  2. Faça o download do script de instalação do Miniconda:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Instale o Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Remova o script do instalador do Miniconda:

    rm -rf ~/miniconda3/miniconda.sh
  5. Adicione o Miniconda à variável PATH:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. Atualize ~/.bashrc para aplicar as mudanças à variável PATH:

    source ~/.bashrc
  7. Crie um novo ambiente do Conda:

    conda create -n tpu python=3.10
  8. Ative o ambiente da Conda:

    source activate tpu

Configurar o MaxDiffusion

  1. Clone o repositório MaxDiffusion e navegue até o diretório MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. Alterne para a ramificação mlperf-4.1:

    git checkout mlperf4.1
  3. Instale o MaxDiffusion:

    pip install -e .
  4. Instale as dependências:

    pip install -r requirements.txt
  5. Instale o JAX:

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  6. Instale dependências extras:

     pip install huggingface_hub==0.25 absl-py flax tensorboardX google-cloud-storage torch tensorflow transformers 

Gerar imagens

  1. Defina variáveis de ambiente para configurar o ambiente de execução da TPU:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. Gerar imagens usando o comando e as configurações definidas em src/maxdiffusion/configs/base_xl.yml:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

    Depois que as imagens forem geradas, limpe os recursos da TPU.

Limpar

Exclua a TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async