Trillium (v6e) 소개

v6e는 이 문서, TPU API, 로그에서 Trillium을 나타내는 데 사용됩니다. v6e는 Google의 6세대 TPU를 나타냅니다.

포드당 칩이 256개인 v6e 아키텍처는 v5e와 많은 유사점을 공유합니다. 이 시스템은 변환기, 텍스트-이미지, 컨볼루셔널 신경망 (CNN) 학습, 미세 조정, 서빙에 최적화되어 있습니다.

v6e 시스템 아키텍처 및 구성에 관한 자세한 내용은 v6e 문서를 참고하세요.

이 소개 문서에서는 JAX, PyTorch 또는 TensorFlow 프레임워크를 사용하여 모델을 학습하고 제공하는 프로세스에 중점을 둡니다. 각 프레임워크를 사용하면 대기열에 추가된 리소스 또는 Google Kubernetes Engine (GKE)을 사용하여 TPU를 프로비저닝할 수 있습니다. GKE 설정은 XPK 또는 GKE 명령어를 사용하여 수행할 수 있습니다.

v6e를 사용하여 모델을 학습하거나 제공하는 일반적인 절차

  1. Google Cloud 프로젝트 준비하기
  2. 용량 확보
  3. TPU 환경 설정
  4. Cloud TPU 환경 프로비저닝
  5. 모델 학습 또는 추론 워크로드 실행
  6. 삭제

Google Cloud 프로젝트 준비

  1. Google 계정에 로그인합니다. 아직 계정이 없다면 새 계정을 만듭니다.
  2. Google Cloud 콘솔의 프로젝트 선택기 페이지에서 클라우드 프로젝트를 선택하거나 만듭니다.
  3. Google Cloud 프로젝트에 결제를 사용 설정합니다. Google Cloud를 사용하려면 결제가 필요합니다.
  4. gcloud alpha 구성요소를 설치합니다.
  5. 다음 명령어를 실행하여 최신 버전의 gcloud 구성요소를 설치합니다.

    gcloud components update
    
  6. Cloud Shell에서 다음 gcloud 명령어를 통해 TPU API를 사용 설정합니다. Google Cloud 콘솔에서도 사용 설정할 수 있습니다.

    gcloud services enable tpu.googleapis.com
    
  7. Compute Engine API에 TPU 서비스 계정으로 권한 사용 설정

    서비스 계정은 Cloud TPU 서비스가 다른 Google Cloud 서비스에 액세스하도록 허용합니다. Google Cloud에서 권장하는 방식은 사용자 관리형 서비스 계정입니다. 다음 가이드를 따라 역할을 만들고 부여합니다. 다음 역할이 필요합니다.

    • TPU 관리자
    • 스토리지 관리자
    • 로그 작성자
    • 모니터링 측정항목 작성자

    a. GKE의 사용자 계정으로 XPK 권한을 설정합니다. XPK

  8. Google 계정으로 인증하고 기본 프로젝트 ID 및 영역을 설정합니다.
    auth login는 gcloud가 Google 사용자 인증 정보로 Google Cloud에 액세스하도록 승인합니다.
    PROJECT_ID는 Google Cloud 프로젝트 이름입니다.
    ZONE는 TPU를 만들려는 영역입니다.

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. TPU VM의 서비스 ID를 만듭니다.

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

용량 확보

Cloud TPU 지원 영업/계정에 문의하여 TPU 할당량을 요청하고 용량에 관한 질문에 답변받으세요.

Cloud TPU 환경 프로비저닝

v6e TPU는 GKE, GKE 및 XPK (GKE용 래퍼 CLI 도구) 또는 대기열에 추가된 리소스로 프로비저닝하고 관리할 수 있습니다.

기본 요건

  • Google Cloud 프로젝트 내에서 액세스할 수 있는 최대 칩 수를 지정하는 TPUS_PER_TPU_FAMILY 할당량이 프로젝트에 충분한지 확인합니다.
  • v6e는 다음 구성으로 테스트되었습니다.
    • python 3.10 이상
    • Nightly 소프트웨어 버전:
      • nightly JAX 0.4.32.dev20240912
      • nightly LibTPU 0.1.dev20240912+nightly
    • 안정화 소프트웨어 버전:
      • JAX + v0.4.35의 JAX 라이브러리
  • 프로젝트에 다음에 충분한 TPU 할당량이 있는지 확인합니다.

    • TPU VM 할당량
    • IP 주소 할당량
    • Hyperdisk-balance 할당량

  • 사용자 프로젝트 권한

환경 변수

Cloud Shell에서 다음 환경 변수를 만듭니다.

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

명령어 플래그 설명

변수 설명
NODE_ID 큐에 추가된 리소스 요청이 할당될 때 생성되는 TPU의 사용자 할당 ID입니다.
PROJECT_ID Google Cloud 프로젝트 이름 에서 기존 프로젝트를 사용하거나 새 프로젝트를 만듭니다.
ZONE 지원되는 영역에 대해서는 TPU 리전 및 영역 문서를 참조하세요.
ACCELERATOR_TYPE 가속기 유형을 참고하세요.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Google Cloud 콘솔 -> IAM -> 서비스 계정에서 찾을 수 있는 서비스 계정의 이메일 주소입니다.

예: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES 생성할 슬라이스 수입니다 (멀티슬라이스에만 필요).
QUEUED_RESOURCE_ID 큐에 추가된 리소스 요청의 사용자 할당 텍스트 ID입니다.
VALID_DURATION 대기열에 추가된 리소스 요청이 유효한 기간입니다.
NETWORK_NAME 사용할 보조 네트워크의 이름입니다.
NETWORK_FW_NAME 사용할 보조 네트워크 방화벽의 이름입니다.

네트워크 성능 최적화

최상의 성능을 위해 MTU (최대 전송 단위)가 8,896인 네트워크를 사용하세요.

기본적으로 Virtual Private Cloud (VPC)는 1,460바이트의 MTU만 제공하므로 최적의 네트워크 성능을 제공하지 않습니다. VPC 네트워크의 MTU를 1,300바이트에서 8,896바이트 (양 끝값 포함) 사이의 값으로 설정할 수 있습니다. 일반적인 커스텀 MTU 크기는 1,500바이트(표준 이더넷) 또는 8,896바이트 (최대)입니다. 자세한 내용은 유효한 VPC 네트워크 MTU 크기를 참고하세요.

기존 또는 기본 네트워크의 MTU 설정을 변경하는 방법에 관한 자세한 내용은 VPC 네트워크의 MTU 설정 변경을 참고하세요.

다음 예에서는 MTU가 8,896인 네트워크를 만듭니다.

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME}
 --allow tcp,icmp,udp --project=${PROJECT}

멀티 NIC 사용 (멀티슬라이스 옵션)

멀티슬라이스 환경을 사용할 때 보조 서브넷에는 다음 환경 변수가 필요합니다.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

다음 명령어를 사용하여 네트워크 및 서브넷의 맞춤 IP 라우팅을 만듭니다.

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT_ID}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT_ID}" \
  --enable-logging

다중 네트워크 슬라이스가 생성되면 XPK 워크로드의 일부로 --command ifconfig를 실행하여 두 NIC이 모두 사용되고 있는지 확인할 수 있습니다. 다음 xpk workload 명령어를 사용하여 Cloud 콘솔 로그에 ifconfig 명령어의 출력을 표시하고 eth0과 eth1 모두 mtu=8896인지 확인합니다.

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   (--base-docker-image maxtext_base_image|--docker-image CLOUD_IMAGE_NAME \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

eth0 및 eth1 모두 mtu=8,896인지 확인합니다. 멀티 닉이 실행 중인지 확인하는 한 가지 방법은 XPK 워크로드의 일부로 --command 'ifconfig' 명령어를 실행하는 것입니다. 그런 다음 Cloud 콘솔 로그에서 해당 xpk 워크로드의 출력을 확인하고 eth0과 eth1 모두 mtu=8896인지 확인합니다.

TCP 설정 개선

큐에 추가된 리소스 인터페이스를 사용하여 만든 TPU의 경우 다음 명령어를 실행하여 TCP 수신 버퍼 한도를 늘려 네트워크 성능을 개선할 수 있습니다.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "$PROJECT" \
  --zone "$ZONE" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

대기열에 추가된 리소스를 사용한 프로비저닝

할당된 용량은 큐에 추가된 리소스 create 명령어를 사용하여 프로비저닝할 수 있습니다.

  1. TPU 큐에 추가된 리소스 요청을 만듭니다.

    --reserved 플래그는 주문형 리소스가 아닌 예약된 리소스에만 필요합니다.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]
    
      # The following flags are only needed if you are using Multislice.
      --node-count node-count  # Number of slices in a Multislice \
      --node-prefix node-prefix # An optional user-defined node prefix;
       the default is QUEUED_RESOURCE_ID.

    큐에 추가된 리소스 요청이 성공적으로 생성되면 'response' 필드 내 상태가 'WAITING_FOR_RESOURCES' 또는 'FAILED'입니다. 큐에 추가된 리소스 요청이 'WAITING_FOR_RESOURCES' 상태인 경우 리소스가 큐에 추가되었으며 할당된 TPU 용량이 충분하면 프로비저닝됩니다. 큐에 추가된 리소스 요청이 '실패' 상태이면 실패 이유가 출력에 표시됩니다. 지정된 기간 내에 v6e가 프로비저닝되지 않으면 대기열에 추가된 리소스 요청이 만료되고 상태는 'FAILED'가 됩니다. 자세한 내용은 큐에 추가된 리소스 공개 문서를 참고하세요.

    큐에 추가된 리소스 요청이 'ACTIVE' 상태이면 SSH를 사용하여 TPU VM에 연결할 수 있습니다. list 또는 describe 명령어를 사용하여 큐에 추가된 리소스의 상태를 쿼리합니다.

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    큐에 추가된 리소스가 'ACTIVE' 상태이면 출력은 다음과 비슷합니다.

      state:
       state: ACTIVE
    
  2. TPU VM을 관리합니다. TPU VM을 관리하는 옵션은 TPU VM 관리를 참고하세요.

  3. SSH를 사용하여 TPU VM에 연결

    TPU 슬라이스의 각 TPU VM에 바이너리를 설치하고 코드를 실행할 수 있습니다. VM 유형 섹션을 참고하여 슬라이스에 포함될 수 있는 VM 수를 확인하세요.

    바이너리를 설치하거나 코드를 실행하려면 SSH를 사용하여 tpu-vm ssh 명령어로 VM에 연결하면 됩니다.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    SSH를 사용하여 특정 VM에 연결하려면 0부터 시작하는 색인을 따르는 --worker 플래그를 사용합니다.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    칩이 8개를 초과하는 슬라이스 형태이면 한 슬라이스에 여러 VM이 있습니다. 이 경우 gcloud alpha compute tpus tpu-vm ssh 명령어에서 --worker=all--command 매개변수를 사용하여 모든 VM에서 명령어를 동시에 실행합니다. 예를 들면 다음과 같습니다.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. 큐에 추가된 리소스 삭제

    세션이 끝나면 큐에 추가된 리소스를 삭제하거나 'FAILED' 상태의 큐에 추가된 리소스 요청을 삭제합니다. 큐에 추가된 리소스를 삭제하려면 슬라이스를 삭제한 후 큐에 추가된 리소스 요청을 다음과 같이 두 단계로 삭제합니다.

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

v6e에서 GKE 사용

v6e에서 GKE 명령어를 사용하는 경우 Kubernetes 명령어 또는 XPK를 사용하여 TPU를 프로비저닝하고 모델을 학습 또는 제공할 수 있습니다. TPU 및 v6e와 함께 GKE를 사용하는 방법은 GKE의 TPU 계획을 참고하세요.

프레임워크 설정

이 섹션에서는 JAX, PyTorch 또는 TensorFlow 프레임워크를 사용하는 ML 모델 학습을 위한 일반적인 설정 프로세스를 설명합니다. 큐에 추가된 리소스 또는 GKE를 사용하여 TPU를 프로비저닝할 수 있습니다. GKE 설정은 XPK 또는 Kubernetes 명령어를 사용하여 수행할 수 있습니다.

JAX 설정

이 섹션에서는 XPK 유무와 관계없이 GKE에서 JAX 워크로드를 실행하는 예와 대기열에 추가된 리소스를 사용하는 예를 제공합니다.

GKE를 사용하여 JAX 설정

다음 예에서는 Kubernetes YAML 파일을 사용하여 2X2 단일 호스트를 설정합니다.

단일 호스트의 단일 슬라이스

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

완료되면 GKE 로그에 다음 메시지가 표시됩니다.

Total TPU chips: 4

멀티 호스트의 단일 슬라이스

다음 예에서는 Kubernetes YAML 파일을 사용하여 4X4 멀티 호스트 노드 풀을 설정합니다.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

완료되면 GKE 로그에 다음 메시지가 표시됩니다.

Total TPU chips: 16

멀티 호스트의 멀티슬라이스

다음 예에서는 Kubernetes YAML 파일을 사용하여 4X4 멀티호스트 노드 풀 2개를 설정합니다.

기본 요건으로 v0.2.3 이상의 JobSet을 설치해야 합니다.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

완료되면 GKE 로그에 다음 메시지가 표시됩니다.

Total TPU chips: 32

자세한 내용은 GKE 문서의 멀티슬라이스 워크로드 실행을 참고하세요.

성능을 개선하려면 hostNetwork를 사용 설정합니다.

멀티 NIC

GKE에서 멀티 NIC를 활용하려면 Kubernetes Pod 매니페스트에 추가 주석이 있어야 합니다. 다음은 TPU가 아닌 멀티 NIC 워크로드 매니페스트의 예입니다.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Kubernetes Pod에 exec하면 다음 코드를 사용하여 추가 NIC가 표시됩니다.

$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

XPK와 함께 GKE를 사용하여 JAX 설정

xpk README에서 예시를 확인하세요.

MaxText로 XPK를 설정하고 실행하려면 MaxText 실행 방법을 참고하세요.

대기열에 추가된 리소스를 사용하여 JAX 설정

gcloud alpha compute tpus tpu-vm ssh를 사용하여 슬라이스의 모든 TPU VM에 JAX를 동시에 설치합니다. 멀티슬라이스의 경우 --node=all를 추가합니다.


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

다음 Python 코드를 실행하여 슬라이스에서 사용할 수 있는 TPU 코어 수를 확인하고 모든 것이 올바르게 설치되었는지 테스트할 수 있습니다 (여기에 표시된 출력은 v6e-16 슬라이스로 생성됨).

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

출력은 다음과 비슷합니다.

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count()는 지정된 슬라이스의 칩 총개수를 보여줍니다. jax.local_device_count()는 이 슬라이스에서 단일 VM이 액세스할 수 있는 칩 수를 나타냅니다.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   && pip install -r requirements.txt  && pip install . '

JAX 설정 문제 해결

일반적인 팁은 GKE 워크로드 매니페스트에서 상세 로깅을 사용 설정하는 것입니다. 그런 다음 GKE 지원팀에 로그를 제공합니다.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

오류 메시지

no endpoints available for service 'jobset-webhook-service'

이 오류는 작업 세트가 올바르게 설치되지 않았음을 의미합니다. jobset-controller-manager 배포 Kubernetes Pod가 실행 중인지 확인합니다. 자세한 내용은 JobSet 문제 해결 문서를 참고하세요.

TPU initialization failed: Failed to connect

GKE 노드 버전이 1.30.4-gke.1348000 이상인지 확인합니다 (GKE 1.31은 지원되지 않음).

PyTorch 설정

이 섹션에서는 PyTorch/XLA를 사용하여 v6e에서 PJRT 사용을 시작하는 방법을 설명합니다. Python 3.10이 권장되는 Python 버전입니다.

XPK를 사용하여 GKE를 통해 PyTorch 설정

PyTorch 종속 항목이 이미 설치된 XPK와 함께 다음 Docker 컨테이너를 사용할 수 있습니다.

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

XPK 워크로드를 만들려면 다음 명령어를 사용합니다.

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

--base-docker-image를 사용하면 현재 작업 디렉터리가 새 Docker에 빌드된 새 Docker 이미지가 생성됩니다.

대기열에 추가된 리소스를 사용하여 PyTorch 설정

대기열에 추가된 리소스를 사용하여 PyTorch를 설치하고 v6e에서 작은 스크립트를 실행하려면 다음 단계를 따르세요.

SSH를 사용하여 종속 항목을 설치하여 VM에 액세스합니다.

멀티슬라이스의 경우 --node=all를 추가합니다.

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

상당한 크기의 자주 할당되는 모델의 성능 개선

상당한 크기의 자주 할당되는 모델의 경우 tcmalloc을 사용하면 기본 malloc 구현에 비해 성능이 크게 향상되는 것으로 확인되었습니다. 따라서 TPU VM에 사용되는 기본 malloctcmalloc입니다. 그러나 워크로드에 따라 (예를 들어 임베딩 테이블에 대한 대규모 할당이 있는 DLRM) tcmalloc이 느려질 수 있으며, 이 경우 대신 기본 malloc을 사용하여 다음 변수를 설정 해제할 수 있습니다.

unset LD_PRELOAD

Python 스크립트를 사용하여 v6e VM에서 계산을 실행합니다.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

그러면 다음과 비슷한 출력이 생성됩니다.

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

TensorFlow 설정

v6e 공개 미리보기의 경우 tf-nightly 런타임 버전만 지원됩니다.

다음 명령어를 실행하여 v6e 호환 TensorFlow 버전으로 tpu-runtime를 재설정할 수 있습니다.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

SSH를 사용하여 worker-0에 액세스합니다.

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

worker-0에 TensorFlow를 설치합니다.

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

TPU_NAME 환경 변수를 내보냅니다.

export TPU_NAME=v6e-16

다음 Python 스크립트를 실행하여 슬라이스에서 사용할 수 있는 TPU 코어 수를 확인하고 모든 것이 올바르게 설치되었는지 테스트할 수 있습니다 (여기에 표시된 출력은 v6e-16 슬라이스로 생성됨).

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

출력은 다음과 비슷합니다.

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

SkyPilot이 포함된 v6e

SkyPilot에서 TPU v6e를 사용할 수 있습니다. 다음 단계에 따라 SkyPilot에 v6e 관련 위치/가격 정보를 추가합니다.

  1. ~/.sky/catalogs/v5/gcp/vms.csv 끝에 다음을 추가합니다.

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. YAML 파일에서 다음 리소스를 지정합니다.

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. TPU v6e로 클러스터를 실행합니다.

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. SSH를 사용하여 TPU v6e에 연결합니다. ssh tpu_v6

추론 튜토리얼

다음 섹션에서는 JetStream을 사용하여 MaxText 및 PyTorch 모델을 서빙하는 방법과 TPU v6e에서 MaxDiffusion 모델을 서빙하는 방법을 안내합니다.

v6e Cloud TPU VM에서 JetStream MaxText 추론

이 튜토리얼에서는 JetStream을 사용하여 TPU v6e에서 MaxText (JAX) 모델을 제공하는 방법을 보여줍니다. JetStream은 XLA 기기 (TPU)에서 대규모 언어 모델 (LLM) 추론을 위한 처리량 및 메모리 최적화 엔진입니다. 이 튜토리얼에서는 Llama2-7B 모델의 추론 벤치마크를 실행합니다.

시작하기 전에

  1. 칩 4개가 있는 TPU v6e 프로비저닝 준비:

    gcloud alpha compute tpus queued-resources create $QUEUED_RESOURCE_ID \
        --node-id $TPU_NAME \
        --project $PROJECT_ID \
        --zone $ZONE \
        --accelerator-type $ACCELERATOR_TYPE \
        --runtime-version $RUNTIME_VERSION \
        --service-account $SERVICE_ACCOUNT
  2. SSH를 사용하여 TPU에 연결합니다.

    gcloud compute tpus tpu-vm ssh $TPU_NAME

튜토리얼 실행

JetStream 및 MaxText를 설정하고, 모델 체크포인트를 변환하고, 추론 벤치마크를 실행하려면 GitHub 저장소의 안내를 따르세요.

삭제

TPU를 삭제합니다.

gcloud compute tpus queued-resources delete $QUEUED_RESOURCE_ID \
    --project $PROJECT_ID \
    --zone $ZONE \
    --force \
    --async

v6e Cloud TPU VM에서 JetStream PyTorch 추론

이 튜토리얼에서는 JetStream을 사용하여 TPU v6e에서 PyTorch 모델을 제공하는 방법을 보여줍니다. JetStream은 XLA 기기(TPU)에서 대규모 언어 모델 (LLM) 추론을 위한 처리량 및 메모리 최적화 엔진입니다. 이 튜토리얼에서는 Llama2-7B 모델의 추론 벤치마크를 실행합니다.

시작하기 전에

  1. 칩 4개가 있는 TPU v6e 프로비저닝 준비:

    gcloud alpha compute tpus queued-resources create $QUEUED_RESOURCE_ID \
        --node-id $TPU_NAME \
        --project $PROJECT_ID \
        --zone $ZONE \
        --accelerator-type $ACCLERATOR_TYPE \
        --runtime-version $RUNTIME_VERSION \
        --service-account $SERVICE_ACCOUNT
  2. SSH를 사용하여 TPU에 연결합니다.

    gcloud compute tpus tpu-vm ssh $TPU_NAME

튜토리얼 실행

JetStream-PyTorch를 설정하고, 모델 체크포인트를 변환하고, 추론 벤치마크를 실행하려면 GitHub 저장소의 안내를 따르세요.

삭제

TPU를 삭제합니다.

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async

v6e Cloud TPU VM에서 MaxDiffusion 추론

이 튜토리얼에서는 TPU v6e에서 MaxDiffusion 모델을 제공하는 방법을 보여줍니다. 이 튜토리얼에서는 Stable Diffusion XL 모델을 사용하여 이미지를 생성합니다.

시작하기 전에

  1. 칩 4개가 있는 TPU v6e 프로비저닝 준비:

    gcloud alpha compute tpus queued-resources create $QUEUED_RESOURCE_ID \
        --node-id $TPU_NAME \
        --project $PROJECT_ID \
        --zone $ZONE \
        --accelerator-type $ACCELERATOR_TYPE \
        --runtime-version $RUNTIME_VERSION \
        --service-account $SERVICE_ACCOUNT
  2. SSH를 사용하여 TPU에 연결합니다.

    gcloud compute tpus tpu-vm ssh $TPU_NAME

Conda 환경 만들기

  1. Miniconda의 디렉터리를 만듭니다.

    mkdir -p ~/miniconda3
  2. Miniconda 설치 프로그램 스크립트를 다운로드합니다.

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Miniconda를 설치합니다.

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Miniconda 설치 프로그램 스크립트를 삭제합니다.

    rm -rf ~/miniconda3/miniconda.sh
  5. PATH 변수에 Miniconda를 추가합니다.

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. ~/.bashrc를 새로고침하여 PATH 변수에 변경사항을 적용합니다.

    source ~/.bashrc
  7. 새 Conda 환경을 만듭니다.

    conda create -n tpu python=3.10
  8. Conda 환경을 활성화합니다.

    source activate tpu

MaxDiffusion 설정

  1. MaxDiffusion 저장소를 클론하고 MaxDiffusion 디렉터리로 이동합니다.

    https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. mlperf-4.1 브랜치로 전환합니다.

    git checkout mlperf4.1
  3. MaxDiffusion을 설치합니다.

    pip install -e .
  4. 종속 항목을 설치합니다.

    pip install -r requirements.txt
  5. JAX를 설치합니다.

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

이미지 생성

  1. TPU 런타임을 구성할 환경 변수를 설정합니다.

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. src/maxdiffusion/configs/base_xl.yml에 정의된 프롬프트와 구성을 사용하여 이미지를 생성합니다.

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

삭제

TPU를 삭제합니다.

gcloud compute tpus queued-resources delete $QUEUED_RESOURCE \
    --project $PROJECT_ID \
    --zone $ZONE \
    --force \
    --async

v6e Cloud TPU VM에서 vLLM 추론

이 튜토리얼에서는 TPU VM에서 vLLM을 시작하는 방법을 보여줍니다. 프로덕션에서 Trillium에 vLLM을 배포하는 권장사항의 예는 GKE vLLM 튜토리얼을 참고하세요.

시작하기 전에

  1. 칩 4개가 있는 TPU v6e 프로비저닝 준비:

    gcloud alpha compute tpus queued-resources create $QUEUED_RESOURCE_ID \
       --node-id $TPU_NAME \
       --project $PROJECT_ID \
       --zone $ZONE \
       --accelerator-type $ACCELERATOR_TYPE \
       --runtime-version $RUNTIME_VERSION \
       --service-account $SERVICE_ACCOUNT

    명령어 플래그 설명

    변수 설명
    NODE_ID 큐에 추가된 리소스 요청이 할당될 때 생성되는 TPU의 사용자 할당 ID입니다.
    PROJECT_ID Google Cloud 프로젝트 이름 에서 기존 프로젝트를 사용하거나 새 프로젝트를 만듭니다.
    ZONE 지원되는 영역에 대해서는 TPU 리전 및 영역 문서를 참고하세요.
    ACCELERATOR_TYPE 가속기 유형을 참고하세요.
    RUNTIME_VERSION v2-alpha-tpuv6e
    SERVICE_ACCOUNT Google Cloud 콘솔 -> IAM -> 서비스 계정에서 찾을 수 있는 서비스 계정의 이메일 주소입니다.

    예: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

  2. SSH를 사용하여 TPU에 연결합니다.

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    

Create a Conda environment

  1. (Recommended) Create a new conda environment for vLLM:

    conda create -n vllm python=3.10 -y
    conda activate vllm

TPU에서 vLLM 설정

  1. vLLM 저장소를 클론하고 vLLM 디렉터리로 이동합니다.

    git clone https://github.com/vllm-project/vllm.git && cd vllm
    
  2. 기존 torch 및 torch-xla 패키지를 삭제합니다.

    pip uninstall torch torch-xla -y
    
  3. PyTorch 및 PyTorch XLA를 설치합니다.

    pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
    
  4. JAX 및 Pallas를 설치합니다.

    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    
    
  5. 다른 빌드 종속 항목을 설치합니다.

    pip install -r requirements-tpu.txt
    VLLM_TARGET_DEVICE="tpu" python setup.py develop
    sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
    

모델 액세스 권한 얻기

HuggingFace 저장소에서 Llama3 모델 제품군을 사용하려면 동의 계약에 서명해야 합니다.

토큰을 아직 만들지 않았다면 새 Hugging Face 토큰을 생성합니다.

  1. 내 프로필 > 설정 > 액세스 토큰을 클릭합니다.
  2. 새 토큰을 선택합니다.
  3. 원하는 이름과 Read 이상의 역할을 지정합니다.
  4. 토큰 생성을 선택합니다.
  5. 생성된 토큰을 클립보드에 복사하고 환경 변수로 설정한 다음 huggingface-cli로 인증합니다.

    export TOKEN=''
    git config --global credential.helper store
    huggingface-cli login --token $TOKEN

벤치마킹 데이터 다운로드

  1. /data 디렉터리를 만들고 Hugging Face에서 ShareGPT 데이터 세트를 다운로드합니다.

    mkdir ~/data && cd ~/data
    wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    

vLLM 서버 시작

다음 명령어는 Hugging Face 모델 허브에서 TPU VM의 /tmp 디렉터리로 모델 가중치를 다운로드하고, 다양한 입력 도형을 사전 컴파일하고, 모델 컴파일을 ~/.cache/vllm/xla_cache에 씁니다.

자세한 내용은 vLLM 문서를 참고하세요.

   cd ~/vllm
   vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &

vLLM 벤치마크 실행

vLLM 벤치마킹 스크립트를 실행합니다.

   python benchmarks/benchmark_serving.py \
       --backend vllm \
       --model "meta-llama/Meta-Llama-3.1-8B"  \
       --dataset-name sharegpt \
       --dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json  \
       --num-prompts 1000

삭제

TPU를 삭제합니다.

gcloud compute tpus queued-resources delete $QUEUED_RESOURCE_ID \
    --project $PROJECT_ID \
    --zone $ZONE \
    --force \
    --async

학습 튜토리얼

다음 섹션에서는 TPU v6e에서 MaxText, MaxDiffusion, PyTorch 모델을 학습시키는 방법을 안내합니다.

v6e Cloud TPU VM에서 MaxText 및 MaxDiffusion 학습

다음 섹션에서는 MaxTextMaxDiffusion 모델의 학습 수명 주기를 다룹니다.

일반적인 대략적인 단계는 다음과 같습니다.

  1. 워크로드 기본 이미지를 빌드합니다.
  2. XPK를 사용하여 워크로드를 실행합니다.
    1. 워크로드의 학습 명령어를 빌드합니다.
    2. 워크로드를 배포합니다.
  3. 워크로드를 따라가고 측정항목을 확인합니다.
  4. 필요하지 않은 경우 XPK 워크로드를 삭제합니다.
  5. 더 이상 필요하지 않으면 XPK 클러스터를 삭제합니다.

기본 이미지 빌드

MaxText 또는 MaxDiffusion을 설치하고 Docker 이미지를 빌드합니다.

  1. 사용할 저장소를 클론하고 저장소의 디렉터리로 변경합니다.

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Google Cloud CLI를 사용하도록 Docker를 구성합니다.

    gcloud auth configure-docker
    
  3. 다음 명령어를 사용하거나 JAX 안정화 스택을 사용하여 Docker 이미지를 빌드합니다. JAX Stable Stack에 관한 자세한 내용은 JAX Stable Stack으로 Docker 이미지 빌드를 참고하세요.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. 로컬에서 빌드된 이미지가 없는 머신에서 워크로드를 실행하는 경우 이미지를 업로드합니다.

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
JAX 안정화 스택으로 Docker 이미지 빌드

JAX Stable Stack 기본 이미지를 사용하여 MaxText 및 MaxDiffusion Docker 이미지를 빌드할 수 있습니다.

JAX Stable Stack은 JAX를 orbax, flax, optax와 같은 핵심 패키지와 함께 번들로 묶고 TPU 프로그램 유틸리티 및 기타 필수 도구를 실행하는 잘 정규화된 libtpu.so를 제공하여 MaxText 및 MaxDiffusion을 위한 일관된 환경을 제공합니다. 이러한 라이브러리는 호환성을 보장하고 MaxText 및 MaxDiffusion을 빌드하고 실행하기 위한 안정적인 기반을 제공하도록 테스트됩니다. 이렇게 하면 호환되지 않는 패키지 버전으로 인한 잠재적 충돌을 방지할 수 있습니다.

JAX Stable Stack에는 TPU 프로그램 컴파일, 실행, ICI 네트워크 구성을 구동하는 핵심 라이브러리인 완전히 출시되고 검증된 libtpu.so가 포함되어 있습니다. libtpu 출시는 이전에 JAX에서 사용한 야간 빌드를 대체하며 HLO/StableHLO IR에서 PJRT 수준의 검증 테스트를 통해 TPU에서 XLA 계산의 일관된 기능을 보장합니다.

JAX 안정화 스택으로 MaxText 및 MaxDiffusion Docker 이미지를 빌드하려면 docker_build_dependency_image.sh 스크립트를 실행할 때 MODE 변수를 stable_stack로 설정하고 BASEIMAGE 변수를 사용하려는 기본 이미지로 설정합니다.

다음 예에서는 us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1를 기본 이미지로 지정합니다.

bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

사용 가능한 JAX Stable Stack 기본 이미지 목록은 Artifact Registry의 JAX Stable Stack 이미지를 참고하세요.

XPK를 사용하여 워크로드 실행

  1. MaxText 또는 MaxDiffusion에서 설정한 기본값을 사용하지 않는 경우 다음 환경 변수를 설정합니다.

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. 모델 스크립트를 빌드합니다. 이 스크립트는 이후 단계에서 학습 명령어로 복사됩니다.

    아직 모델 스크립트를 실행하지 마세요.

    MaxText

    MaxText는 순수 Python 및 JAX로 작성되었으며 학습 및 추론을 위해 Google Cloud TPU 및 GPU를 타겟팅하는 고성능의 확장성이 뛰어난 오픈소스 LLM입니다.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma는 Gemini 연구 및 기술을 기반으로 Google DeepMind에서 개발한 개방형 가중치 LLM 제품군입니다.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral은 Mistral AI에서 개발한 최신 AI 모델로, 희소 전문가 망 (MoE) 아키텍처를 활용합니다.

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama는 Meta에서 개발한 개방형 가중치 LLM 제품군입니다.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion은 Cloud TPU 및 GPU를 비롯한 XLA 기기에서 실행되는 순수 Python 및 JAX로 작성된 다양한 잠재 확산 모델의 참조 구현 모음입니다. Stable Diffusion은 모든 텍스트 입력에서 실사 이미지를 생성하는 잠재 텍스트 이미지 변환 모델입니다.

    MaxDiffusion을 실행하려면 다음 git checkout 명령어와 같이 특정 Git 브랜치를 설치해야 합니다.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    학습 스크립트:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
    
        
  3. 이전 단계에서 만든 스크립트를 사용하여 모델을 실행합니다. MaxText 기본 이미지를 사용하려면 --base-docker-image 플래그를 지정하거나 --docker-image 플래그와 사용하려는 이미지를 지정해야 합니다.

    선택사항: --enable-debug-logs 플래그를 포함하여 디버그 로깅을 사용 설정할 수 있습니다. 자세한 내용은 MaxText에서 JAX 디버그를 참고하세요.

    선택사항: --use-vertex-tensorboard 플래그를 포함하여 Vertex AI 실험을 만들어 Vertex AI 텐서보드에 데이터를 업로드할 수 있습니다. 자세한 내용은 Vertex AI를 사용하여 MaxText에서 JAX 모니터링을 참고하세요.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
        --tpu-type=$ACCELERATOR_TYPE \
        --num-slices=$NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command $YOUR-MODEL-SCRIPT

    다음 변수를 내보냅니다.

    export ClUSTER_NAME=CLUSTER_NAME: XPK 클러스터의 이름입니다. export ACCELERATOR_TYPEACCELERATOR_TYPE: TPU의 버전 및 크기입니다. 예를 들면 v6e-256입니다. export NUM_SLICES=NUM_SLICES: TPU 슬라이스 수입니다. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: 학습 명령으로 실행할 모델 스크립트입니다.

    출력에는 다음과 유사한 워크로드 추적 링크가 포함됩니다.

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    링크를 열고 로그 탭을 클릭하여 워크로드를 실시간으로 추적합니다.

MaxText에서 JAX 디버그

보조 XPK 명령어를 사용하여 클러스터 또는 워크로드가 실행되지 않는 이유를 진단합니다.

  • XPK 워크로드 목록
  • XPK 검사기
  • XPK 워크로드를 만들 때 --enable-debug-logs 플래그를 사용하여 워크로드 로그에서 상세 로깅을 사용 설정합니다.

Vertex AI를 사용하여 MaxText에서 JAX 모니터링

Vertex AI의 관리형 TensorBoard를 통해 스칼라 및 프로필 데이터를 확인합니다.

  1. 사용 중인 영역의 리소스 관리 (CRUD) 요청을 600에서 5,000으로 늘립니다. VM 16대 미만을 사용하는 소규모 워크로드의 경우 문제가 되지 않을 수 있습니다.
  2. Vertex AI용 cloud-accelerator-diagnostics와 같은 종속 항목을 설치합니다.

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Vertex AI TensorBoard 만들기에 설명된 대로 --create-vertex-tensorboard 플래그를 사용하여 XPK 클러스터를 만듭니다. 기존 클러스터에서 이 명령어를 실행할 수도 있습니다.

  4. --use-vertex-tensorboard 플래그와 선택사항인 --experiment-name 플래그를 사용하여 XPK 워크로드를 실행할 때 Vertex AI 실험을 만듭니다. 단계의 전체 목록은 Vertex AI 실험을 만들어 Vertex AI 텐서보드에 데이터 업로드하기를 참고하세요.

로그에는 다음과 유사한 Vertex AI 텐서보드 링크가 포함됩니다.

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Google Cloud 콘솔에서도 Vertex AI TensorBoard 링크를 찾을 수 있습니다. Google Cloud 콘솔에서 Vertex AI Experiments로 이동합니다. 드롭다운에서 적절한 지역을 선택합니다.

TensorBoard 디렉터리도 ${BASE_OUTPUT_DIR}로 지정한 Cloud Storage 버킷에 씁니다.

XPK 워크로드 삭제

xpk workload delete 명령어를 사용하여 작업 접두사 또는 작업 상태를 기반으로 하나 이상의 워크로드를 삭제합니다. 이 명령어는 더 이상 실행할 필요가 없는 XPK 워크로드를 전송했거나 대기열에 멈춘 작업이 있는 경우에 유용할 수 있습니다.

XPK 클러스터 삭제

xpk cluster delete 명령어를 사용하여 클러스터를 삭제합니다.

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone $ZONE --project $PROJECT_ID

v6e Cloud TPU VM에서 Llama 및 PyTorch/XLA 학습

이 튜토리얼에서는 WikiText 데이터 세트를 사용하여 TPU v6e에서 PyTorch/XLA를 사용하여 Llama 모델을 학습시키는 방법을 설명합니다. 또한 여기에서 PyTorch TPU 모델 스크립트에 Docker 이미지로 액세스할 수 있습니다.

설치

가상 환경에 Hugging Face Transformers의 pytorch-tpu/transformers 포크와 종속 항목을 설치합니다.

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

모델 구성 설정

다음 섹션의 학습 명령어인 모델 스크립트 빌드는 두 개의 JSON 구성 파일을 사용하여 모델 매개변수와 FSDP (전체 샤딩 데이터 병렬) 구성을 정의합니다. FSDP 샤딩은 학습 중에 모델 가중치가 더 큰 배치 크기에 맞게 조정되는 데 사용됩니다. 소형 모델로 학습하는 경우 데이터 병렬 처리를 사용하고 각 기기에서 가중치를 복제하는 것으로 충분할 수 있습니다. PyTorch/XLA에서 여러 기기 간에 텐서를 샤딩하는 방법에 관한 자세한 내용은 PyTorch/XLA SPMD 사용자 가이드를 참고하세요.

  1. 모델 매개변수 구성 파일을 만듭니다. 다음은 Llama3-8B의 모델 매개변수 구성입니다. 다른 모델의 경우 Hugging Face에서 구성을 찾습니다. 예를 들어 Llama2-7B 구성을 참고하세요.

    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
  2. FSDP 구성 파일을 만듭니다.

    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }

    FSDP에 관한 자세한 내용은 FSDPv2를 참고하세요.

  3. 다음 명령어를 사용하여 구성 파일을 TPU VM에 업로드합니다.

        gcloud alpha compute tpus tpu-vm scp your-config-file.json $TPU_NAME:. \
            --worker=all \
            --project=$PROJECT \
            --zone $ZONE

    현재 작업 디렉터리에서 구성 파일을 만들고 XPK에서 --base-docker-image 플래그를 사용할 수도 있습니다.

모델 스크립트 빌드

--config_name 플래그를 사용하여 모델 매개변수 구성 파일을 지정하고 --fsdp_config 플래그를 사용하여 FSDP 구성 파일을 지정하여 모델 스크립트를 빌드합니다. 다음 섹션인 모델 실행에서 TPU에서 이 스크립트를 실행합니다. 아직 모델 스크립트를 실행하지 마세요.

    export PJRT_DEVICE=TPU
    export XLA_USE_SPMD=1
    export ENABLE_PJRT_COMPATIBILITY=true
    # Optional variables for debugging:
    export XLA_IR_DEBUG=1
    export XLA_HLO_DEBUG=1
    export PROFILE_EPOCH=0
    export PROFILE_STEP=3
    export PROFILE_DURATION_MS=100000
    export PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path
    python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 8 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/config-8B.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp_config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20

모델 실행

이전 단계인 모델 스크립트 빌드에서 만든 스크립트를 사용하여 모델을 실행합니다.

단일 호스트 TPU VM (예: v6e-4)을 사용하는 경우 TPU VM에서 직접 학습 명령어를 실행할 수 있습니다. 멀티 호스트 TPU VM을 사용하는 경우 --worker=all 플래그를 사용하여 모든 호스트에서 동시에 스크립트를 실행합니다.

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
    --zone $ZONE \
    --worker=all \
    --command=your-command

PyTorch/XLA 문제 해결

이전 섹션에서 디버깅을 위한 선택적 변수를 설정하면 모델의 프로필이 PROFILE_LOGDIR 변수로 지정된 위치에 저장됩니다. 이 위치에 저장된 xplane.pb 파일을 추출하고 tensorboard를 사용하여 텐서보드 안내에 따라 브라우저에서 프로필을 볼 수 있습니다. PyTorch/XLA가 예상대로 작동하지 않으면 모델 디버깅, 프로파일링, 최적화에 관한 제안사항이 포함된 문제 해결 가이드를 참고하세요.

v6e에서 DLRM DCN v2 학습

이 튜토리얼에서는 TPU v6e에서 DLRM DCN v2 모델을 학습시키는 방법을 보여줍니다. 64, 128 또는 256개 칩이 있는 TPU v6e를 프로비저닝해야 합니다.

멀티 호스트에서 실행하는 경우 다음 명령어를 실행하여 적절한 TensorFlow 버전으로 tpu-runtime를 재설정합니다. 단일 호스트에서 실행하는 경우 다음 두 명령어를 실행할 필요가 없습니다.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

worker-0에 SSH로 연결

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

TPU 이름 설정

export TPU_NAME=${TPU_NAME}

DLRM v2 실행

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

script.sh을 실행합니다.

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

추천 워크로드 (DLRM DCN)를 실행하려면 다음 플래그가 필요합니다.

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

벤치마킹 결과

다음 섹션에는 v6e의 DLRM DCN v2 및 MaxDiffusion에 대한 벤치마킹 결과가 포함되어 있습니다.

DLRM DCN v2

DLRM DCN v2 학습 스크립트는 다양한 규모로 실행되었습니다. 다음 표에서 처리량을 확인하세요.

v6e-64 v6e-128 v6e-256
학습 단계 7000 7000 7000
전역 배치 크기 131072 262144 524288
처리량(예시/초) 2975334 5111808 10066329

MaxDiffusion

v6e-4, v6e-16, 2xv6e-16에서 MaxDiffusion 학습 스크립트를 실행했습니다. 다음 표에서 처리량을 확인하세요.

v6e-4 v6e-16 v6e-16 2개
학습 단계 0.069 0.073 0.13
전역 배치 크기 8 32 64
처리량(예시/초) 115.9 438.4 492.3

수집 예약

Trillium (v6e)에는 '수집 예약'이라는 새로운 기능이 포함되어 있습니다. 이 기능은 GKE와 Cloud TPU API 모두에서 단일 호스트 추론 워크로드를 실행하는 여러 TPU 슬라이스를 관리하는 방법을 제공합니다. 이러한 슬라이스를 컬렉션으로 그룹화하면 수요에 맞게 복제본 수를 쉽게 조정할 수 있습니다. 소프트웨어 업데이트는 컬렉션 내 슬라이스의 일부를 항상 수신 트래픽을 처리하는 데 사용할 수 있도록 신중하게 제어됩니다.

GKE에서 수집 예약을 사용하는 방법에 관한 자세한 내용은 GKE 문서를 참고하세요.

수집 예약 기능은 v6e에만 적용됩니다.

Cloud TPU API에서 수집 예약 사용

Cloud TPU API의 단일 호스트 컬렉션은 대기열에 추가된 리소스로, 특수 플래그 (--workload-type = availability-optimized)가 설정되어 워크로드 제공에 사용될 것임을 기본 인프라에 나타냅니다.

다음 명령어는 Cloud TPU API를 사용하여 단일 호스트 컬렉션을 프로비저닝합니다.

gcloud alpha compute tpus queued-resources create my-collection \
   --project=$PROJECT_ID \
   --zone=${ZONE} \
   --accelerator-type $ACCELERATOR_TYPE \
   --node-count ${NODE_COUNT} \
   --workload-type=availability-optimized

모니터링 및 프로파일링

Cloud TPU v6e는 이전 세대의 Cloud TPU와 동일한 메서드를 사용한 모니터링 및 프로파일링을 지원합니다. 모니터링에 관한 자세한 내용은 TPU VM 모니터링을 참고하세요.