Use Gemma open models with Dataflow

Gemma is a family of lightweight, state-of-the art open models built from research and technology used to create the Gemini models. You can use Gemma models in your Apache Beam inference pipelines. The term open weight means that a model's pretrained parameters, or weights, are released. Details such as the original dataset, model architecture, and training code aren't provided.

Use cases

You can use Gemma models with Dataflow for sentiment analysis. With Dataflow and the Gemma models, you can process events, such as customer reviews, as they arrive. Run the reviews through the model to analyze them, and then generate recommendations. By combining Gemma with Apache Beam, you can seamlessly complete this workflow.

Support and limitations

Gemma open models are supported with Apache Beam and Dataflow with the following requirements:

  • Available for batch and streaming pipelines that use the Apache Beam Python SDK versions 2.46.0 and later.
  • Dataflow jobs must use Runner v2.
  • Dataflow jobs must use GPUs. For a list of GPU types supported with Dataflow, see Availability. The T4 and L4 GPU types are recommended.
  • The model must be downloaded and saved in the .keras file format.
  • The TensorFlow model handler is recommended but not required.


  • Access Gemma models through Kaggle.
  • Complete the consent form and accept the terms and conditions.
  • Download the Gemma model. Save it in the .keras file format in a location that your Dataflow job can access, such as a Cloud Storage bucket. When you specify a value for the model path variable, use the path to this storage location.
  • To run your job on Dataflow, create a custom container image. This step makes it possible to run the pipeline with GPUs on the Dataflow service.

Use Gemma in your pipeline

To use a Gemma model in your Apache Beam pipeline, follow these steps.

  1. In your Apache Beam code, after you import your pipeline dependencies, include a path to your saved model:

    model_path = "MODEL_PATH"

    Replace MODEL_PATH with the path where you saved the downloaded model. For example, if you save your model to a Cloud Storage bucket, the path has the format gs://STORAGE_PATH/FILENAME.keras.

  2. The Keras implementation of the Gemma models has a generate() method that generates text based on a prompt. To pass elements to the generate() method, use a custom inference function.

    def gemma_inference_function(model, batch, inference_args, model_id):
      vectorized_batch = np.stack(batch, axis=0)
      # The only inference_arg expected here is a max_length parameter to
      # determine how many words are included in the output.
      predictions = model.generate(vectorized_batch, **inference_args)
      return utils._convert_to_result(batch, predictions, model_id)
  3. Run your pipeline, specifying the path to the trained model. This example uses a TensorFlow model handler.

    class FormatOutput(beam.DoFn):
      def process(self, element, *args, **kwargs):
        yield "Input: {input}, Output: {output}".format(input=element.example, output=element.inference)
    # Instantiate a NumPy array of string prompts for the model.
    examples = np.array(["Tell me the sentiment of the phrase 'I like pizza': "])
    # Specify the model handler, providing a path and the custom inference function.
    model_handler = TFModelHandlerNumpy(model_path, inference_fn=gemma_inference_function)
    with beam.Pipeline() as p:
      _ = (p | beam.Create(examples) # Create a PCollection of the prompts.
             | RunInference(model_handler, inference_args={'max_length': 32}) # Send the prompts to the model and get responses.
             | beam.ParDo(FormatOutput()) # Format the output.
             | beam.Map(print) # Print the formatted output.

What's next