Downloading data from https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg
1812110/1812110 [==============================] - 0s 0us/step
# Convert the input image to the type and dimensions required by the model.img=np.array(img)/255.0img_tensor=tf.cast(tf.convert_to_tensor(img[...]),dtype=tf.float32)
fromapache_beam.ml.inference.tensorflow_inferenceimportTFModelHandlerTensorfromapache_beam.ml.inference.baseimportPredictionResultfromapache_beam.ml.inference.baseimportRunInferencefromtypingimportIterablemodel_handler=TFModelHandlerTensor(model_uri=CLASSIFIER_URL)classPostProcessor(beam.DoFn):"""Process the PredictionResult to get the predicted label. Returns predicted label. """defsetup(self):labels_path=tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')self._imagenet_labels=np.array(open(labels_path).read().splitlines())defprocess(self,element:PredictionResult)-> Iterable[str]:predicted_class=np.argmax(element.inference)predicted_class_name=self._imagenet_labels[predicted_class]yield"Predicted Label: {}".format(predicted_class_name.title())withbeam.Pipeline()asp:_=(p|"Create PCollection" >> beam.Create([img_tensor])|"Perform inference" >> RunInference(model_handler)|"Post Processing" >> beam.ParDo(PostProcessor())|"Print" >> beam.Map(print))
Predicted Label: Tiger Cat
Except as otherwise noted, the content of this page is licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-04-30 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Hard to understand","hardToUnderstand","thumb-down"],["Incorrect information or sample code","incorrectInformationOrSampleCode","thumb-down"],["Missing the information/samples I need","missingTheInformationSamplesINeed","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2025-04-30 UTC."],[[["\u003cp\u003eThis guide demonstrates using Apache Beam's \u003ccode\u003eRunInference\u003c/code\u003e transform with TensorFlow and a pre-trained model from TensorFlow Hub for machine learning inference.\u003c/p\u003e\n"],["\u003cp\u003eApache Beam offers \u003ccode\u003eTFModelHandlerNumpy\u003c/code\u003e for models expecting NumPy arrays and \u003ccode\u003eTFModelHandlerTensor\u003c/code\u003e for models requiring tensor inputs.\u003c/p\u003e\n"],["\u003cp\u003eTo use the \u003ccode\u003eTFModelHandler\u003c/code\u003e class, you must use TF2 formatted models and can use the pre-trained model's URL from TensorFlow Hub by providing it to the \u003ccode\u003emodel_uri\u003c/code\u003e field.\u003c/p\u003e\n"],["\u003cp\u003eThe provided example shows how to implement a pipeline to apply inference on an image via a pre-trained model and process the prediction result to output a predicted label.\u003c/p\u003e\n"],["\u003cp\u003eThe guide contains information about installing the necessary packages for this process.\u003c/p\u003e\n"]]],[],null,["# Apache Beam RunInference with TensorFlow and TensorFlow Hub\n\n\u003cbr /\u003e\n\nThis notebook shows how to use the Apache Beam [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform for [TensorFlow](https://www.tensorflow.org/) with a trained model from [TensorFlow Hub](https://www.tensorflow.org/hub).\nApache Beam includes built-in support for two TensorFlow model handlers: [TFModelHandlerNumpy](https://github.com/apache/beam/blob/ca0787642a6b3804a742326147281c99ae8d08d2/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L91) and [TFModelHandlerTensor](https://github.com/apache/beam/blob/ca0787642a6b3804a742326147281c99ae8d08d2/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L184).\n\n- Use `TFModelHandlerNumpy` to run inference on models that expect a NumPy array as an input.\n- Use `TFModelHandlerTensor` to run inference on models expecting a tensor as an input.\n\nFor more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation.\n| **Note:** The image used for prediction is licensed in CC-BY. The creator is listed in the [LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt) file.\n\nBefore you begin\n----------------\n\nFirst, import `tensorflow`. To use RunInference with the TensorFlow model handler, install Apache Beam version 2.46 or later. \n\n pip install tensorflow\n pip install apache_beam==2.46.0\n\nUse TensorFlow Hub's trained model URL\n--------------------------------------\n\nTo use TensorFlow Hub's trained model URL, pass the model URL to the `model_uri` field of `TFModelHandler` class.\n**Note:** Only models that you save in the [TF2 format](https://www.tensorflow.org/tutorials/keras/save_and_load#save_the_entire_model) are compatible with `TFModelHandler`. To see TF2 models, go to the [TF2 section of the TensorFlow Hub](https://tfhub.dev/s?subtype=module,placeholder&tf-version=tf2). \n\n import tensorflow as tf\n import tensorflow_hub as hub\n import apache_beam as beam\n\n # URL of the trained model from TensorFlow Hub\n CLASSIFIER_URL =\"https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4\"\n\n import numpy as np\n import PIL.Image as Image\n\n IMAGE_RES = 224\n img = tf.keras.utils.get_file(origin='https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg')\n img = Image.open(img).resize((IMAGE_RES, IMAGE_RES))\n img\n\n```\nDownloading data from https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg\n1812110/1812110 [==============================] - 0s 0us/step\n```\n\n # Convert the input image to the type and dimensions required by the model.\n img = np.array(img)/255.0\n img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)\n\n from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor\n from apache_beam.ml.inference.base import PredictionResult\n from apache_beam.ml.inference.base import RunInference\n from typing import Iterable\n\n model_handler = TFModelHandlerTensor(model_uri=CLASSIFIER_URL)\n\n class PostProcessor(beam.DoFn):\n \"\"\"Process the PredictionResult to get the predicted label.\n Returns predicted label.\n \"\"\"\n def setup(self):\n labels_path = tf.keras.utils.get_file(\n 'ImageNetLabels.txt',\n 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'\n )\n self._imagenet_labels = np.array(open(labels_path).read().splitlines())\n\n def process(self, element: PredictionResult) -\u003e Iterable[str]:\n predicted_class = np.argmax(element.inference)\n predicted_class_name = self._imagenet_labels[predicted_class]\n yield \"Predicted Label: {}\".format(predicted_class_name.title())\n\n with beam.Pipeline() as p:\n _ = (p\n | \"Create PCollection\" \u003e\u003e beam.Create([img_tensor])\n | \"Perform inference\" \u003e\u003e RunInference(model_handler)\n | \"Post Processing\" \u003e\u003e beam.ParDo(PostProcessor())\n | \"Print\" \u003e\u003e beam.Map(print))\n\n```\nPredicted Label: Tiger Cat\n```"]]