Learn Tensorflow by Examples Series - Part 7

Transfer Learning with Tensorflow - I


Have you ever wondered why humans are able to learn newer things without a lot of examples and relatively quickly, and why it takes a lot of time (and examples) for neural networks to learn even basic image recognition?
The answer lies in a concept called transfer learning. An untrained neural network is like a new-born baby, it just doesn't come pre-trained to recognize the world around it. Once a baby learns how to recognize and perceive visuals, it applies that learning to understand more complex series of events.
The way neural networks are generally trained, we literally train them to make sense of numbers they haven't seen before, every time (e.g. pixels in images). Transfer learning is a technique with which we can reuse what a neural network has already learned, and apply that knowledge to new, similar or more complicated tasks.
It's a terrible thing to see and have no vision - Hellen Keller
In this tutorial we will,
  1. Create and train a deep neural net: Call it model1, with 4 hidden layers, train it to recognize digits 0-4 in MNIST dataset.
  2. Transfer learn a new model: Create a new neural network, say model2, and copy (transfer) the first two hidden layers of model1 to model2, and "freeze" model2's first layer.
  3. Train the new model on different type of data: Train model2 (only unfrozen layers) on digits 5-9, and compare how fast it learns compared to the original model1.

1. Create and train a deep neural net

Create a deep neural net (with 4 hidden layers), say model1, train it to recognize digits 0-4 in MNIST dataset.
import tensorflow as tf
import numpy as np
from tensorflow.contrib.layers import fully_connected
tf.reset_default_graph()

n_inputs = 28*28
n_outputs = 10

# Define inputs
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name='X')
y = tf.placeholder(tf.int64, shape=(None), name='y')

# Construct the neural network
with tf.name_scope('dnn'):
    hidden1 = fully_connected(inputs=X, num_outputs=300, scope='hidden1')
    hidden2 = fully_connected(inputs=hidden1, num_outputs=100, scope='hidden2')
    hidden3 = fully_connected(inputs=hidden2, num_outputs=100, scope='hidden3')
    hidden4 = fully_connected(inputs=hidden3, num_outputs=100, scope='hidden4')
    logits = fully_connected(inputs=hidden4, num_outputs=n_outputs, scope='outputs', activation_fn=None)
    
with tf.name_scope('loss'):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
    loss = tf.reduce_mean(xentropy, name='loss')

learning_rate = 0.01
with tf.name_scope('train'):
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
    training_op = optimizer.minimize(loss)

with tf.name_scope('eval'):
    correct_op = tf.nn.in_top_k(predictions=logits, targets=y, k=1)
    accuracy_op = tf.reduce_mean(tf.cast(correct_op, tf.float32))

# Get the MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp/data')

init = tf.global_variables_initializer()
model1_saver = tf.train.Saver()

n_epochs = 20
batch_size = 1000
n_iterations = mnist.train.num_examples // batch_size
model1_save_path = './dnn_model.ckpt'

# We'll store the test dataset accuracy as we train the model.
model1_test_acc = list()

# Utility method to filter the training data (with digits between min and max)
def subset(X_batch, y_batch, min, max):
    zipped = list(zip(X_batch, y_batch))
    zipped_batch_min_to_max = ([x for x in zipped if x[1] >= min and x[1] <= max])
    x_uz, y_uz = zip(*zipped_batch_min_to_max)
    return list(x_uz), list(y_uz)

# Train the network
with tf.Session() as sess:
    init.run()
    
    X_test_acc, y_test_acc = subset(mnist.test.images, mnist.test.labels, 0, 4)
    for epoch in range(n_epochs):
        for iteration in range(n_iterations):
            X_batch, y_batch = mnist.train.next_batch(batch_size);
            X_batch_0_4, y_batch_0_4 = subset(X_batch, y_batch, 0, 4)
            
            sess.run(training_op, feed_dict={X: X_batch_0_4, y: y_batch_0_4})
        
        acc_test = sess.run(accuracy_op, feed_dict={X: X_test_acc, y: y_test_acc})
        model1_test_acc.append(acc_test)
            
        print('epoch:', epoch)
        print('test acc:', acc_test)
        
        model1_saver.save(sess, save_path=model1_save_path)
At this point, we have model1 trained for recognizing digits 0-4 in MNIST dataset. We achieved approximately 97.2% accuracy on the test dataset.

2. Transfer learn a new model

Create a new neural network, say model2 (identical structure), and copy (transfer) the first two hidden layers of model1 to model2.
We'll
  1. Extract and save the 1st and 2nd hidden layers of model1.
  2. Restore the saved 1st and 2nd hidden layers in a new session.
  3. Confirm that only hidden1 and hidden2 weights/biases are transferred.

# Transfer learning - check without retraining
model2_save_path = './new_dnn_model.ckpt'
model2_saver = tf.train.Saver()

X_train_acc_5_9, y_train_acc_5_9 = subset(mnist.train.images, mnist.train.labels, 5, 9)
X_test_acc_5_9, y_test_acc_5_9 = subset(mnist.test.images, mnist.test.labels, 5, 9)
with tf.Session() as sess:
    model1_saver.restore(sess, save_path=model1_save_path)
    reuse_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='hidden[12]')
    reuse_vars_restorer = tf.train.Saver(reuse_vars)
    
    # Save hidden1, hidden2 in a new location
    sess.run(init) # Initialize a new model
    reuse_vars_restorer.restore(sess, save_path=model1_save_path) # only restore hidden1, hidden2 layer
    model2_saver.save(sess, save_path=model2_save_path) # Save model2 (with restored layers)

# Restore model1 and model2 weights/biases
with tf.Session() as sess:
    sess.run(init)
    
    #Load model1 weights/biases
    model1_saver.restore(sess, save_path=model1_save_path)
    model1_vars = sess.run(tf.all_variables())
    
    #Load model2 weights/biases
    model2_saver.restore(sess, save_path=model2_save_path)
    model2_vars = sess.run(tf.all_variables())
    
# Confirm that only hidden1 and hidden2 layers weights/biases are transferred
assert(len(model1_vars) == len(model2_vars))
assert(True == (model1_vars[0] == model2_vars[0]).all()) # hidden1/weights
assert(True == (model1_vars[1] == model2_vars[1]).all()) # hidden1/biases
assert(True == (model1_vars[2] == model2_vars[2]).all()) # hidden2 layer weights
assert(True == (model1_vars[3] == model2_vars[3]).all()) # hidden2 layer biases
assert(False == (model1_vars[4] == model2_vars[4]).all()) # hidden3 layer weights
assert(False == (model1_vars[5] == model2_vars[5]).all()) # hidden3 layer biases
assert(False == (model1_vars[6] == model2_vars[6]).all()) # hidden4 layer weights
assert(False == (model1_vars[7] == model2_vars[7]).all()) # hidden4 layer biases
assert(False == (model1_vars[8] == model2_vars[8]).all()) # output layer weights
assert(False == (model1_vars[9] == model2_vars[9]).all()) # output layer biases
At this point, we've saved only the model1's hidden1 and hidden2 layers into a new location. We've also confirmed that the rest of the layers are indeed different.

3. Train the new model on different type of data

Train model2 (only non-frozen layers) on digits 5-9, and compare how fast it learns compared to the first neural net.
We will,
  1. Train a model on MNIST data (digits 5-9) WITHOUT transfer learning
  2. Train a model MNIST data (digits 5-9) WITH transfer learning
  3. Compare how fast the two models train (with respect to 'no. of iterations'), and how the accuracy compares.
n_epochs = 5

# Train on digits 5-9 without transfer learning
test_acc_without_transfer_learning = list()
 
with tf.Session() as sess:
    init.run()
     
    for epoch in range(n_epochs):
        for iteration in range(n_iterations):
            X_batch, y_batch = mnist.train.next_batch(batch_size);
            X_batch_5_9, y_batch_5_9 = subset(X_batch, y_batch, 5, 9)
            
            sess.run(training_op, feed_dict={X: X_batch_5_9, y: y_batch_5_9})
         
            acc_test = sess.run(accuracy_op, feed_dict={X: X_test_acc_5_9, y: y_test_acc_5_9})
            test_acc_without_transfer_learning.append(acc_test)
         
        print('epoch:', epoch)
        print('test acc:', acc_test)

# Train the new model while freezing hidden1 and hidden2 layers
test_acc_with_transfer_learning = list()

with tf.Session() as sess:
    init.run()
    model2_saver.restore(sess, save_path=model2_save_path)
    
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='hidden[234]|outputs')
    new_training_op = optimizer.minimize(loss, var_list=train_vars)
    
    for epoch in range(n_epochs):
        for iteration in range(n_iterations):
            X_batch, y_batch = mnist.train.next_batch(batch_size);
            X_batch_5_9, y_batch_5_9 = subset(X_batch, y_batch, 5, 9)
            sess.run(new_training_op, feed_dict={X: X_batch_5_9, y: y_batch_5_9})
            
            acc_test = sess.run(accuracy_op, feed_dict={X: X_test_acc_5_9, y: y_test_acc_5_9})
            test_acc_with_transfer_learning.append(acc_test)
        
        print('epoch:', epoch)
        print('test acc:', acc_test)
     
Now let's plot the accuracy vs iteration graph:
from matplotlib import pyplot as plt
%matplotlib inline

plt.plot(test_acc_without_transfer_learning, label='WITHOUT transfer learning'); plt.xlabel('epoch'); plt.ylabel('accuracy')
plt.plot(test_acc_with_transfer_learning, label='WITH transfer learning')
plt.legend(loc='lower right')
plt.title('Test accuracy - without vs with transfer learning')

Observations:
  • The model trained WITH transfer learning, starts to learn new data faster. It takes it 50 iterations for 60% accuracy, while the model WITHOUT transfer learning requires ~100 iterations.
  • The model trained WITH transfer learning doesn't really achieve the same level of accuracy, as the one trained WITHOUT transfer learning, because of the reduced degrees of freedom, i.e. because it cannot tweak the first two layers as it requires.
  • Because the model trained WITH transfer learning achieves high accuracy with fewer iterations, that means it can recognize reasonable patterns even with less training data.

Conclusion

  • Transfer learning speeds up training process.
  • Transfer learning requires lesser training data
  • Moded trained with transfer learning may not achieve as much accuracy as the model trained without it.

When to use transfer learning?

  • When the network is small enough and easy to train, it may not be worthwhile to invest in transfer learning based models.
  • When you have data with a wide variety of patterns (e.g. images) and you already have a sophisticated model trained on a different type of data, using transfer learning will provide great benefits in terms of speed of training, and the amount of training data you need.

Why only transfer lower layers?

One thing you may have noticed is, we're only transferring lower layers of a neural network. The reason is, the lower layers are believed to capture low-level patterns in the data, for example, curves in an image. The higher layers rely on data provided pre-processed by lower layers, so their pattern recognition is based out of the features from lower layers.
If a network is trained on general enough data, for example images, the information in lower layers can almost be applied to any image recognition application. That means, if we have a sophisticated model, we can reuse its capabilities for image recognition. Some examples of such models are MobileNet, Inception or Resnet.

What's Next?

In the next tutorial, we'll use the lower layers of a standard sophisticated model (MobileNetV2) from tfhub.dev to train a model similar as above, and see how it compare versus the above models (in terms of accuracy) for our digit recognition tasks.