Google Kubernetes Engine で Keras を使用して TensorFlow モデルをトレーニングする

次のセクションでは、TensorFlow で Hugging Face Transformers ライブラリを使用して、シーケンス分類用の BERT モデルをファインチューニングする例を示します。データセットは、マウントされた Parallelstore が基盤となるボリュームにダウンロードされます。これにより、モデル トレーニングでボリュームからデータを直接読み取ることができます。

前提条件

モデル トレーニング ジョブの次の YAML マニフェスト(parallelstore-csi-job-example.yaml)を保存します。

  apiVersion: batch/v1
  kind: Job
  metadata:
    name: parallelstore-csi-job-example
  spec:
    template:
      metadata:
        annotations:
            gke-parallelstore/cpu-limit: "0"
            gke-parallelstore/memory-limit: "0"
      spec:
        securityContext:
          runAsUser: 1000
          runAsGroup: 100
          fsGroup: 100
        containers:
        - name: tensorflow
          image: jupyter/tensorflow-notebook@sha256:173f124f638efe870bb2b535e01a76a80a95217e66ed00751058c51c09d6d85d
          command: ["bash", "-c"]
          args:
          - |
            pip install transformers datasets
            python - <<EOF
            from datasets import load_dataset
            dataset = load_dataset("glue", "cola", cache_dir='/data')
            dataset = dataset["train"]
            from transformers import AutoTokenizer
            import numpy as np
            tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
            tokenized_data = tokenizer(dataset["sentence"], return_tensors="np", padding=True)
            tokenized_data = dict(tokenized_data)
            labels = np.array(dataset["label"])
            from transformers import TFAutoModelForSequenceClassification
            from tensorflow.keras.optimizers import Adam
            model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased")
            model.compile(optimizer=Adam(3e-5))
            model.fit(tokenized_data, labels)
            EOF
          volumeMounts:
          - name: parallelstore-volume
            mountPath: /data
        volumes:
        - name: parallelstore-volume
          persistentVolumeClaim:
            claimName: parallelstore-pvc
        restartPolicy: Never
    backoffLimit: 1

YAML マニフェストをクラスタに適用します。

kubectl apply -f parallelstore-csi-job-example.yaml

次のコマンドを使用して、データの読み込みとモデルのトレーニングの進行状況を確認します。

POD_NAME=$(kubectl get pod | grep 'parallelstore-csi-job-example' | awk '{print $1}')
kubectl logs -f $POD_NAME -c tensorflow