How to compress your Keras model x5 smaller with TensorFlow model optimization



This tutorial will demonstrate how you can reduce the size of your Keras model by 5 times with TensorFlow model optimization, which can be particularly important for deployment in resource-constraint environments.

From the official TensorFlow model optimization documentation. Weight pruning means eliminating unnecessary values in weight tensors. We set the neural network parameters' values to zero to remove what we estimate are unnecessary connections between the layers of a neural network. This is done during the training process to allow the neural network to adapt to the changes.

Here is a breakdown of how you can adopt this technique.

  1. Train Keras model to reach an acceptable accuracy as always.
  2. Make Keras layers or model ready to be pruned.
  3. Create a pruning schedule and train the model for more epochs.
  4. Export the pruned model by striping pruning wrappers from the model.
  5. Convert Keras model to TensorFlow Lite with optional quantization.

Prune your pre-trained Keras model

Your pre-trained model has already achieved desirable accuracy, you want to cut down its size while maintaining the performance. The pruning API can help you make it happen.

To use the pruning API, install the tensorflow-model-optimization and tf-nightly packages.

pip uninstall -yq tensorflow
pip uninstall -yq tf-nightly
pip install -Uq tf-nightly-gpu
pip install -q tensorflow-model-optimization

Then you can load your previous trained model and make it "prunable". The Keras-based API can be applied at the level of individual layers, or the entire model. Since you have the entire model pre-trained, it is easier to apply the pruning to the entire model. The algorithm will be applied to all layers capable of weight pruning.

For the pruning schedule, we start at the sparsity level 50% and gradually train the model to reach 90% sparsity. X% sparsity means that X% of the weight tensor is going to be pruned away.

Furthermore, we give the model some time to recover after each pruning step, so pruning does not happen on every step. We set the pruning frequency to 100. Similar to pruning a bonsai, we are trimming it gradually so that the tree can adequately heal the wound created during pruning instead of cutting 90% of its branches in one day.

Given the model already reached a satisfactory accuracy, we can start pruning immediately. As a result, we set the begin_step to 0 here, and only train for another four epochs.

The end step is calculated given the number of train example, batch size, and the total epochs to train.

import numpy as np
import tensorflow as tf
from tensorflow_model_optimization.sparsity import keras as sparsity

# Backend agnostic way to save/restore models
# _, keras_file = tempfile.mkstemp('.h5')
# print('Saving model to: ', keras_file)
# tf.keras.models.save_model(model, keras_file, include_optimizer=False)

# Load the serialized model
loaded_model = tf.keras.models.load_model(keras_file)

epochs = 4
end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs

new_pruning_params = {
      'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,

new_pruned_model = sparsity.prune_low_magnitude(loaded_model, **new_pruning_params)


Don't panic if you find more trainable parameters in the new_pruned_model summary, those came from the pruning wrappers which we will remove later.

Now let's start the training and pruning model.

# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
    sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)
], y_train,
          validation_data=(x_test, y_test))

score = new_pruned_model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

The test loss and accuracy of the pruned model should look similar to your original Keras model.

Export the pruned model

Those pruning wrappers can be removed easily like this, after which the total number of parameters should be the same as your original model.

final_model = sparsity.strip_pruning(pruned_model)

Now you can check the percentage of weights were pruned by comparing them to zero.

from tensorflow.keras.models import load_model

model = load_model(final_model)
import numpy as np

for i, w in enumerate(model.get_weights()):
        "{} -- Total:{}, Zeros: {:.2f}%".format(
            model.weights[i].name, w.size, np.sum(w == 0) / w.size * 100

Here is the results, as you can see, 90% of convolution, dense and batch norm layers' weights are pruned.

name Total para Pruned%
conv2d_2/kernel:0 800 89.12%
conv2d_2/bias:0 32 0.00%
batch_normalization_1/gamma:0 32 0.00%
batch_normalization_1/beta:0 32 0.00%
conv2d_3/kernel:0 32 0.00%
conv2d_3/bias:0 32 0.00%
dense_2/kernel:0 51200 89.09%
dense_2/bias:0 64 0.00%
dense_3/kernel:0 3211264 89.09%
dense_3/bias:0 1024 0.00%
batch_normalization_1/moving_mean:0 10240 89.09%
batch_normalization_1/moving_variance:0 10 0.00%

Now, simply using a generic file compression algorithm (e.g. zip), the Keras model will be reduced by x5 times.

import tempfile
import zipfile

_, new_pruned_keras_file = tempfile.mkstemp(".h5")
print("Saving pruned model to: ", new_pruned_keras_file)
tf.keras.models.save_model(final_model, new_pruned_keras_file, include_optimizer=False)

# Zip the .h5 model file
_, zip3 = tempfile.mkstemp(".zip")
with zipfile.ZipFile(zip3, "w", compression=zipfile.ZIP_DEFLATED) as f:
    "Size of the pruned model before compression: %.2f Mb"
    % (os.path.getsize(new_pruned_keras_file) / float(2 ** 20))
    "Size of the pruned model after compression: %.2f Mb"
    % (os.path.getsize(zip3) / float(2 ** 20))

Here is what you get, x5 times smaller model.

Size of the pruned model before compression: 12.52 Mb
Size of the pruned model after compression: 2.51 Mb

Convert Keras model to TensorFlow Lite

Tensorflow Lite is an example format you can use to deploy to mobile devices. To convert to a Tensorflow Lite graph, it is necessary to use the TFLiteConverter as below:

# Create the .tflite file
tflite_model_file = "/tmp/sparse_mnist.tflite"
converter = tf.lite.TFLiteConverter.from_keras_model_file(pruned_keras_file)
tflite_model = converter.convert()
with open(tflite_model_file, "wb") as f:

Then you can use a similar technique to zip the tflite file and reduce size x5 times smaller.

Post-training quantization converts weights to 8-bit precision as part of the model conversion from keras model to TFLite's flat buffer, resulting in another 4x reduction in the model size. Just add the following line to the previous snippet before calling the convert().

converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]

The compressed 8-bit tensorflow lite model only takes 0.60 Mb compared to the original Keras model's 12.52 Mb while maintaining comparable test accuracy. That's totally x16 times size reduction.

You can evaluate the accuracy of the converted TensorFlow Lite model like this where you feed the eval_model with the test dataset.

import numpy as np

interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

def eval_model(interpreter, x_test, y_test):
  total_seen = 0
  num_correct = 0

  for img, label in zip(x_test, y_test):
    inp = img.reshape((1, 28, 28, 1))
    total_seen += 1
    interpreter.set_tensor(input_index, inp)
    predictions = interpreter.get_tensor(output_index)
    if np.argmax(predictions) == np.argmax(label):
      num_correct += 1

    if total_seen % 1000 == 0:
        print("Accuracy after %i images: %f" %
              (total_seen, float(num_correct) / float(total_seen)))

  return float(num_correct) / float(total_seen)

print(eval_model(interpreter, x_test, y_test))

Conclusion and Further reading

In this tutorial, we showed you how to create sparse models with the TensorFlow model optimization toolkit weight pruning API. Right now, this allows you to create models that take significantly less space on the disk. The resulting model can also be more efficiently implemented to avoid computation; in the future, TensorFlow Lite will provide such capabilities.

Check out the official TensorFlow model optimization page and their GitHub page for more information.

The source code for this post is available on my Github and runnable on Google Colab Notebook.

Current rating: 4.2