Learn Tensorflow by Examples Series - Part 8

Transfer Learning with Tensorflow - II - Using modules from TF-Hub

We'll take a state-of-the-art model from tensorflow hub (https://tfhub.dev), and use that for our purpose of digit recognition (with transfer learning).
In the previous tutorial (Transfer Learning - Part I) we trained a model from scratch, and reused its lower layers (transfer learning) for a similar but different task. We learnt that reusing the lower layers improved our training speed, however the accuracy achieved was lower than a similar model trained from scratch (without transfer learning). Here is the accuracy curve that we finally got:

In this tutorial, instead of training a model from scratch and reusing its lower layers, we'll use the state-of-the-art model MobileNetV2 model from ILSVR 2012 (ImageNet Large Scale Visual Recognition Competition), and transfer its lower layers to a model for our digit recognition task. We'll also compare its accuracy with the earlier models that we trained in the previous tutorial (Transfer Learning - Part I).
Specifically, we will,
  1. Download the MobileNetV2 feature vector model from tensorflow hub.
  2. Use it as a feature vector generator for our images in a DNN model.
  3. Train the model on digits 5-9 with MNIST data.
  4. Compare the speed and accuracy with our previous models.
As the MNIST images are 28x28x1, we'll add padding to make them 96x96x3 (for compatibility with MobileNetV2).

Installing tensorflow_hub with pip

!pip3 install absl-py
!pip install 'tensorflow_hub==0.4.0' 
!pip install 'tf-nightly'

Training the model

# import bunch of things
import tensorflow as tf
import tensorflow_hub as hub
from matplotlib import pyplot as plt
%matplotlib inline

import tensorflow as tf
import numpy as np

from tensorflow.contrib.layers import fully_connected

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp/data')

n_outputs = 10

#Use MobileV2 model
module = hub.Module('https://tfhub.dev/google/imagenet/mobilenet_v2_100_96/feature_vector/2')

tl_X = tf.placeholder(tf.float32, shape=(None, 96, 96, 3), name='tl_X')
tl_y = tf.placeholder(tf.int64, shape=(None), name='tl_y')

with tf.name_scope('transfer_learning_dnn'):
    mobilenet_feature_vector = module(tl_X)
    tl_hidden1 = fully_connected(inputs = mobilenet_feature_vector, num_outputs = 100, scope='tl_hidden1')
    tl_logits = fully_connected(inputs = tl_hidden1, num_outputs = n_outputs, scope='tl_logits', activation_fn=None)
with tf.name_scope('transfer_learning_loss'):
    tl_xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tl_y, logits=tl_logits)
    tl_loss = tf.reduce_mean(tl_xentropy, name='tl_loss')
learning_rate = 0.01
with tf.name_scope('transfer_learning_train'):
    tl_optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    tl_training_op = tl_optimizer.minimize(tl_loss)
with tf.name_scope('transfer_learning_eval'):
    tl_correct_op = tf.nn.in_top_k(predictions=tl_logits, targets=tl_y, k=1)
    tl_accuracy_op = tf.reduce_mean(tf.cast(tl_correct_op, tf.float32))

def add_padding_and_reshape(x):
    # reshape x to (? 96, 96, 3), while adding padding
    result = list()
    for original in x:
        a = original.reshape(28, 28, 1)
        n = np.zeros([96, 96, 3])
        n[:28, :28, :1] = a
    return result

def subset(X_batch, y_batch, min, max):
    zipped = list(zip(X_batch, y_batch))
    zipped_batch_min_to_max = ([x for x in zipped if x[1] >= min and x[1] <= max])
    x_uz, y_uz = zip(*zipped_batch_min_to_max)
    return list(x_uz), list(y_uz)

X_test_acc_5_9, y_test_acc_5_9 = subset(mnist.test.images, mnist.test.labels, 5, 9)
tl_X_test_acc_5_9 = add_padding_and_reshape(X_test_acc_5_9)

n_epochs = 5
batch_size = 1000
n_iterations = mnist.train.num_examples // batch_size

# Transfer learning with MobileNetV2
tl_new_test_acc = list()
init = tf.global_variables_initializer()

with tf.Session() as sess:
    tl_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='tl_hidden1|tl_logits')
    tl_new_training_op = tl_optimizer.minimize(tl_loss, var_list=tl_train_vars)
    for epoch in range(n_epochs):
        for iteration in range(n_iterations):
            tl_X_batch, tl_y_batch = mnist.train.next_batch(batch_size);
            tl_X_batch_5_9, tl_y_batch_5_9 = subset(tl_X_batch, tl_y_batch, 5, 9)
            tl_y_batch_5_9 = np.array(tl_y_batch_5_9)
            tl_padded_X_batch_5_9 = np.array(add_padding_and_reshape(tl_X_batch_5_9))
            sess.run(tl_training_op, feed_dict={tl_X: tl_padded_X_batch_5_9, tl_y: tl_y_batch_5_9})
            tl_acc_test = sess.run(tl_accuracy_op, feed_dict={tl_X: tl_X_test_acc_5_9, tl_y: y_test_acc_5_9})
        print('epoch:', epoch)
        print('test acc:', tl_acc_test)
# plot accuracy    
plt.xlabel('epoch'); plt.ylabel('accuracy')
plt.plot(tl_new_test_acc, label='mobilenet-transfer-learned')
plt.legend(loc='lower right')
plt.title('Train accuracy - original vs transfer-learned model')
Here is the accuracy plot:

That looks pretty good, reaching ~87% accuracy. Let's see how it compares with the other models that we trained with/without transfer-learning:


  • The model with transfer learning with MobileNetV2 performs better than the other model we trained on digits 5-9, while reusing the lower layers of a model trained on digits 0-4.
  • The mobilenet-transfer-learned model still doesn't do as great as the original model without any transfer learning.


  • For sophisticated tasks, if there already exists a generic enough model (for example MobileNetV2 for image recognition), it's better to take advantage of that with transfer learning.
  • Even if you intend to train a model from scratch, as it may take a lot of time to train, it's always good to train a quick model with transfer learning and establish a baseline on the performance. In fact, as we are progressing more and more on training AI to recognize the world, transfer learning is already plaing a fundamental role.
  • For smaller models which train within an hour, transfer learning may not be that useful.