Salvataggio dei modelli TensorFlow per AI Explanations

Questa pagina spiega come salvare un modello TensorFlow da utilizzare con le AI Explanations, qualunque sia la versione di TensorFlow in uso, 2.x o 1.15.

TensorFlow 2

Se utilizzi TensorFlow 2.x, usa tf.saved_model.save per salvare il modello.

Un'opzione comune per ottimizzare i modelli TensorFlow salvati è che gli utenti forniscano le firme. Puoi specificare le firme di input quando salvi il modello. Se hai una sola firma di input, le AI Explanations utilizzano automaticamente la funzione di pubblicazione predefinita per le richieste di spiegazioni, seguendo il comportamento predefinito di tf.saved_model.save. Scopri di più su come specificare le firme di pubblicazione in TensorFlow.

Più firme di input

Se il tuo modello ha più di una firma di input, le spiegazioni dell'IA non possono determinare automaticamente quale definizione di firma utilizzare per recuperare una previsione dal modello. Pertanto, devi specificare quale definizione di firma vuoi che venga utilizzata da AI Explanations. Quando salvi il modello, specifica la firma della funzione predefinita di pubblicazione in una chiave univoca, xai-model:

tf.saved_model.save(m, model_dir, signatures={
    'serving_default': serving_fn,
    'xai_model': my_signature_default_fn # Required for AI Explanations
    })

In questo caso, le AI Explanations utilizzano la firma della funzione del modello che hai fornito con la chiave xai_model per interagire con il modello e generare spiegazioni. Utilizza la stringa esatta xai_model per la chiave. Per ulteriori informazioni, consulta questa panoramica delle definizioni delle firme.

Funzioni di pre-elaborazione

Se utilizzi una funzione di pre-elaborazione, devi specificare le firme per la funzione di pre-elaborazione e la funzione del modello quando salvi il modello. Utilizza la chiave xai_preprocess per specificare la funzione di pre-elaborazione:

tf.saved_model.save(m, model_dir, signatures={
    'serving_default': serving_fn,
    'xai_preprocess': preprocess_fn, # Required for AI Explanations
    'xai_model': model_fn # Required for AI Explanations
    })

In questo caso, AI Explanations utilizza la funzione di preelaborazione e la funzione del modello per le richieste di spiegazione. Assicurati che l'output della funzione di preelaborazione corrisponda all'input previsto dalla funzione del modello.

Prova i notebook di esempio completi di TensorFlow 2:

TensorFlow 1.15

Se utilizzi TensorFlow 1.15, non utilizzare tf.saved_model.save. Questa funzione non è supportata con le spiegazioni dell'IA quando utilizzi TensorFlow 1. Utilizza invece tf.estimator.export_savedmodel in combinazione con un valore tf.estimator.export.ServingInputReceiver appropriato.

Modelli creati con Keras

Se crei e addestri il modello in Keras, devi convertirlo in un estimatore TensorFlow e poi esportarlo in un modello SavedModel. Questa sezione si concentra sul salvataggio di un modello. Per un esempio completo funzionante, consulta i notebook di esempio:

Dopo aver creato, compilato, addestrato e valutato il modello Keras, devi eseguire i seguenti passaggi:

  • Converti il modello Keras in un estimatore TensorFlow utilizzando tf.keras.estimator.model_to_estimator
  • Fornisci una funzione input di pubblicazione utilizzando tf.estimator.export.build_raw_serving_input_receiver_fn
  • Esporta il modello come SavedModel utilizzando tf.estimator.export_saved_model.
# Build, compile, train, and evaluate your Keras model
model = tf.keras.Sequential(...)
model.compile(...)
model.fit(...)
model.predict(...)

## Convert your Keras model to an Estimator
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir='export')

## Define a serving input function appropriate for your model
def serving_input_receiver_fn():
  ...
  return tf.estimator.export.ServingInputReceiver(...)

## Export the SavedModel to Cloud Storage using your serving input function
export_path = keras_estimator.export_saved_model(
    'gs://' + 'YOUR_BUCKET_NAME', serving_input_receiver_fn).decode('utf-8')

print("Model exported to: ", export_path)

Passaggi successivi