How to convert trained Keras model to a single TensorFlow .pb file and make prediction



You are going to learn step by step how to freeze and convert your trained Keras model into a single TensorFlow pb file.

When compared to TensorFlow, Keras API might look less daunting and easier to work with, especially when you are doing quick experiments and build a model with standard layers. While TensorFlow is more versatile when you plan to deploy your model to different platforms across different programming languages. While there are many ways to convert a Keras model to its TenserFlow counterpart, I am going to show you one of the easiest when all you want is to make predictions with the converted model in deployment situations.

Here is the overview what will be covered.

  • Keras to single TensorFlow .pb file
  • Load .pb file with TensorFlow and make predictions.
  • (Optional) Visualize the graph in a Jupyter notebook.

Source code for this post available on my GitHub.

Keras to TensorFlow .pb file

When you have trained a Keras model, it is a good practice to save it as a single HDF5 file first so you can load it back later after training.

import os
os.makedirs('./model', exist_ok=True)'./model/keras_model.h5')

In case you ran into the "incompatible with expected resource" issue with a model containing BatchNormization layers such as DenseNet, make sure to set the learning phase to 0 before loading the Keras model in a new session.

from keras import backend as K
# This line must be executed before loading Keras model.

Then you can load up the model and find the model's input and output tensors' names.

from keras.models import load_model
model = load_model('./model/keras_model.h5')
# [<tf.Tensor 'dense_2/Softmax:0' shape=(?, 10) dtype=float32>]
# [<tf.Tensor 'conv2d_1_input:0' shape=(?, 28, 28, 1) dtype=float32>]

As you can see, our simple model has only single input and output, your model might have multiple inputs/outputs.

We keep track of their names since we are going to locate them by name in the converted TensorFlow graph during inference.

The first step is to get the computation graph of TensorFlow backend which represents the Keras model, where the forward pass and training related operations are included.

Then the graph will be converted to a GraphDef protocol buffer, after that it will be pruned so subgraphs that are not necessary to compute the requested outputs such as the training operations are removed. This step if refer to as freezing the graph.

from keras import backend as K
import tensorflow as tf

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set( for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [ for v in tf.global_variables()]
        # Graph -> GraphDef ProtoBuf
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph

frozen_graph = freeze_session(K.get_session(),
                              output_names=[ for out in model.outputs])

The frozen_graph is a serialized GraphDef proto which we can use the following function call to save it as a single binary pb file.

# Save to ./model/tf_model.pb
tf.train.write_graph(frozen_graph, "model", "tf_model.pb", as_text=False)

Load .pb file and make predictions

Now we have everything we need to predict with the graph saved as one single .pb file.

To load it back, start a new session either by restarting the Jupyter Notebook Kernel or running in a new Python script.

The following several lines deserialize the GraphDef from .pb file and restore it as the default graph to current running TensorFlow session.

import tensorflow as tf
from tensorflow.python.platform import gfile

f = gfile.FastGFile("./model/tf_model.pb", 'rb')
graph_def = tf.GraphDef()
# Parses a serialized binary message into the current message.

# Import a serialized TensorFlow `GraphDef` protocol buffer
# and place into the current default `Graph`.

Locate the input tensor so we can feed it with some input data and grab the predictions from the output tensor, we are going to get them by name. The only difference is that all tensors' names are prefixed with the string "import/" so the input tensor is now named "import/conv2d_1_input:0" and output tensor is "import/dense_2/Softmax:0".

To make a prediction, it can be as simple as,

softmax_tensor = sess.graph.get_tensor_by_name('import/dense_2/Softmax:0')
predictions =, {'import/conv2d_1_input:0': x_test[:20]})

The predictions is the softmax value in this case with shape (20, 10) representing 20 samples each has logits values for 10 classes.

If your model has multiple inputs/outputs, you can do something like this.

predicts_1, predicts_2 =[output_tensor1, output_tensor2], {
    'import/input0:0': x_1[:20], 'import/input1:0': x_2[:20]})

Visualize the graph in Notebook (optional)

Do you wonder what the model freezing step have done to your model, like what operations have been removed?

Let's compare those two graphs side by side in the Jupyter Notebook by loading a minimal version of TensorBoard to show the graph structures.

Here I included a module allowing you to do so.

from show_graph import show_graph
import tensorflow as tf
# Show current session graph with TensorBoard in Jupyter Notebook.

You can run this block twice, one after Keras model training/loading, one after loading&restoring the .pb file, and here is the results.

Train graph: 


Frozen graph:


You can easily see the related training operations are removed in the frozen graph.

Conclusion and further reading

You have learned how to convert your Keras model into a TensorFlow .pb file for inference purpose only. Be sure to check out the source code for this post on my GitHub.

Here are some related resources you might find helpful.

A Tool Developer's Guide to TensorFlow Model Files

Exporting and Importing a MetaGraph

Current rating: 3.9