(Comments)
You are going to learn step by step how to freeze and convert your trained Keras model into a single TensorFlow
When compared to TensorFlow, Keras API might look less daunting and easier to work with, especially when you are doing quick experiments and build a model with standard layers. While TensorFlow is more versatile when you plan to deploy your model to different platforms across different programming languages. While there are many ways to convert a Keras model to its TenserFlow counterpart, I am going to show you one of the easiest when all you want is to make predictions with the converted model in deployment situations.
Here is the overview what will be covered.
When you have trained a Keras model, it is a good practice to save it as a single HDF5 file first so you can load it back later after training.
import os
os.makedirs('./model', exist_ok=True)
model.save('./model/keras_model.h5')
In case you ran into the "incompatible with expected resource" issue with a model containing BatchNormization layers such as DenseNet, make sure to set the learning phase to 0 before loading the Keras model in a new session.
from keras import backend as K
# This line must be executed before loading Keras model.
K.set_learning_phase(0)
Then you can load up the model and find the model's input and output tensors' names.
from keras.models import load_model
model = load_model('./model/keras_model.h5')
print(model.outputs)
# [<tf.Tensor 'dense_2/Softmax:0' shape=(?, 10) dtype=float32>]
print(model.inputs)
# [<tf.Tensor 'conv2d_1_input:0' shape=(?, 28, 28, 1) dtype=float32>]
As you can see, our simple model has only single input and output, your model might have multiple inputs/outputs.
We keep track of their names since we are going to locate them by name in the converted TensorFlow graph during inference.
The first step is to get the computation graph of TensorFlow backend which represents the Keras model, where the forward pass and training related operations are included.
Then the graph will be converted to a GraphDef protocol buffer, after that it will be pruned so subgraphs that are not necessary to compute the requested outputs such as the training operations are removed. This step if refer to as freezing the graph.
from keras import backend as K
import tensorflow as tf
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
# Graph -> GraphDef ProtoBuf
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in model.outputs])
frozen_graph
# Save to ./model/tf_model.pb
tf.train.write_graph(frozen_graph, "model", "tf_model.pb", as_text=False)
Now we have everything we need to predict with the graph saved as one single .pb file.
To load it back, start a new session either by restarting the Jupyter Notebook Kernel or running in a new Python script.
The following several lines deserialize the GraphDef from .pb file and restore it as the default graph to
import tensorflow as tf
from tensorflow.python.platform import gfile
f = gfile.FastGFile("./model/tf_model.pb", 'rb')
graph_def = tf.GraphDef()
# Parses a serialized binary message into the current message.
graph_def.ParseFromString(f.read())
f.close()
sess.graph.as_default()
# Import a serialized TensorFlow `GraphDef` protocol buffer
# and place into the current default `Graph`.
tf.import_graph_def(graph_def)
Locate the input tensor so we can feed it with some input data and grab the predictions from the output tensor, we are going to get them by name. The only difference is that all tensors' names are prefixed with the string "import/" so the input tensor is now "import/conv2d_1_input:0"
"import/dense_2/Softmax:0"
To make a prediction, it can be as simple as,
softmax_tensor = sess.graph.get_tensor_by_name('import/dense_2/Softmax:0')
predictions = sess.run(softmax_tensor, {'import/conv2d_1_input:0': x_test[:20]})
predictions
If your model has multiple inputs/outputs, you can do something like this.
predicts_1, predicts_2 = sess.run([output_tensor1, output_tensor2], {
'import/input0:0': x_1[:20], 'import/input1:0': x_2[:20]})
Do you wonder what the model freezing step have done to your model, like what operations have been removed?
Let's compare those two graphs side by side in the Jupyter Notebook by loading a minimal version of TensorBoard to show the graph structures.
Here I included a show_graph.py module allowing you to do so.
from show_graph import show_graph
import tensorflow as tf
# Show current session graph with TensorBoard in Jupyter Notebook.
show_graph(tf.get_default_graph().as_graph_def())
You can run this block twice, one after Keras model training/loading, one after loading&restoring the .pb file, and here is the results.
Train graph:
Frozen graph:
You can easily see the related training operations are removed in the frozen graph.
You have learned how to convert your Keras model into a TensorFlow .pb file for inference purpose only. Be sure to check out the source code for this post on my GitHub.
Here are some related resources you might find helpful.
A Tool Developer's Guide to TensorFlow Model Files
Exporting and Importing a MetaGraph
Share on Twitter Share on Facebook
Comments