Train a TensorFlow model with Keras on Google Kubernetes Engine
Stay organized with collections
Save and categorize content based on your preferences.
The following section provides an example of
fine-tuning a BERT model
for sequence classification using the
Hugging Face transformers library
with TensorFlow. The dataset is downloaded into a mounted
Parallelstore-backed volume, allowing the model training to directly read data
from the volume.
Prerequisites
Ensure your node has at least 8 GiB of memory available.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Hard to understand","hardToUnderstand","thumb-down"],["Incorrect information or sample code","incorrectInformationOrSampleCode","thumb-down"],["Missing the information/samples I need","missingTheInformationSamplesINeed","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2025-08-25 UTC."],[],[],null,["# Train a TensorFlow model with Keras on Google Kubernetes Engine\n\nThe following section provides an example of\n[fine-tuning a BERT model](https://huggingface.co/docs/transformers/training#train-a-tensorflow-model-with-keras)\nfor sequence classification using the\n[Hugging Face transformers](https://github.com/huggingface/transformers) library\nwith TensorFlow. The dataset is downloaded into a mounted\nParallelstore-backed volume, allowing the model training to directly read data\nfrom the volume.\n\nPrerequisites\n-------------\n\n- Ensure your node has at least 8 GiB of memory available.\n- [Create a PersistentVolumeClaim requesting for a Parallelstore-backed volume](/kubernetes-engine/docs/how-to/persistent-volumes/parallelstore-csi-new-volume#pvc).\n\nSave the following YAML manifest (`parallelstore-csi-job-example.yaml`) for your model training Job. \n\n apiVersion: batch/v1\n kind: Job\n metadata:\n name: parallelstore-csi-job-example\n spec:\n template:\n metadata:\n annotations:\n gke-parallelstore/cpu-limit: \"0\"\n gke-parallelstore/memory-limit: \"0\"\n spec:\n securityContext:\n runAsUser: 1000\n runAsGroup: 100\n fsGroup: 100\n containers:\n - name: tensorflow\n image: jupyter/tensorflow-notebook@sha256:173f124f638efe870bb2b535e01a76a80a95217e66ed00751058c51c09d6d85d\n command: [\"bash\", \"-c\"]\n args:\n - |\n pip install transformers datasets\n python - \u003c\u003cEOF\n from datasets import load_dataset\n dataset = load_dataset(\"glue\", \"cola\", cache_dir='/data')\n dataset = dataset[\"train\"]\n from transformers import AutoTokenizer\n import numpy as np\n tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n tokenized_data = tokenizer(dataset[\"sentence\"], return_tensors=\"np\", padding=True)\n tokenized_data = dict(tokenized_data)\n labels = np.array(dataset[\"label\"])\n from transformers import TFAutoModelForSequenceClassification\n from tensorflow.keras.optimizers import Adam\n model = TFAutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\")\n model.compile(optimizer=Adam(3e-5))\n model.fit(tokenized_data, labels)\n EOF\n volumeMounts:\n - name: parallelstore-volume\n mountPath: /data\n volumes:\n - name: parallelstore-volume\n persistentVolumeClaim:\n claimName: parallelstore-pvc\n restartPolicy: Never\n backoffLimit: 1\n\nApply the YAML manifest to the cluster.\n\n`kubectl apply -f parallelstore-csi-job-example.yaml`\n\nCheck your data loading and model training progress with the following command: \n\n POD_NAME=$(kubectl get pod | grep 'parallelstore-csi-job-example' | awk '{print $1}')\n kubectl logs -f $POD_NAME -c tensorflow\n\n| **Note:** The model training takes approximately five minutes to complete."]]